diff options
Diffstat (limited to 'src/main.py')
-rwxr-xr-x | src/main.py | 113 |
1 files changed, 7 insertions, 106 deletions
diff --git a/src/main.py b/src/main.py index f95a09c..aa3b133 100755 --- a/src/main.py +++ b/src/main.py @@ -1,20 +1,20 @@ import os import sys -import uvicorn -from fastapi import FastAPI, Depends, Request +import uvicorn +from fastapi import Depends, FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from fastapi_jwt_auth import AuthJWT from fastapi_jwt_auth.exceptions import AuthJWTException -from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel -from index import CouchIindex -import time -from db import DictDB -import requests + +import routers app = FastAPI() +app.include_router(routers.router, prefix='/sc/v0') + app.add_middleware( CORSMiddleware, allow_origins=["http://localhost:8001"], @@ -24,28 +24,12 @@ app.add_middleware( expose_headers=["X-Total-Count"], ) -# TODO: X-Total-Count - - @app.middleware("http") async def mock_x_total_count_header(request: Request, call_next): response = await call_next(request) response.headers["X-Total-Count"] = "100" return response -for i in range(10): - try: - db = DictDB() - except requests.exceptions.ConnectionError: - print(f'Database not responding, will try again soon.' + - 'Attempt {i + 1} of 10.') - else: - break - time.sleep(10) -else: - print('Database did not respond after 10 attempts, quitting.') - sys.exit(-1) - def get_pubkey(): try: @@ -63,27 +47,6 @@ def get_pubkey(): return pubkey -def get_data(key=None, limit=25, skip=0, ip=None, - port=None, asn=None, domain=None): - if key: - return db.get(key) - - selectors = dict() - indexes = CouchIindex().dict() - selectors['domain'] = domain - - if ip and 'ip' in indexes: - selectors['ip'] = ip - if port and 'port' in indexes: - selectors['port'] = port - if asn and 'asn' in indexes: - selectors['asn'] = asn - - data = db.search(**selectors, limit=limit, skip=skip) - - return data - - class JWTConfig(BaseModel): authjwt_algorithm: str = "ES256" authjwt_public_key: str = get_pubkey() @@ -107,68 +70,6 @@ def app_exception_handler(request: Request, exc: RuntimeError): status_code=400) -@app.get('/sc/v0/get') -async def get(key=None, limit=25, skip=0, ip=None, port=None, - asn=None, Authorize: AuthJWT = Depends()): - - Authorize.jwt_required() - - data = [] - raw_jwt = Authorize.get_raw_jwt() - - if 'domains' not in raw_jwt: - return JSONResponse(content={"status": "error", - "message": "Could not find domains" + - "claim in JWT token"}, - status_code=400) - else: - domains = raw_jwt['domains'] - - for domain in domains: - data.extend(get_data(key, limit, skip, ip, port, asn, domain)) - - return JSONResponse(content={"status": "success", "docs": data}) - - -@app.get('/sc/v0/get/{key}') -async def get_key(key=None, Authorize: AuthJWT = Depends()): - - Authorize.jwt_required() - - # TODO: Use JWT authz and check e.g. domain here - - data = get_data(key) - - return JSONResponse(content={"status": "success", "docs": data}) - - -@app.post('/sc/v0/add') -async def add(data: Request, Authorize: AuthJWT = Depends()): - - # Maybe we should protect this enpoint too and let the scanner use - # a JWT token as well. - # Authorize.jwt_required() - - json_data = await data.json() - - key = db.add(json_data) - - return JSONResponse(content={"status": "success", "docs": key}) - - -@app.delete('/sc/v0/delete/{key}') -async def delete(key, Authorize: AuthJWT = Depends()): - - Authorize.jwt_required() - - if db.delete(key) is None: - return JSONResponse(content={"status": "error", - "message": "Document not found"}, - status_code=400) - - return JSONResponse(content={"status": "success", "docs": {}}) - - def main(standalone=False): if not standalone: return app |