diff options
author | Kristofer Hallin <kristofer@sunet.se> | 2022-01-17 14:01:08 +0100 |
---|---|---|
committer | Kristofer Hallin <kristofer@sunet.se> | 2022-01-17 14:01:08 +0100 |
commit | bb5029d512a58021718061aca439383c8b11e575 (patch) | |
tree | 74354b6bf55a9159695eea695653ef03009e5ad4 /src | |
parent | 571997129ba5275cc5e148a8ac1c0f64d895a9ef (diff) | |
parent | 0b55f7ff7cdd3b78bd9992063208476c1c080a02 (diff) |
* Merge branch 'main' into feature.callhome
* New API endpoints
* Updated requirements
Diffstat (limited to 'src')
-rwxr-xr-x | src/authn.py | 112 | ||||
-rwxr-xr-x | src/db/db.py | 9 | ||||
-rw-r--r-- | src/db/scanner.py | 30 | ||||
-rwxr-xr-x | src/main.py | 1 | ||||
-rw-r--r-- | src/routers/collector.py | 67 | ||||
-rw-r--r-- | src/routers/scanner.py | 6 |
6 files changed, 80 insertions, 145 deletions
diff --git a/src/authn.py b/src/authn.py deleted file mode 100755 index e90118a..0000000 --- a/src/authn.py +++ /dev/null @@ -1,112 +0,0 @@ -#! /usr/bin/env python3 - -import yaml - - -class Authz: - def __init__(self, org, perms): - self._org = org - self._perms = perms - - def dump(self): - return "{}: {}".format(self._org, self._perms) - - def read_p(self): - return 'r' in self._perms - - def write_p(self): - return 'w' in self._perms - - -class User: - def __init__(self, username, pw, authz): - self._username = username - self._password = pw - self._authz = {} - for org, perms in authz.items(): - self._authz[org] = Authz(org, perms) - - def dump(self): - return ["{}/{}: {}".format(self._username, self._password, auth.dump()) - for auth in self._authz.values()] - - def authn_p(self, pw): - return pw == self._password - - def orgnames(self): - return [x for x in self._authz.keys()] - - def read_perms(self): - acc = [] - for k, v in self._authz.items(): - if v.read_p(): - acc.append(k) - return acc - - def write_perms(self): - acc = [] - for k, v in self._authz.items(): - if v.write_p(): - acc.append(k) - return acc - - -class UserDB: - def __init__(self, yamlfile): - self._users = {} - for u, d in yaml.safe_load(open(yamlfile)).items(): - self._users[u] = User(u, d['pw'], d['authz']) - - def dump(self): - return [u.dump() for u in self._users.values()] - - def user_authn_p(self, username, password): - user = self._users.get(username) - if not user: - return False - return user.authn_p(password) - - def orgs_for_user(self, username): - return self._users.get(username).orgnames() - - def read_perms(self, username, password): - user = self._users.get(username) - if not user: - return None - if not user.authn_p(password): - return None - return user.read_perms() - - def write_perms(self, username, password): - user = self._users.get(username) - if not user: - return None - if not user.authn_p(password): - return None - return user.write_perms() - - -def self_test(): - db = UserDB('userdb.yaml') - print(db.dump()) - - orgs = db.orgs_for_user('user3') - assert('sunet.se' in orgs) - assert('su.se' in orgs) - assert(len(orgs) == 2) - - assert(db.user_authn_p('user3', 'pw3') == True) - assert(db.user_authn_p('user3', 'wrongpw') == False) - - rp = db.read_perms('user3', 'pw3') - assert(len(rp) == 2) - assert('sunet.se' in rp) - assert('su.se' in rp) - - wp = db.write_perms('user3', 'pw3') - assert(len(wp) == 1) - assert('sunet.se' in wp) - - -if __name__ == '__main__': - self_test() diff --git a/src/db/db.py b/src/db/db.py index cbb87ce..3926fda 100755 --- a/src/db/db.py +++ b/src/db/db.py @@ -1,12 +1,3 @@ -# A database storing dictionaries, keyed on a timestamp. value = A -# dict which will be stored as a JSON object encoded in UTF-8. Note -# that dict keys of type integer or float will become strings while -# values will keep their type. - -# Note that there's a (slim) chance that you'd stomp on the previous -# value if you're too quick with generating the timestamps, ie -# invoking time.time() several times quickly enough. - import os import sys import time diff --git a/src/db/scanner.py b/src/db/scanner.py index e9ac8c3..625fd8e 100644 --- a/src/db/scanner.py +++ b/src/db/scanner.py @@ -1,7 +1,7 @@ import enum from datetime import datetime -from sqlalchemy import (Boolean, Column, DateTime, Integer, Unicode, +from sqlalchemy import (JSON, Boolean, Column, DateTime, Integer, Unicode, UniqueConstraint, create_engine) from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.declarative import declarative_base @@ -22,13 +22,12 @@ class Scanner(Base): id = Column(Integer, autoincrement=True, primary_key=True) uuid = Column(Unicode(37), nullable=False) - enabled = Column(Boolean, nullable=False) + enabled = Column(Boolean, default=False, nullable=False) first_seen = Column(DateTime, default=datetime.utcnow, nullable=False) last_seen = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) - comment = Column(Unicode(255), nullable=True) - scanners = Column(Unicode(2048), nullable=False) - targets = Column(Unicode(255), nullable=True) + comment = Column(Unicode(255), default="", nullable=True) + scanners = Column(JSON, nullable=True) def as_dict(self): """Return JSON serializable dict.""" @@ -47,7 +46,8 @@ class Scanner(Base): @classmethod def comment(cls, uuid, comment): with SqlDB.sql_session() as session: - scanner: Scanner = session.query(Scanner).filter(Scanner.uuid == uuid).one_or_none() + scanner: Scanner = session.query(Scanner).filter( + Scanner.uuid == uuid).one_or_none() if scanner: scanner.comment = comment @@ -58,7 +58,8 @@ class Scanner(Base): @classmethod def enable(cls, uuid): with SqlDB.sql_session() as session: - scanner: Scanner = session.query(Scanner).filter(Scanner.uuid == uuid).one_or_none() + scanner: Scanner = session.query(Scanner).filter( + Scanner.uuid == uuid).one_or_none() if scanner: scanner.enabled = True @@ -69,7 +70,8 @@ class Scanner(Base): @classmethod def disable(cls, uuid): with SqlDB.sql_session() as session: - scanner: Scanner = session.query(Scanner).filter(Scanner.uuid == uuid).one_or_none() + scanner: Scanner = session.query(Scanner).filter( + Scanner.uuid == uuid).one_or_none() if scanner: scanner.enabled = False @@ -84,9 +86,11 @@ class Scanner(Base): with SqlDB.sql_session() as session: if scanner_id: - scanner: Scanner = session.query(Scanner).filter(Scanner.id == scanner_id).one_or_none() + scanner: Scanner = session.query(Scanner).filter( + Scanner.id == scanner_id).one_or_none() elif uuid: - scanner: Scanner = session.query(Scanner).filter(Scanner.uuid == uuid).one_or_none() + scanner: Scanner = session.query(Scanner).filter( + Scanner.uuid == uuid).one_or_none() else: return None @@ -104,7 +108,7 @@ class Scanner(Base): scanner = Scanner() scanner.uuid = uuid scanner.enabled = False - scanner.scanners = "None" + scanner.scanners = {} session.add(scanner) session.flush() @@ -113,11 +117,11 @@ class Scanner(Base): except IntegrityError: return None - @classmethod def is_enabled(cls, uuid): with SqlDB.sql_session() as session: - scanner: Scanner = session.query(Scanner).filter(Scanner.uuid == uuid).one_or_none() + scanner: Scanner = session.query(Scanner).filter( + Scanner.uuid == uuid).one_or_none() if scanner is None: return None diff --git a/src/main.py b/src/main.py index aa3b133..a65971d 100755 --- a/src/main.py +++ b/src/main.py @@ -24,6 +24,7 @@ app.add_middleware( expose_headers=["X-Total-Count"], ) + @app.middleware("http") async def mock_x_total_count_header(request: Request, call_next): response = await call_next(request) diff --git a/src/routers/collector.py b/src/routers/collector.py index 3cda23a..7d91609 100644 --- a/src/routers/collector.py +++ b/src/routers/collector.py @@ -48,18 +48,22 @@ def get_data(key=None, limit=25, skip=0, ip=None, @router.get('/get') async def get(key=None, limit=25, skip=0, ip=None, port=None, asn=None, Authorize: AuthJWT = Depends()): + Authorize.jwt_required() data = [] raw_jwt = Authorize.get_raw_jwt() - if 'domains' not in raw_jwt: - return JSONResponse(content={"status": "error", - "message": "Could not find domains" + - "claim in JWT token"}, - status_code=400) + if "read" not in raw_jwt: + return JSONResponse( + content={ + "status": "error", + "message": "Could not find read claim in JWT token", + }, + status_code=400, + ) else: - domains = raw_jwt['domains'] + domains = raw_jwt["read"] for domain in domains: data.extend(get_data(key, limit, skip, ip, port, asn, domain)) @@ -69,17 +73,39 @@ async def get(key=None, limit=25, skip=0, ip=None, port=None, @router.get('/get/{key}') async def get_key(key=None, Authorize: AuthJWT = Depends()): + Authorize.jwt_required() - # TODO: Use JWT authz and check e.g. domain here + raw_jwt = Authorize.get_raw_jwt() + + if "read" not in raw_jwt: + return JSONResponse( + content={ + "status": "error", + "message": "Could not find read claim in JWT token", + }, + status_code=400, + ) + else: + allowed_domains = raw_jwt["read"] data = get_data(key) + if data["domain"] not in allowed_domains: + return JSONResponse( + content={ + "status": "error", + "message": "User not authorized to view this object", + }, + status_code=400, + ) + return JSONResponse(content={"status": "success", "docs": data}) @router.post('/add') async def add(data: Request, Authorize: AuthJWT = Depends()): + Authorize.jwt_required() json_data = await data.json() @@ -91,11 +117,36 @@ async def add(data: Request, Authorize: AuthJWT = Depends()): @router.delete('/delete/{key}') async def delete(key, Authorize: AuthJWT = Depends()): + Authorize.jwt_required() + raw_jwt = Authorize.get_raw_jwt() + + if "write" not in raw_jwt: + return JSONResponse( + content={ + "status": "error", + "message": "Could not find write claim in JWT token", + }, + status_code=400, + ) + else: + allowed_domains = raw_jwt["write"] + + data = get_data(key) + + if data["domain"] not in allowed_domains: + return JSONResponse( + content={ + "status": "error", + "message": "User not authorized to delete this object", + }, + status_code=400, + ) + if db.delete(key) is None: return JSONResponse(content={"status": "error", "message": "Document not found"}, status_code=400) - return JSONResponse(content={"status": "success", "docs": {}}) + return JSONResponse(content={"status": "success", "docs": data}) diff --git a/src/routers/scanner.py b/src/routers/scanner.py index 9bb0f98..645cd74 100644 --- a/src/routers/scanner.py +++ b/src/routers/scanner.py @@ -84,9 +84,9 @@ async def callhome(uuid, data: Request, Authorize: AuthJWT = Depends()): else: if Scanner.add(uuid): - return JSONResponse(content={"status": "success", - "message": "Scanner added."}, - status_code=200) + return JSONResponse(content={"status": "error", + "message": "Scanner added but disabled."}, + status_code=400) else: return JSONResponse(content={"status": "error", "message": "Failed to add scanner."}, |