summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKristofer Hallin <kristofer@sunet.se>2021-10-29 09:29:16 +0200
committerKristofer Hallin <kristofer@sunet.se>2021-10-29 09:29:16 +0200
commit2bfbe7568a8c6477de60a676d9027dcb9714af42 (patch)
treef4da7aee25d4728659182a2cc19197f341d165d3
parent34a353a539f71b6a87413b58ea483b36f94e3516 (diff)
Use FastAPI + JWT instead of Falcon.
-rw-r--r--requirements.txt24
-rwxr-xr-xsrc/db.py31
-rw-r--r--src/index.py24
-rwxr-xr-xsrc/wsgi.py258
4 files changed, 156 insertions, 181 deletions
diff --git a/requirements.txt b/requirements.txt
index 2447183..93afa37 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,3 +1,21 @@
-falcon
-pyyaml
-requests
+anyio==3.3.4
+asgiref==3.4.1
+certifi==2021.10.8
+cffi==1.15.0
+charset-normalizer==2.0.7
+click==8.0.3
+cryptography==35.0.0
+fastapi==0.70.0
+fastapi-jwt-auth==0.5.0
+h11==0.12.0
+idna==3.3
+pycparser==2.20
+pydantic==1.8.2
+PyJWT==1.7.1
+requests==2.26.0
+sniffio==1.2.0
+starlette==0.16.0
+typing-extensions==3.10.0.2
+urllib3==1.26.7
+uvicorn==0.15.0
+
diff --git a/src/db.py b/src/db.py
index 2308e8c..7e83d96 100755
--- a/src/db.py
+++ b/src/db.py
@@ -7,28 +7,41 @@
# value if you're too quick with generating the timestamps, ie
# invoking time.time() several times quickly enough.
+import os
+import sys
import time
import couch
-import index
+
+from index import CouchIindex
class DictDB():
- def __init__(self, database, hostname, username, password):
+ def __init__(self):
"""
Check if the database exists, otherwise we will create it together
with the indexes specified in index.py.
"""
+ try:
+ self.database = os.environ['COUCHDB_NAME']
+ self.hostname = os.environ['COUCHDB_HOSTNAME']
+ self.username = os.environ['COUCHDB_USER']
+ self.password = os.environ['COUCHDB_PASSWORD']
+ except KeyError:
+ print('The environment variables COUCHDB_NAME, COUCHDB_HOSTNAME,' +
+ ' COUCHDB_USER and COUCHDB_PASSWORD must be set.')
+ sys.exit(-1)
+
self.server = couch.client.Server(
- f"http://{username}:{password}@{hostname}:5984/")
+ f"http://{self.username}:{self.password}@{self.hostname}:5984/")
try:
- self.couchdb = self.server.database(database)
+ self.couchdb = self.server.database(self.database)
except couch.exceptions.NotFound:
print("Creating database and indexes.")
- self.couchdb = self.server.create(database)
+ self.couchdb = self.server.create(self.database)
- for i in index.indexes:
+ for i in CouchIindex():
self.couchdb.index(i)
self._ts = time.time()
@@ -54,12 +67,12 @@ class DictDB():
if type(data) is list:
for item in data:
item['_id'] = str(self.unique_key())
- self.couchdb.save_bulk(data)
+ ret = self.couchdb.save_bulk(data)
else:
data['_id'] = str(self.unique_key())
- self.couchdb.save(data)
+ ret = self.couchdb.save(data)
- return True
+ return ret
def get(self, key):
"""
diff --git a/src/index.py b/src/index.py
index 837f47e..3541ec7 100644
--- a/src/index.py
+++ b/src/index.py
@@ -1,5 +1,8 @@
-indexes = [
- {
+from pydantic import BaseSettings
+
+
+class CouchIindex(BaseSettings):
+ domain: dict = {
"index": {
"fields": [
"domain",
@@ -7,8 +10,8 @@ indexes = [
},
"name": "domain-json-index",
"type": "json"
- },
- {
+ }
+ ip: dict = {
"index": {
"fields": [
"domain",
@@ -17,8 +20,8 @@ indexes = [
},
"name": "ip-json-index",
"type": "json"
- },
- {
+ }
+ port: dict = {
"index": {
"fields": [
"domain",
@@ -27,8 +30,8 @@ indexes = [
},
"name": "port-json-index",
"type": "json"
- },
- {
+ }
+ asn: dict = {
"index": {
"fields": [
"domain",
@@ -37,8 +40,8 @@ indexes = [
},
"name": "asn-json-index",
"type": "json"
- },
- {
+ }
+ asn_country_code: dict = {
"index": {
"fields": [
"domain",
@@ -48,4 +51,3 @@ indexes = [
"name": "asn-country-code-json-index",
"type": "json"
}
-]
diff --git a/src/wsgi.py b/src/wsgi.py
index 8ab178a..b690bc0 100755
--- a/src/wsgi.py
+++ b/src/wsgi.py
@@ -1,167 +1,109 @@
-#! /usr/bin/env python3
+import uvicorn
-import os
-import sys
-import json
-import authn
-import index
-import falcon
+from fastapi import FastAPI, Depends, Request
+from fastapi.responses import JSONResponse
+from fastapi_jwt_auth import AuthJWT
+from fastapi_jwt_auth.exceptions import AuthJWTException
+from pydantic import BaseModel
+from index import CouchIindex
from db import DictDB
-from base64 import b64decode
-from wsgiref.simple_server import make_server
-
-try:
- database = os.environ['COUCHDB_NAME']
- hostname = os.environ['COUCHDB_HOSTNAME']
- username = os.environ['COUCHDB_USER']
- password = os.environ['COUCHDB_PASSWORD']
-except KeyError:
- print('The environment variables COUCHDB_NAME, COUCHDB_HOSTNAME,' +
- ' COUCHDB_USER and COUCHDB_PASSWORD must be set.')
- sys.exit(-1)
-
-
-class CollectorResource():
- def __init__(self, db, users):
- self._db = db
- self._users = users
-
- def parse_error(data):
- return "I want valid JSON but got this:\n{}\n".format(data)
-
- def user_auth(self, auth_header, authfun):
- if not auth_header:
- return None, None # Fail.
-
- BAlit, b64 = auth_header.split()
- if BAlit != "Basic":
- return None, None # Fail
-
- userbytes, pwbytes = b64decode(b64).split(b':')
- try:
- user = userbytes.decode('utf-8')
- pw = pwbytes.decode('utf-8')
- except Exception:
- return None, None # Fail
- return authfun(user, pw)
-
-
-class EPGet(CollectorResource):
- def on_get(self, req, resp, key=None):
- out = list()
- selectors = dict()
-
- limit = 25
- skip = 0
-
- orgs = self.user_auth(req.auth, self._users.read_perms)
-
- if not orgs:
- resp.status = falcon.HTTP_401
- resp.text = json.dumps({
- 'status': 'error',
- 'message': 'Invalid username or password\n'
- })
- return
-
- if key:
- out = self._db.get(key)
- resp.text = json.dumps({'status': 'success', 'data': out})
- return
-
- for param in req.params:
- if param == 'limit':
- limit = req.params['limit']
- elif param == 'skip':
- skip = req.params['skip']
- for i in index.indexes:
- for j in i['index']['fields']:
- if j == param:
- selectors[param] = req.params[param]
-
- for org in orgs:
- selectors['domain'] = org
- data = self._db.search(**selectors, limit=limit, skip=skip)
- if data:
- out += data
-
- resp.text = json.dumps({'status': 'success', 'data': out})
-
-
-class EPAdd(CollectorResource):
- def on_post(self, req, resp):
- resp.status = falcon.HTTP_200
- resp.content_type = falcon.MEDIA_TEXT
- self._indata = []
-
- orgs = self.user_auth(req.auth, self._users.write_perms)
- if not orgs:
- resp.status = falcon.HTTP_401
- resp.text = json.dumps(
- {'status': 'error', 'message': 'Invalid user or password\n'})
- return
-
- # NOTE: Allowing writing to _any_ org!
- # TODO: Allow only input where input.domain in orgs == True.
-
- # TODO: can we do json.load(req.bounded_stream,
- # cls=customDecoder) where our decoder calls JSONDecoder after
- # decoding UTF-8?
-
- # NOTE: Reading the whole body in one go instead of streaming
- # it nicely.
- rawin = req.bounded_stream.read()
-
- try:
- decodedin = rawin.decode('UTF-8')
- except Exception:
- resp.status = falcon.HTTP_400
- resp.text = json.dumps(
- {'status': 'error', 'message': 'Need UTF-8\n'})
- return
-
- try:
- json_data = json.loads(decodedin)
- except TypeError:
- print('DEBUG: type error')
- resp.status = falcon.HTTP_400
- resp.text = json.dumps(
- {'status': 'error', 'message': CollectorResource.parse_error(decodedin)})
-
- return
- except json.decoder.JSONDecodeError:
- print('DEBUG: json decode error')
- resp.status = falcon.HTTP_400
- resp.text = json.dumps(
- {'status': 'error', 'message': CollectorResource.parse_error(decodedin)})
- return
-
- keys = self._db.add(json_data)
- resp.text = json.dumps({'status': 'success', 'key': keys})
-
-
-def main(port=8000, wsgi_helper=False):
- db = DictDB(database, hostname, username, password)
- users = authn.UserDB('wsgi_demo_users.yaml')
-
- app = falcon.App(cors_enable=True)
- app.add_route('/sc/v0/add', EPAdd(db, users))
- app.add_route('/sc/v0/get', EPGet(db, users))
- app.add_route('/sc/v0/get/{key}', EPGet(db, users))
-
- if wsgi_helper:
+
+app = FastAPI()
+db = DictDB()
+
+public_key = """-----BEGIN PUBLIC KEY-----
+MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEPW8bkkVIq4BX8eWwlUOUYbJhiGDv
+K/6xY5T0BsvV6pbMoIUfgeThVOq5I3CmXxLt+qyPska6ol9fTN7woZLsCg==
+-----END PUBLIC KEY-----"""
+
+
+def get_data(key=None, limit=25, skip=0, ip=None,
+ port=None, asn=None):
+
+ selectors = dict()
+ indexes = CouchIindex().dict()
+
+ selectors['domain'] = 'sunet.se'
+
+ if ip and 'ip' in indexes:
+ selectors['ip'] = ip
+ if port and 'port' in indexes:
+ selectors['port'] = port
+ if asn and 'asn' in indexes:
+ selectors['asn'] = asn
+
+ data = db.search(**selectors, limit=limit, skip=skip)
+
+ return JSONResponse(content={"status": "success", "data": data})
+
+
+class JWTConfig(BaseModel):
+ authjwt_algorithm: str = "ES256"
+ authjwt_public_key: str = public_key
+
+
+@AuthJWT.load_config
+def jwt_config():
+ return JWTConfig()
+
+
+@app.exception_handler(AuthJWTException)
+def authjwt_exception_handler(request: Request, exc: AuthJWTException):
+ return JSONResponse(content={"status": "error", "message":
+ exc.message}, status_code=400)
+
+
+@app.exception_handler(RuntimeError)
+def app_exception_handler(request: Request, exc: RuntimeError):
+ return JSONResponse(content={"status": "error", "message":
+ str(exc.with_traceback(None))},
+ status_code=400)
+
+
+@app.get('/sc/v0/get')
+async def get(key=None, limit=25, skip=0, ip=None, port=None,
+ asn=None, Authorize: AuthJWT = Depends()):
+
+ Authorize.jwt_required()
+
+ return get_data(key, limit, skip, ip, port, asn)
+
+
+@app.get('/sc/v0/get/{key}')
+async def get_key(key=None, Authorize: AuthJWT = Depends()):
+
+ Authorize.jwt_required()
+
+ return get_data(key)
+
+
+@app.post('/sc/v0/add')
+async def add(data: Request, Authorize: AuthJWT = Depends()):
+
+ Authorize.jwt_required()
+
+ orgs = ['sunet.se']
+
+ if not orgs:
+ return JSONResponse(content={"status": "error", "message":
+ "Could not find an organization"}, status_code=400)
+
+ json_data = await data.json()
+
+ key = db.add(json_data)
+
+ return JSONResponse(content={"status": "success", "docs": key})
+
+
+def main(standalone=False):
+ if not standalone:
return app
- print('Serving on port 8000...')
- httpd = make_server('', port, app)
- httpd.serve_forever()
+ uvicorn.run(app, host="0.0.0.0", port=8000, log_level="debug")
if __name__ == '__main__':
- try:
- sys.exit(main())
- except KeyboardInterrupt:
- print('\nBye!')
+ main(standalone=True)
else:
- app = main(port=8000, wsgi_helper=True)
+ app = main()