import json import os import sys import time import uvicorn from fastapi import Depends, FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from fastapi_jwt_auth import AuthJWT from fastapi_jwt_auth.exceptions import AuthJWTException from pydantic import BaseModel from db import DictDB from schema import get_index_keys, validate_collector_data app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["http://localhost:8001"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], expose_headers=["X-Total-Count"], ) # TODO: X-Total-Count @app.middleware("http") async def mock_x_total_count_header(request: Request, call_next): response = await call_next(request) response.headers["X-Total-Count"] = "100" return response for i in range(10): try: db = DictDB() except Exception: print( f'Database not responding, will try again soon. Attempt {i + 1} of 10.') else: break time.sleep(1) else: print('Database did not respond after 10 attempts, quitting.') sys.exit(-1) def get_pubkey(): try: if 'JWT_PUBKEY_PATH' in os.environ: keypath = os.environ['JWT_PUBKEY_PATH'] else: keypath = '/opt/certs/public.pem' with open(keypath, "r") as fd: pubkey = fd.read() except FileNotFoundError: print(f"Could not find JWT certificate in {keypath}") sys.exit(-1) return pubkey def get_data(key=None, limit=25, skip=0, ip=None, port=None, asn=None, domain=None): if key: return db.get(key) selectors = dict() indexes = get_index_keys() selectors['domain'] = domain if ip and 'ip' in indexes: selectors['ip'] = ip if port and 'port' in indexes: selectors['port'] = port if asn and 'asn' in indexes: selectors['asn'] = asn data = db.search(**selectors, limit=limit, skip=skip) return data class JWTConfig(BaseModel): authjwt_algorithm: str = "ES256" authjwt_public_key: str = get_pubkey() @AuthJWT.load_config def jwt_config(): return JWTConfig() @app.exception_handler(AuthJWTException) def authjwt_exception_handler(request: Request, exc: AuthJWTException): return JSONResponse(content={"status": "error", "message": exc.message}, status_code=400) @app.exception_handler(RuntimeError) def app_exception_handler(request: Request, exc: RuntimeError): return JSONResponse(content={"status": "error", "message": str(exc.with_traceback(None))}, status_code=400) @app.get('/sc/v0/get') async def get(key=None, limit=25, skip=0, ip=None, port=None, asn=None, Authorize: AuthJWT = Depends()): Authorize.jwt_required() data = [] raw_jwt = Authorize.get_raw_jwt() if "read" not in raw_jwt: return JSONResponse( content={ "status": "error", "message": "Could not find read claim in JWT token", }, status_code=400, ) else: domains = raw_jwt["read"] for domain in domains: data.extend(get_data(key, limit, skip, ip, port, asn, domain)) return JSONResponse(content={"status": "success", "docs": data}) @app.get('/sc/v0/get/{key}') async def get_key(key=None, Authorize: AuthJWT = Depends()): Authorize.jwt_required() raw_jwt = Authorize.get_raw_jwt() if "read" not in raw_jwt: return JSONResponse( content={ "status": "error", "message": "Could not find read claim in JWT token", }, status_code=400, ) else: allowed_domains = raw_jwt["read"] data = get_data(key) if data and data["domain"] not in allowed_domains: return JSONResponse( content={ "status": "error", "message": "User not authorized to view this object", }, status_code=400, ) return JSONResponse(content={"status": "success", "docs": data}) @app.post('/sc/v0/add') async def add(data: Request, Authorize: AuthJWT = Depends()): # Authorize.jwt_required() try: json_data = await data.json() except json.decoder.JSONDecodeError: return JSONResponse( content={ "status": "error", "message": "Invalid JSON.", }, status_code=400, ) key = db.add(json_data) if isinstance(key, str): return JSONResponse( content={ "status": "error", "message": key, }, status_code=400, ) return JSONResponse(content={"status": "success", "docs": key}) @app.delete('/sc/v0/delete/{key}') async def delete(key, Authorize: AuthJWT = Depends()): Authorize.jwt_required() raw_jwt = Authorize.get_raw_jwt() if "write" not in raw_jwt: return JSONResponse( content={ "status": "error", "message": "Could not find write claim in JWT token", }, status_code=400, ) else: allowed_domains = raw_jwt["write"] data = get_data(key) if data and data["domain"] not in allowed_domains: return JSONResponse( content={ "status": "error", "message": "User not authorized to delete this object", }, status_code=400, ) if db.delete(key) is None: return JSONResponse(content={"status": "error", "message": "Document not found"}, status_code=400) return JSONResponse(content={"status": "success", "docs": data}) def main(standalone=False): if not standalone: return app uvicorn.run(app, host="0.0.0.0", port=8000, log_level="debug") if __name__ == '__main__': main(standalone=True) else: app = main()