summaryrefslogtreecommitdiff
path: root/src/main.py
diff options
context:
space:
mode:
authorKristofer Hallin <kristofer@sunet.se>2021-11-17 09:52:05 +0100
committerKristofer Hallin <kristofer@sunet.se>2021-11-17 09:53:12 +0100
commitfc31d040886ddd9495a0318a7272468fe81a215e (patch)
treeeaa5ccdaee6abec69b6e9678f14972abf7c45abd /src/main.py
parentfadb0f24bb55697a1ba34611a4288d12e25065d1 (diff)
* Rename wsgi.py to main.py, we're not using WSGI.
* Added env variable.
Diffstat (limited to 'src/main.py')
-rwxr-xr-xsrc/main.py145
1 files changed, 145 insertions, 0 deletions
diff --git a/src/main.py b/src/main.py
new file mode 100755
index 0000000..9e028b0
--- /dev/null
+++ b/src/main.py
@@ -0,0 +1,145 @@
+import os
+import sys
+import uvicorn
+
+from fastapi import FastAPI, Depends, Request
+from fastapi.responses import JSONResponse
+from fastapi_jwt_auth import AuthJWT
+from fastapi_jwt_auth.exceptions import AuthJWTException
+from pydantic import BaseModel
+from index import CouchIindex
+import time
+from db import DictDB
+
+app = FastAPI()
+
+for i in range(10):
+ try:
+ db = DictDB()
+ except Exception:
+ print(
+ f'Database not responding, will try again soon. Attempt {i + 1} of 10.')
+ else:
+ break
+ time.sleep(10)
+else:
+ print('Database did not respond after 10 attempts, quitting.')
+ sys.exit(-1)
+
+
+def get_pubkey():
+ try:
+ if 'keypath' in os.environ:
+ keypath = os.environ['JWT_PUBKEY_PATH']
+ else:
+ keypath = '/opt/certs/public.pem'
+
+ with open(keypath, "r") as fd:
+ pubkey = fd.read()
+ except FileNotFoundError:
+ print(f"Could not find JWT certificate in {keypath}")
+ sys.exit(-1)
+
+ return pubkey
+
+
+def get_data(key=None, limit=25, skip=0, ip=None,
+ port=None, asn=None, domain=None):
+
+ selectors = dict()
+ indexes = CouchIindex().dict()
+ selectors['domain'] = domain
+
+ if ip and 'ip' in indexes:
+ selectors['ip'] = ip
+ if port and 'port' in indexes:
+ selectors['port'] = port
+ if asn and 'asn' in indexes:
+ selectors['asn'] = asn
+
+ data = db.search(**selectors, limit=limit, skip=skip)
+
+ return data
+
+
+class JWTConfig(BaseModel):
+ authjwt_algorithm: str = "ES256"
+ authjwt_public_key: str = get_pubkey()
+
+
+@AuthJWT.load_config
+def jwt_config():
+ return JWTConfig()
+
+
+@app.exception_handler(AuthJWTException)
+def authjwt_exception_handler(request: Request, exc: AuthJWTException):
+ return JSONResponse(content={"status": "error", "message":
+ exc.message}, status_code=400)
+
+
+@app.exception_handler(RuntimeError)
+def app_exception_handler(request: Request, exc: RuntimeError):
+ return JSONResponse(content={"status": "error", "message":
+ str(exc.with_traceback(None))},
+ status_code=400)
+
+
+@app.get('/sc/v0/get')
+async def get(key=None, limit=25, skip=0, ip=None, port=None,
+ asn=None, Authorize: AuthJWT = Depends()):
+
+ Authorize.jwt_required()
+
+ data = []
+ raw_jwt = Authorize.get_raw_jwt()
+
+ if 'domains' not in raw_jwt:
+ return JSONResponse(content={"status": "error",
+ "message": "Could not find domains" +
+ "claim in JWT token"},
+ status_code=400)
+ else:
+ domains = raw_jwt['domains']
+
+ for domain in domains:
+ data.extend(get_data(key, limit, skip, ip, port, asn, domain))
+
+ return JSONResponse(content={"status": "success", "docs": data})
+
+
+@app.get('/sc/v0/get/{key}')
+async def get_key(key=None, Authorize: AuthJWT = Depends()):
+
+ Authorize.jwt_required()
+
+ data = get_data(key)
+
+ return JSONResponse(content={"status": "success", "docs": data})
+
+
+@app.post('/sc/v0/add')
+async def add(data: Request, Authorize: AuthJWT = Depends()):
+
+ # Maybe we should protect this enpoint too and let the scanner use
+ # a JWT token as well.
+ # Authorize.jwt_required()
+
+ json_data = await data.json()
+
+ key = db.add(json_data)
+
+ return JSONResponse(content={"status": "success", "docs": key})
+
+
+def main(standalone=False):
+ if not standalone:
+ return app
+
+ uvicorn.run(app, host="0.0.0.0", port=8000, log_level="debug")
+
+
+if __name__ == '__main__':
+ main(standalone=True)
+else:
+ app = main()