summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorKristofer Hallin <kristofer@sunet.se>2022-01-17 14:01:08 +0100
committerKristofer Hallin <kristofer@sunet.se>2022-01-17 14:01:08 +0100
commitbb5029d512a58021718061aca439383c8b11e575 (patch)
tree74354b6bf55a9159695eea695653ef03009e5ad4 /src
parent571997129ba5275cc5e148a8ac1c0f64d895a9ef (diff)
parent0b55f7ff7cdd3b78bd9992063208476c1c080a02 (diff)
* Merge branch 'main' into feature.callhome
* New API endpoints * Updated requirements
Diffstat (limited to 'src')
-rwxr-xr-xsrc/authn.py112
-rwxr-xr-xsrc/db/db.py9
-rw-r--r--src/db/scanner.py30
-rwxr-xr-xsrc/main.py1
-rw-r--r--src/routers/collector.py67
-rw-r--r--src/routers/scanner.py6
6 files changed, 80 insertions, 145 deletions
diff --git a/src/authn.py b/src/authn.py
deleted file mode 100755
index e90118a..0000000
--- a/src/authn.py
+++ /dev/null
@@ -1,112 +0,0 @@
-#! /usr/bin/env python3
-
-import yaml
-
-
-class Authz:
- def __init__(self, org, perms):
- self._org = org
- self._perms = perms
-
- def dump(self):
- return "{}: {}".format(self._org, self._perms)
-
- def read_p(self):
- return 'r' in self._perms
-
- def write_p(self):
- return 'w' in self._perms
-
-
-class User:
- def __init__(self, username, pw, authz):
- self._username = username
- self._password = pw
- self._authz = {}
- for org, perms in authz.items():
- self._authz[org] = Authz(org, perms)
-
- def dump(self):
- return ["{}/{}: {}".format(self._username, self._password, auth.dump())
- for auth in self._authz.values()]
-
- def authn_p(self, pw):
- return pw == self._password
-
- def orgnames(self):
- return [x for x in self._authz.keys()]
-
- def read_perms(self):
- acc = []
- for k, v in self._authz.items():
- if v.read_p():
- acc.append(k)
- return acc
-
- def write_perms(self):
- acc = []
- for k, v in self._authz.items():
- if v.write_p():
- acc.append(k)
- return acc
-
-
-class UserDB:
- def __init__(self, yamlfile):
- self._users = {}
- for u, d in yaml.safe_load(open(yamlfile)).items():
- self._users[u] = User(u, d['pw'], d['authz'])
-
- def dump(self):
- return [u.dump() for u in self._users.values()]
-
- def user_authn_p(self, username, password):
- user = self._users.get(username)
- if not user:
- return False
- return user.authn_p(password)
-
- def orgs_for_user(self, username):
- return self._users.get(username).orgnames()
-
- def read_perms(self, username, password):
- user = self._users.get(username)
- if not user:
- return None
- if not user.authn_p(password):
- return None
- return user.read_perms()
-
- def write_perms(self, username, password):
- user = self._users.get(username)
- if not user:
- return None
- if not user.authn_p(password):
- return None
- return user.write_perms()
-
-
-def self_test():
- db = UserDB('userdb.yaml')
- print(db.dump())
-
- orgs = db.orgs_for_user('user3')
- assert('sunet.se' in orgs)
- assert('su.se' in orgs)
- assert(len(orgs) == 2)
-
- assert(db.user_authn_p('user3', 'pw3') == True)
- assert(db.user_authn_p('user3', 'wrongpw') == False)
-
- rp = db.read_perms('user3', 'pw3')
- assert(len(rp) == 2)
- assert('sunet.se' in rp)
- assert('su.se' in rp)
-
- wp = db.write_perms('user3', 'pw3')
- assert(len(wp) == 1)
- assert('sunet.se' in wp)
-
-
-if __name__ == '__main__':
- self_test()
diff --git a/src/db/db.py b/src/db/db.py
index cbb87ce..3926fda 100755
--- a/src/db/db.py
+++ b/src/db/db.py
@@ -1,12 +1,3 @@
-# A database storing dictionaries, keyed on a timestamp. value = A
-# dict which will be stored as a JSON object encoded in UTF-8. Note
-# that dict keys of type integer or float will become strings while
-# values will keep their type.
-
-# Note that there's a (slim) chance that you'd stomp on the previous
-# value if you're too quick with generating the timestamps, ie
-# invoking time.time() several times quickly enough.
-
import os
import sys
import time
diff --git a/src/db/scanner.py b/src/db/scanner.py
index e9ac8c3..625fd8e 100644
--- a/src/db/scanner.py
+++ b/src/db/scanner.py
@@ -1,7 +1,7 @@
import enum
from datetime import datetime
-from sqlalchemy import (Boolean, Column, DateTime, Integer, Unicode,
+from sqlalchemy import (JSON, Boolean, Column, DateTime, Integer, Unicode,
UniqueConstraint, create_engine)
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.declarative import declarative_base
@@ -22,13 +22,12 @@ class Scanner(Base):
id = Column(Integer, autoincrement=True, primary_key=True)
uuid = Column(Unicode(37), nullable=False)
- enabled = Column(Boolean, nullable=False)
+ enabled = Column(Boolean, default=False, nullable=False)
first_seen = Column(DateTime, default=datetime.utcnow, nullable=False)
last_seen = Column(DateTime, default=datetime.utcnow,
onupdate=datetime.utcnow, nullable=False)
- comment = Column(Unicode(255), nullable=True)
- scanners = Column(Unicode(2048), nullable=False)
- targets = Column(Unicode(255), nullable=True)
+ comment = Column(Unicode(255), default="", nullable=True)
+ scanners = Column(JSON, nullable=True)
def as_dict(self):
"""Return JSON serializable dict."""
@@ -47,7 +46,8 @@ class Scanner(Base):
@classmethod
def comment(cls, uuid, comment):
with SqlDB.sql_session() as session:
- scanner: Scanner = session.query(Scanner).filter(Scanner.uuid == uuid).one_or_none()
+ scanner: Scanner = session.query(Scanner).filter(
+ Scanner.uuid == uuid).one_or_none()
if scanner:
scanner.comment = comment
@@ -58,7 +58,8 @@ class Scanner(Base):
@classmethod
def enable(cls, uuid):
with SqlDB.sql_session() as session:
- scanner: Scanner = session.query(Scanner).filter(Scanner.uuid == uuid).one_or_none()
+ scanner: Scanner = session.query(Scanner).filter(
+ Scanner.uuid == uuid).one_or_none()
if scanner:
scanner.enabled = True
@@ -69,7 +70,8 @@ class Scanner(Base):
@classmethod
def disable(cls, uuid):
with SqlDB.sql_session() as session:
- scanner: Scanner = session.query(Scanner).filter(Scanner.uuid == uuid).one_or_none()
+ scanner: Scanner = session.query(Scanner).filter(
+ Scanner.uuid == uuid).one_or_none()
if scanner:
scanner.enabled = False
@@ -84,9 +86,11 @@ class Scanner(Base):
with SqlDB.sql_session() as session:
if scanner_id:
- scanner: Scanner = session.query(Scanner).filter(Scanner.id == scanner_id).one_or_none()
+ scanner: Scanner = session.query(Scanner).filter(
+ Scanner.id == scanner_id).one_or_none()
elif uuid:
- scanner: Scanner = session.query(Scanner).filter(Scanner.uuid == uuid).one_or_none()
+ scanner: Scanner = session.query(Scanner).filter(
+ Scanner.uuid == uuid).one_or_none()
else:
return None
@@ -104,7 +108,7 @@ class Scanner(Base):
scanner = Scanner()
scanner.uuid = uuid
scanner.enabled = False
- scanner.scanners = "None"
+ scanner.scanners = {}
session.add(scanner)
session.flush()
@@ -113,11 +117,11 @@ class Scanner(Base):
except IntegrityError:
return None
-
@classmethod
def is_enabled(cls, uuid):
with SqlDB.sql_session() as session:
- scanner: Scanner = session.query(Scanner).filter(Scanner.uuid == uuid).one_or_none()
+ scanner: Scanner = session.query(Scanner).filter(
+ Scanner.uuid == uuid).one_or_none()
if scanner is None:
return None
diff --git a/src/main.py b/src/main.py
index aa3b133..a65971d 100755
--- a/src/main.py
+++ b/src/main.py
@@ -24,6 +24,7 @@ app.add_middleware(
expose_headers=["X-Total-Count"],
)
+
@app.middleware("http")
async def mock_x_total_count_header(request: Request, call_next):
response = await call_next(request)
diff --git a/src/routers/collector.py b/src/routers/collector.py
index 3cda23a..7d91609 100644
--- a/src/routers/collector.py
+++ b/src/routers/collector.py
@@ -48,18 +48,22 @@ def get_data(key=None, limit=25, skip=0, ip=None,
@router.get('/get')
async def get(key=None, limit=25, skip=0, ip=None, port=None,
asn=None, Authorize: AuthJWT = Depends()):
+
Authorize.jwt_required()
data = []
raw_jwt = Authorize.get_raw_jwt()
- if 'domains' not in raw_jwt:
- return JSONResponse(content={"status": "error",
- "message": "Could not find domains" +
- "claim in JWT token"},
- status_code=400)
+ if "read" not in raw_jwt:
+ return JSONResponse(
+ content={
+ "status": "error",
+ "message": "Could not find read claim in JWT token",
+ },
+ status_code=400,
+ )
else:
- domains = raw_jwt['domains']
+ domains = raw_jwt["read"]
for domain in domains:
data.extend(get_data(key, limit, skip, ip, port, asn, domain))
@@ -69,17 +73,39 @@ async def get(key=None, limit=25, skip=0, ip=None, port=None,
@router.get('/get/{key}')
async def get_key(key=None, Authorize: AuthJWT = Depends()):
+
Authorize.jwt_required()
- # TODO: Use JWT authz and check e.g. domain here
+ raw_jwt = Authorize.get_raw_jwt()
+
+ if "read" not in raw_jwt:
+ return JSONResponse(
+ content={
+ "status": "error",
+ "message": "Could not find read claim in JWT token",
+ },
+ status_code=400,
+ )
+ else:
+ allowed_domains = raw_jwt["read"]
data = get_data(key)
+ if data["domain"] not in allowed_domains:
+ return JSONResponse(
+ content={
+ "status": "error",
+ "message": "User not authorized to view this object",
+ },
+ status_code=400,
+ )
+
return JSONResponse(content={"status": "success", "docs": data})
@router.post('/add')
async def add(data: Request, Authorize: AuthJWT = Depends()):
+
Authorize.jwt_required()
json_data = await data.json()
@@ -91,11 +117,36 @@ async def add(data: Request, Authorize: AuthJWT = Depends()):
@router.delete('/delete/{key}')
async def delete(key, Authorize: AuthJWT = Depends()):
+
Authorize.jwt_required()
+ raw_jwt = Authorize.get_raw_jwt()
+
+ if "write" not in raw_jwt:
+ return JSONResponse(
+ content={
+ "status": "error",
+ "message": "Could not find write claim in JWT token",
+ },
+ status_code=400,
+ )
+ else:
+ allowed_domains = raw_jwt["write"]
+
+ data = get_data(key)
+
+ if data["domain"] not in allowed_domains:
+ return JSONResponse(
+ content={
+ "status": "error",
+ "message": "User not authorized to delete this object",
+ },
+ status_code=400,
+ )
+
if db.delete(key) is None:
return JSONResponse(content={"status": "error",
"message": "Document not found"},
status_code=400)
- return JSONResponse(content={"status": "success", "docs": {}})
+ return JSONResponse(content={"status": "success", "docs": data})
diff --git a/src/routers/scanner.py b/src/routers/scanner.py
index 9bb0f98..645cd74 100644
--- a/src/routers/scanner.py
+++ b/src/routers/scanner.py
@@ -84,9 +84,9 @@ async def callhome(uuid, data: Request, Authorize: AuthJWT = Depends()):
else:
if Scanner.add(uuid):
- return JSONResponse(content={"status": "success",
- "message": "Scanner added."},
- status_code=200)
+ return JSONResponse(content={"status": "error",
+ "message": "Scanner added but disabled."},
+ status_code=400)
else:
return JSONResponse(content={"status": "error",
"message": "Failed to add scanner."},