diff options
author | Kristofer Hallin <kristofer@sunet.se> | 2021-11-17 09:52:05 +0100 |
---|---|---|
committer | Kristofer Hallin <kristofer@sunet.se> | 2021-11-17 09:53:12 +0100 |
commit | fc31d040886ddd9495a0318a7272468fe81a215e (patch) | |
tree | eaa5ccdaee6abec69b6e9678f14972abf7c45abd /src/main.py | |
parent | fadb0f24bb55697a1ba34611a4288d12e25065d1 (diff) |
* Rename wsgi.py to main.py, we're not using WSGI.
* Added env variable.
Diffstat (limited to 'src/main.py')
-rwxr-xr-x | src/main.py | 145 |
1 files changed, 145 insertions, 0 deletions
diff --git a/src/main.py b/src/main.py new file mode 100755 index 0000000..9e028b0 --- /dev/null +++ b/src/main.py @@ -0,0 +1,145 @@ +import os +import sys +import uvicorn + +from fastapi import FastAPI, Depends, Request +from fastapi.responses import JSONResponse +from fastapi_jwt_auth import AuthJWT +from fastapi_jwt_auth.exceptions import AuthJWTException +from pydantic import BaseModel +from index import CouchIindex +import time +from db import DictDB + +app = FastAPI() + +for i in range(10): + try: + db = DictDB() + except Exception: + print( + f'Database not responding, will try again soon. Attempt {i + 1} of 10.') + else: + break + time.sleep(10) +else: + print('Database did not respond after 10 attempts, quitting.') + sys.exit(-1) + + +def get_pubkey(): + try: + if 'keypath' in os.environ: + keypath = os.environ['JWT_PUBKEY_PATH'] + else: + keypath = '/opt/certs/public.pem' + + with open(keypath, "r") as fd: + pubkey = fd.read() + except FileNotFoundError: + print(f"Could not find JWT certificate in {keypath}") + sys.exit(-1) + + return pubkey + + +def get_data(key=None, limit=25, skip=0, ip=None, + port=None, asn=None, domain=None): + + selectors = dict() + indexes = CouchIindex().dict() + selectors['domain'] = domain + + if ip and 'ip' in indexes: + selectors['ip'] = ip + if port and 'port' in indexes: + selectors['port'] = port + if asn and 'asn' in indexes: + selectors['asn'] = asn + + data = db.search(**selectors, limit=limit, skip=skip) + + return data + + +class JWTConfig(BaseModel): + authjwt_algorithm: str = "ES256" + authjwt_public_key: str = get_pubkey() + + +@AuthJWT.load_config +def jwt_config(): + return JWTConfig() + + +@app.exception_handler(AuthJWTException) +def authjwt_exception_handler(request: Request, exc: AuthJWTException): + return JSONResponse(content={"status": "error", "message": + exc.message}, status_code=400) + + +@app.exception_handler(RuntimeError) +def app_exception_handler(request: Request, exc: RuntimeError): + return JSONResponse(content={"status": "error", "message": + str(exc.with_traceback(None))}, + status_code=400) + + +@app.get('/sc/v0/get') +async def get(key=None, limit=25, skip=0, ip=None, port=None, + asn=None, Authorize: AuthJWT = Depends()): + + Authorize.jwt_required() + + data = [] + raw_jwt = Authorize.get_raw_jwt() + + if 'domains' not in raw_jwt: + return JSONResponse(content={"status": "error", + "message": "Could not find domains" + + "claim in JWT token"}, + status_code=400) + else: + domains = raw_jwt['domains'] + + for domain in domains: + data.extend(get_data(key, limit, skip, ip, port, asn, domain)) + + return JSONResponse(content={"status": "success", "docs": data}) + + +@app.get('/sc/v0/get/{key}') +async def get_key(key=None, Authorize: AuthJWT = Depends()): + + Authorize.jwt_required() + + data = get_data(key) + + return JSONResponse(content={"status": "success", "docs": data}) + + +@app.post('/sc/v0/add') +async def add(data: Request, Authorize: AuthJWT = Depends()): + + # Maybe we should protect this enpoint too and let the scanner use + # a JWT token as well. + # Authorize.jwt_required() + + json_data = await data.json() + + key = db.add(json_data) + + return JSONResponse(content={"status": "success", "docs": key}) + + +def main(standalone=False): + if not standalone: + return app + + uvicorn.run(app, host="0.0.0.0", port=8000, log_level="debug") + + +if __name__ == '__main__': + main(standalone=True) +else: + app = main() |