diff options
Diffstat (limited to 'src/wsgi.py')
-rwxr-xr-x | src/wsgi.py | 60 |
1 files changed, 42 insertions, 18 deletions
diff --git a/src/wsgi.py b/src/wsgi.py index b690bc0..5f56fb9 100755 --- a/src/wsgi.py +++ b/src/wsgi.py @@ -1,3 +1,5 @@ +import os +import sys import uvicorn from fastapi import FastAPI, Depends, Request @@ -12,19 +14,29 @@ from db import DictDB app = FastAPI() db = DictDB() -public_key = """-----BEGIN PUBLIC KEY----- -MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEPW8bkkVIq4BX8eWwlUOUYbJhiGDv -K/6xY5T0BsvV6pbMoIUfgeThVOq5I3CmXxLt+qyPska6ol9fTN7woZLsCg== ------END PUBLIC KEY-----""" + +def get_pubkey(): + try: + keypath = os.environ['JWT_PUBKEY_PATH'] + + with open(keypath, "r") as fd: + pubkey = fd.read() + except KeyError: + print("Could not find environment variable JWT_PUBKEY_PATH") + sys.exit(-1) + except FileNotFoundError: + print(f"Could not find JWT certificate in {keypath}") + sys.exit(-1) + + return pubkey def get_data(key=None, limit=25, skip=0, ip=None, - port=None, asn=None): + port=None, asn=None, domain=None): selectors = dict() indexes = CouchIindex().dict() - - selectors['domain'] = 'sunet.se' + selectors['domain'] = domain if ip and 'ip' in indexes: selectors['ip'] = ip @@ -35,12 +47,12 @@ def get_data(key=None, limit=25, skip=0, ip=None, data = db.search(**selectors, limit=limit, skip=skip) - return JSONResponse(content={"status": "success", "data": data}) + return data class JWTConfig(BaseModel): authjwt_algorithm: str = "ES256" - authjwt_public_key: str = public_key + authjwt_public_key: str = get_pubkey() @AuthJWT.load_config @@ -67,7 +79,21 @@ async def get(key=None, limit=25, skip=0, ip=None, port=None, Authorize.jwt_required() - return get_data(key, limit, skip, ip, port, asn) + data = [] + raw_jwt = Authorize.get_raw_jwt() + + if 'domains' not in raw_jwt: + return JSONResponse(content={"status": "error", + "message": "Could not find domains" + + "claim in JWT token"}, + status_code=400) + else: + domains = raw_jwt['domains'] + + for domain in domains: + data.extend(get_data(key, limit, skip, ip, port, asn, domain)) + + return JSONResponse(content={"statuc": "success", "docs": data}) @app.get('/sc/v0/get/{key}') @@ -75,19 +101,17 @@ async def get_key(key=None, Authorize: AuthJWT = Depends()): Authorize.jwt_required() - return get_data(key) + data = get_data(key) + + return JSONResponse(content={"statuc": "success", "docs": data}) @app.post('/sc/v0/add') async def add(data: Request, Authorize: AuthJWT = Depends()): - Authorize.jwt_required() - - orgs = ['sunet.se'] - - if not orgs: - return JSONResponse(content={"status": "error", "message": - "Could not find an organization"}, status_code=400) + # Maybe we should protect this enpoint too and let the scanner use + # a JWT token as well. + # Authorize.jwt_required() json_data = await data.json() |