import os import sys import uvicorn from fastapi import FastAPI, Depends, Request from fastapi.responses import JSONResponse from fastapi_jwt_auth import AuthJWT from fastapi_jwt_auth.exceptions import AuthJWTException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from index import CouchIindex import time from db import DictDB import requests app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["http://localhost:8001"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], expose_headers=["X-Total-Count"], ) # TODO: X-Total-Count @app.middleware("http") async def mock_x_total_count_header(request: Request, call_next): response = await call_next(request) response.headers["X-Total-Count"] = "100" return response for i in range(10): try: db = DictDB() except requests.exceptions.ConnectionError: print(f'Database not responding, will try again soon.' + 'Attempt {i + 1} of 10.') else: break time.sleep(10) else: print('Database did not respond after 10 attempts, quitting.') sys.exit(-1) def get_pubkey(): try: if 'JWT_PUBKEY_PATH' in os.environ: keypath = os.environ['JWT_PUBKEY_PATH'] else: keypath = '/opt/certs/public.pem' with open(keypath, "r") as fd: pubkey = fd.read() except FileNotFoundError: print(f"Could not find JWT certificate in {keypath}") sys.exit(-1) return pubkey def get_data(key=None, limit=25, skip=0, ip=None, port=None, asn=None, domain=None): if key: return db.get(key) selectors = dict() indexes = CouchIindex().dict() selectors['domain'] = domain if ip and 'ip' in indexes: selectors['ip'] = ip if port and 'port' in indexes: selectors['port'] = port if asn and 'asn' in indexes: selectors['asn'] = asn data = db.search(**selectors, limit=limit, skip=skip) return data class JWTConfig(BaseModel): authjwt_algorithm: str = "ES256" authjwt_public_key: str = get_pubkey() @AuthJWT.load_config def jwt_config(): return JWTConfig() @app.exception_handler(AuthJWTException) def authjwt_exception_handler(request: Request, exc: AuthJWTException): return JSONResponse(content={"status": "error", "message": exc.message}, status_code=400) @app.exception_handler(RuntimeError) def app_exception_handler(request: Request, exc: RuntimeError): return JSONResponse(content={"status": "error", "message": str(exc.with_traceback(None))}, status_code=400) @app.get('/sc/v0/get') async def get(key=None, limit=25, skip=0, ip=None, port=None, asn=None, Authorize: AuthJWT = Depends()): Authorize.jwt_required() data = [] raw_jwt = Authorize.get_raw_jwt() if "read" not in raw_jwt: return JSONResponse( content={ "status": "error", "message": "Could not find read claim in JWT token", }, status_code=400, ) else: domains = raw_jwt["read"] for domain in domains: data.extend(get_data(key, limit, skip, ip, port, asn, domain)) return JSONResponse(content={"status": "success", "docs": data}) @app.get('/sc/v0/get/{key}') async def get_key(key=None, Authorize: AuthJWT = Depends()): Authorize.jwt_required() raw_jwt = Authorize.get_raw_jwt() if "read" not in raw_jwt: return JSONResponse( content={ "status": "error", "message": "Could not find read claim in JWT token", }, status_code=400, ) else: allowed_domains = raw_jwt["read"] data = get_data(key) if data["domain"] not in allowed_domains: return JSONResponse( content={ "status": "error", "message": "User not authorized to view this object", }, status_code=400, ) return JSONResponse(content={"status": "success", "docs": data}) @app.post('/sc/v0/add') async def add(data: Request, Authorize: AuthJWT = Depends()): # Maybe we should protect this enpoint too and let the scanner use # a JWT token as well. # Authorize.jwt_required() json_data = await data.json() key = db.add(json_data) return JSONResponse(content={"status": "success", "docs": key}) @app.delete('/sc/v0/delete/{key}') async def delete(key, Authorize: AuthJWT = Depends()): Authorize.jwt_required() raw_jwt = Authorize.get_raw_jwt() if "write" not in raw_jwt: return JSONResponse( content={ "status": "error", "message": "Could not find write claim in JWT token", }, status_code=400, ) else: allowed_domains = raw_jwt["write"] data = get_data(key) if data["domain"] not in allowed_domains: return JSONResponse( content={ "status": "error", "message": "User not authorized to delete this object", }, status_code=400, ) if db.delete(key) is None: return JSONResponse(content={"status": "error", "message": "Document not found"}, status_code=400) return JSONResponse(content={"status": "success", "docs": data}) def main(standalone=False): if not standalone: return app uvicorn.run(app, host="0.0.0.0", port=8000, log_level="debug") if __name__ == '__main__': main(standalone=True) else: app = main()