diff options
Diffstat (limited to 'src/wsgi.py')
-rwxr-xr-x | src/wsgi.py | 280 |
1 files changed, 123 insertions, 157 deletions
diff --git a/src/wsgi.py b/src/wsgi.py index 8ab178a..b6a1a10 100755 --- a/src/wsgi.py +++ b/src/wsgi.py @@ -1,167 +1,133 @@ -#! /usr/bin/env python3 - import os import sys -import json -import authn -import index -import falcon +import uvicorn + +from fastapi import FastAPI, Depends, Request +from fastapi.responses import JSONResponse +from fastapi_jwt_auth import AuthJWT +from fastapi_jwt_auth.exceptions import AuthJWTException +from pydantic import BaseModel +from index import CouchIindex from db import DictDB -from base64 import b64decode -from wsgiref.simple_server import make_server - -try: - database = os.environ['COUCHDB_NAME'] - hostname = os.environ['COUCHDB_HOSTNAME'] - username = os.environ['COUCHDB_USER'] - password = os.environ['COUCHDB_PASSWORD'] -except KeyError: - print('The environment variables COUCHDB_NAME, COUCHDB_HOSTNAME,' + - ' COUCHDB_USER and COUCHDB_PASSWORD must be set.') - sys.exit(-1) - - -class CollectorResource(): - def __init__(self, db, users): - self._db = db - self._users = users - - def parse_error(data): - return "I want valid JSON but got this:\n{}\n".format(data) - - def user_auth(self, auth_header, authfun): - if not auth_header: - return None, None # Fail. - - BAlit, b64 = auth_header.split() - if BAlit != "Basic": - return None, None # Fail - - userbytes, pwbytes = b64decode(b64).split(b':') - try: - user = userbytes.decode('utf-8') - pw = pwbytes.decode('utf-8') - except Exception: - return None, None # Fail - return authfun(user, pw) - - -class EPGet(CollectorResource): - def on_get(self, req, resp, key=None): - out = list() - selectors = dict() - - limit = 25 - skip = 0 - - orgs = self.user_auth(req.auth, self._users.read_perms) - - if not orgs: - resp.status = falcon.HTTP_401 - resp.text = json.dumps({ - 'status': 'error', - 'message': 'Invalid username or password\n' - }) - return - - if key: - out = self._db.get(key) - resp.text = json.dumps({'status': 'success', 'data': out}) - return - - for param in req.params: - if param == 'limit': - limit = req.params['limit'] - elif param == 'skip': - skip = req.params['skip'] - for i in index.indexes: - for j in i['index']['fields']: - if j == param: - selectors[param] = req.params[param] - - for org in orgs: - selectors['domain'] = org - data = self._db.search(**selectors, limit=limit, skip=skip) - if data: - out += data - - resp.text = json.dumps({'status': 'success', 'data': out}) - - -class EPAdd(CollectorResource): - def on_post(self, req, resp): - resp.status = falcon.HTTP_200 - resp.content_type = falcon.MEDIA_TEXT - self._indata = [] - - orgs = self.user_auth(req.auth, self._users.write_perms) - if not orgs: - resp.status = falcon.HTTP_401 - resp.text = json.dumps( - {'status': 'error', 'message': 'Invalid user or password\n'}) - return - - # NOTE: Allowing writing to _any_ org! - # TODO: Allow only input where input.domain in orgs == True. - - # TODO: can we do json.load(req.bounded_stream, - # cls=customDecoder) where our decoder calls JSONDecoder after - # decoding UTF-8? - - # NOTE: Reading the whole body in one go instead of streaming - # it nicely. - rawin = req.bounded_stream.read() - - try: - decodedin = rawin.decode('UTF-8') - except Exception: - resp.status = falcon.HTTP_400 - resp.text = json.dumps( - {'status': 'error', 'message': 'Need UTF-8\n'}) - return - - try: - json_data = json.loads(decodedin) - except TypeError: - print('DEBUG: type error') - resp.status = falcon.HTTP_400 - resp.text = json.dumps( - {'status': 'error', 'message': CollectorResource.parse_error(decodedin)}) - - return - except json.decoder.JSONDecodeError: - print('DEBUG: json decode error') - resp.status = falcon.HTTP_400 - resp.text = json.dumps( - {'status': 'error', 'message': CollectorResource.parse_error(decodedin)}) - return - - keys = self._db.add(json_data) - resp.text = json.dumps({'status': 'success', 'key': keys}) - - -def main(port=8000, wsgi_helper=False): - db = DictDB(database, hostname, username, password) - users = authn.UserDB('wsgi_demo_users.yaml') - - app = falcon.App(cors_enable=True) - app.add_route('/sc/v0/add', EPAdd(db, users)) - app.add_route('/sc/v0/get', EPGet(db, users)) - app.add_route('/sc/v0/get/{key}', EPGet(db, users)) - - if wsgi_helper: + +app = FastAPI() +db = DictDB() + + +def get_pubkey(): + try: + keypath = os.environ['JWT_PUBKEY_PATH'] + + with open(keypath, "r") as fd: + pubkey = fd.read() + except KeyError: + print("Could not find environment variable JWT_PUBKEY_PATH") + sys.exit(-1) + except FileNotFoundError: + print(f"Could not find JWT certificate in {keypath}") + sys.exit(-1) + + return pubkey + + +def get_data(key=None, limit=25, skip=0, ip=None, + port=None, asn=None, domain=None): + + selectors = dict() + indexes = CouchIindex().dict() + selectors['domain'] = domain + + if ip and 'ip' in indexes: + selectors['ip'] = ip + if port and 'port' in indexes: + selectors['port'] = port + if asn and 'asn' in indexes: + selectors['asn'] = asn + + data = db.search(**selectors, limit=limit, skip=skip) + + return data + + +class JWTConfig(BaseModel): + authjwt_algorithm: str = "ES256" + authjwt_public_key: str = get_pubkey() + + +@AuthJWT.load_config +def jwt_config(): + return JWTConfig() + + +@app.exception_handler(AuthJWTException) +def authjwt_exception_handler(request: Request, exc: AuthJWTException): + return JSONResponse(content={"status": "error", "message": + exc.message}, status_code=400) + + +@app.exception_handler(RuntimeError) +def app_exception_handler(request: Request, exc: RuntimeError): + return JSONResponse(content={"status": "error", "message": + str(exc.with_traceback(None))}, + status_code=400) + + +@app.get('/sc/v0/get') +async def get(key=None, limit=25, skip=0, ip=None, port=None, + asn=None, Authorize: AuthJWT = Depends()): + + Authorize.jwt_required() + + data = [] + raw_jwt = Authorize.get_raw_jwt() + + if 'domains' not in raw_jwt: + return JSONResponse(content={"status": "error", + "message": "Could not find domains" + + "claim in JWT token"}, + status_code=400) + else: + domains = raw_jwt['domains'] + + for domain in domains: + data.extend(get_data(key, limit, skip, ip, port, asn, domain)) + + return JSONResponse(content={"status": "success", "docs": data}) + + +@app.get('/sc/v0/get/{key}') +async def get_key(key=None, Authorize: AuthJWT = Depends()): + + Authorize.jwt_required() + + data = get_data(key) + + return JSONResponse(content={"status": "success", "docs": data}) + + +@app.post('/sc/v0/add') +async def add(data: Request, Authorize: AuthJWT = Depends()): + + # Maybe we should protect this enpoint too and let the scanner use + # a JWT token as well. + # Authorize.jwt_required() + + json_data = await data.json() + + key = db.add(json_data) + + return JSONResponse(content={"status": "success", "docs": key}) + + +def main(standalone=False): + if not standalone: return app - print('Serving on port 8000...') - httpd = make_server('', port, app) - httpd.serve_forever() + uvicorn.run(app, host="0.0.0.0", port=8000, log_level="debug") if __name__ == '__main__': - try: - sys.exit(main()) - except KeyboardInterrupt: - print('\nBye!') + main(standalone=True) else: - app = main(port=8000, wsgi_helper=True) + app = main() |