diff options
| -rw-r--r-- | requirements.txt | 24 | ||||
| -rwxr-xr-x | src/db.py | 31 | ||||
| -rw-r--r-- | src/index.py | 24 | ||||
| -rwxr-xr-x | src/wsgi.py | 280 | ||||
| -rw-r--r-- | tools/jwt_producer.py | 45 | 
5 files changed, 224 insertions, 180 deletions
| diff --git a/requirements.txt b/requirements.txt index 2447183..93afa37 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,21 @@ -falcon -pyyaml -requests +anyio==3.3.4 +asgiref==3.4.1 +certifi==2021.10.8 +cffi==1.15.0 +charset-normalizer==2.0.7 +click==8.0.3 +cryptography==35.0.0 +fastapi==0.70.0 +fastapi-jwt-auth==0.5.0 +h11==0.12.0 +idna==3.3 +pycparser==2.20 +pydantic==1.8.2 +PyJWT==1.7.1 +requests==2.26.0 +sniffio==1.2.0 +starlette==0.16.0 +typing-extensions==3.10.0.2 +urllib3==1.26.7 +uvicorn==0.15.0 + @@ -7,28 +7,41 @@  # value if you're too quick with generating the timestamps, ie  # invoking time.time() several times quickly enough. +import os +import sys  import time  import couch -import index + +from index import CouchIindex  class DictDB(): -    def __init__(self, database, hostname, username, password): +    def __init__(self):          """          Check if the database exists, otherwise we will create it together          with the indexes specified in index.py.          """ +        try: +            self.database = os.environ['COUCHDB_NAME'] +            self.hostname = os.environ['COUCHDB_HOSTNAME'] +            self.username = os.environ['COUCHDB_USER'] +            self.password = os.environ['COUCHDB_PASSWORD'] +        except KeyError: +            print('The environment variables COUCHDB_NAME, COUCHDB_HOSTNAME,' + +                  ' COUCHDB_USER and COUCHDB_PASSWORD must be set.') +            sys.exit(-1) +          self.server = couch.client.Server( -            f"http://{username}:{password}@{hostname}:5984/") +            f"http://{self.username}:{self.password}@{self.hostname}:5984/")          try: -            self.couchdb = self.server.database(database) +            self.couchdb = self.server.database(self.database)          except couch.exceptions.NotFound:              print("Creating database and indexes.") -            self.couchdb = self.server.create(database) +            self.couchdb = self.server.create(self.database) -            for i in index.indexes: +            for i in CouchIindex():                  self.couchdb.index(i)          self._ts = time.time() @@ -54,12 +67,12 @@ class DictDB():          if type(data) is list:              for item in data:                  item['_id'] = str(self.unique_key()) -            self.couchdb.save_bulk(data) +            ret = self.couchdb.save_bulk(data)          else:              data['_id'] = str(self.unique_key()) -            self.couchdb.save(data) +            ret = self.couchdb.save(data) -        return True +        return ret      def get(self, key):          """ diff --git a/src/index.py b/src/index.py index 837f47e..3541ec7 100644 --- a/src/index.py +++ b/src/index.py @@ -1,5 +1,8 @@ -indexes = [ -    { +from pydantic import BaseSettings + + +class CouchIindex(BaseSettings): +    domain: dict = {          "index": {              "fields": [                  "domain", @@ -7,8 +10,8 @@ indexes = [          },          "name": "domain-json-index",          "type": "json" -    }, -    { +    } +    ip: dict = {          "index": {              "fields": [                  "domain", @@ -17,8 +20,8 @@ indexes = [          },          "name": "ip-json-index",          "type": "json" -    }, -    { +    } +    port: dict = {          "index": {              "fields": [                  "domain", @@ -27,8 +30,8 @@ indexes = [          },          "name": "port-json-index",          "type": "json" -    }, -    { +    } +    asn: dict = {          "index": {              "fields": [                  "domain", @@ -37,8 +40,8 @@ indexes = [          },          "name": "asn-json-index",          "type": "json" -    }, -    { +    } +    asn_country_code: dict = {          "index": {              "fields": [                  "domain", @@ -48,4 +51,3 @@ indexes = [          "name": "asn-country-code-json-index",          "type": "json"      } -] diff --git a/src/wsgi.py b/src/wsgi.py index 8ab178a..b6a1a10 100755 --- a/src/wsgi.py +++ b/src/wsgi.py @@ -1,167 +1,133 @@ -#! /usr/bin/env python3 -  import os  import sys -import json -import authn -import index -import falcon +import uvicorn + +from fastapi import FastAPI, Depends, Request +from fastapi.responses import JSONResponse +from fastapi_jwt_auth import AuthJWT +from fastapi_jwt_auth.exceptions import AuthJWTException +from pydantic import BaseModel +from index import CouchIindex  from db import DictDB -from base64 import b64decode -from wsgiref.simple_server import make_server - -try: -    database = os.environ['COUCHDB_NAME'] -    hostname = os.environ['COUCHDB_HOSTNAME'] -    username = os.environ['COUCHDB_USER'] -    password = os.environ['COUCHDB_PASSWORD'] -except KeyError: -    print('The environment variables COUCHDB_NAME, COUCHDB_HOSTNAME,' + -          ' COUCHDB_USER and COUCHDB_PASSWORD must be set.') -    sys.exit(-1) - - -class CollectorResource(): -    def __init__(self, db, users): -        self._db = db -        self._users = users - -    def parse_error(data): -        return "I want valid JSON but got this:\n{}\n".format(data) - -    def user_auth(self, auth_header, authfun): -        if not auth_header: -            return None, None   # Fail. - -        BAlit, b64 = auth_header.split() -        if BAlit != "Basic": -            return None, None   # Fail - -        userbytes, pwbytes = b64decode(b64).split(b':') -        try: -            user = userbytes.decode('utf-8') -            pw = pwbytes.decode('utf-8') -        except Exception: -            return None, None   # Fail -        return authfun(user, pw) - - -class EPGet(CollectorResource): -    def on_get(self, req, resp, key=None): -        out = list() -        selectors = dict() - -        limit = 25 -        skip = 0 - -        orgs = self.user_auth(req.auth, self._users.read_perms) - -        if not orgs: -            resp.status = falcon.HTTP_401 -            resp.text = json.dumps({ -                'status': 'error', -                'message': 'Invalid username or password\n' -            }) -            return - -        if key: -            out = self._db.get(key) -            resp.text = json.dumps({'status': 'success', 'data': out}) -            return - -        for param in req.params: -            if param == 'limit': -                limit = req.params['limit'] -            elif param == 'skip': -                skip = req.params['skip'] -            for i in index.indexes: -                for j in i['index']['fields']: -                    if j == param: -                        selectors[param] = req.params[param] - -        for org in orgs: -            selectors['domain'] = org -            data = self._db.search(**selectors, limit=limit, skip=skip) -            if data: -                out += data - -        resp.text = json.dumps({'status': 'success', 'data': out}) - - -class EPAdd(CollectorResource): -    def on_post(self, req, resp): -        resp.status = falcon.HTTP_200 -        resp.content_type = falcon.MEDIA_TEXT -        self._indata = [] - -        orgs = self.user_auth(req.auth, self._users.write_perms) -        if not orgs: -            resp.status = falcon.HTTP_401 -            resp.text = json.dumps( -                {'status': 'error', 'message': 'Invalid user or password\n'}) -            return - -        # NOTE: Allowing writing to _any_ org! -        # TODO: Allow only input where input.domain in orgs == True. - -        # TODO: can we do json.load(req.bounded_stream, -        # cls=customDecoder) where our decoder calls JSONDecoder after -        # decoding UTF-8? - -        # NOTE: Reading the whole body in one go instead of streaming -        # it nicely. -        rawin = req.bounded_stream.read() - -        try: -            decodedin = rawin.decode('UTF-8') -        except Exception: -            resp.status = falcon.HTTP_400 -            resp.text = json.dumps( -                {'status': 'error', 'message': 'Need UTF-8\n'}) -            return - -        try: -            json_data = json.loads(decodedin) -        except TypeError: -            print('DEBUG: type error') -            resp.status = falcon.HTTP_400 -            resp.text = json.dumps( -                {'status': 'error', 'message': CollectorResource.parse_error(decodedin)}) - -            return -        except json.decoder.JSONDecodeError: -            print('DEBUG: json decode error') -            resp.status = falcon.HTTP_400 -            resp.text = json.dumps( -                {'status': 'error', 'message': CollectorResource.parse_error(decodedin)}) -            return - -        keys = self._db.add(json_data) -        resp.text = json.dumps({'status': 'success', 'key': keys}) - - -def main(port=8000, wsgi_helper=False): -    db = DictDB(database, hostname, username, password) -    users = authn.UserDB('wsgi_demo_users.yaml') - -    app = falcon.App(cors_enable=True) -    app.add_route('/sc/v0/add', EPAdd(db, users)) -    app.add_route('/sc/v0/get', EPGet(db, users)) -    app.add_route('/sc/v0/get/{key}', EPGet(db, users)) - -    if wsgi_helper: + +app = FastAPI() +db = DictDB() + + +def get_pubkey(): +    try: +        keypath = os.environ['JWT_PUBKEY_PATH'] + +        with open(keypath, "r") as fd: +            pubkey = fd.read() +    except KeyError: +        print("Could not find environment variable JWT_PUBKEY_PATH") +        sys.exit(-1) +    except FileNotFoundError: +        print(f"Could not find JWT certificate in {keypath}") +        sys.exit(-1) + +    return pubkey + + +def get_data(key=None, limit=25, skip=0, ip=None, +             port=None, asn=None, domain=None): + +    selectors = dict() +    indexes = CouchIindex().dict() +    selectors['domain'] = domain + +    if ip and 'ip' in indexes: +        selectors['ip'] = ip +    if port and 'port' in indexes: +        selectors['port'] = port +    if asn and 'asn' in indexes: +        selectors['asn'] = asn + +    data = db.search(**selectors, limit=limit, skip=skip) + +    return data + + +class JWTConfig(BaseModel): +    authjwt_algorithm: str = "ES256" +    authjwt_public_key: str = get_pubkey() + + +@AuthJWT.load_config +def jwt_config(): +    return JWTConfig() + + +@app.exception_handler(AuthJWTException) +def authjwt_exception_handler(request: Request, exc: AuthJWTException): +    return JSONResponse(content={"status": "error", "message": +                                 exc.message}, status_code=400) + + +@app.exception_handler(RuntimeError) +def app_exception_handler(request: Request, exc: RuntimeError): +    return JSONResponse(content={"status": "error", "message": +                                 str(exc.with_traceback(None))}, +                        status_code=400) + + +@app.get('/sc/v0/get') +async def get(key=None, limit=25, skip=0, ip=None, port=None, +              asn=None, Authorize: AuthJWT = Depends()): + +    Authorize.jwt_required() + +    data = [] +    raw_jwt = Authorize.get_raw_jwt() + +    if 'domains' not in raw_jwt: +        return JSONResponse(content={"status": "error", +                                     "message": "Could not find domains" + +                                     "claim in JWT token"}, +                            status_code=400) +    else: +        domains = raw_jwt['domains'] + +    for domain in domains: +        data.extend(get_data(key, limit, skip, ip, port, asn, domain)) + +    return JSONResponse(content={"status": "success", "docs": data}) + + +@app.get('/sc/v0/get/{key}') +async def get_key(key=None, Authorize: AuthJWT = Depends()): + +    Authorize.jwt_required() + +    data = get_data(key) + +    return JSONResponse(content={"status": "success", "docs": data}) + + +@app.post('/sc/v0/add') +async def add(data: Request, Authorize: AuthJWT = Depends()): + +    # Maybe we should protect this enpoint too and let the scanner use +    # a JWT token as well. +    # Authorize.jwt_required() + +    json_data = await data.json() + +    key = db.add(json_data) + +    return JSONResponse(content={"status": "success", "docs": key}) + + +def main(standalone=False): +    if not standalone:          return app -    print('Serving on port 8000...') -    httpd = make_server('', port, app) -    httpd.serve_forever() +    uvicorn.run(app, host="0.0.0.0", port=8000, log_level="debug")  if __name__ == '__main__': -    try: -        sys.exit(main()) -    except KeyboardInterrupt: -        print('\nBye!') +    main(standalone=True)  else: -    app = main(port=8000, wsgi_helper=True) +    app = main() diff --git a/tools/jwt_producer.py b/tools/jwt_producer.py new file mode 100644 index 0000000..3f8094d --- /dev/null +++ b/tools/jwt_producer.py @@ -0,0 +1,45 @@ +import sys +import jwt +import getopt + + +def usage(): +    progname = sys.argv[0] + +    print(f'{progname} -p <path to public key> -s <path to private key>' + +          '-d <domain>, for example sunet.se>') +    sys.exit(0) + + +def create_token(private_key, domain): +    payload = { +        'type': 'access', +        'domains': [domain]  # We'll just do one domain now +    } + +    with open(private_key, "r") as fd: +        key = fd.read() + +    return jwt.encode(payload=payload, algorithm='ES256', key=key) + + +if __name__ == '__main__': +    try: +        opts, args = getopt.getopt(sys.argv[1:], 'p:d:') +    except getopt.GetoptError: +        usage() + +    if len(sys.argv) != 5: +        usage() + +    for opt, arg in opts: +        if opt == '-p': +            private_key = arg +        elif opt == '-d': +            domain = arg +        else: +            usage() + +    token = create_token(private_key, domain).decode('utf-8') + +    print(f'{token}') | 
