diff options
Diffstat (limited to 'src/collector')
-rw-r--r--[-rwxr-xr-x] | src/collector/db.py | 185 | ||||
-rwxr-xr-x | src/collector/main.py | 332 | ||||
-rw-r--r-- | src/collector/schema.py | 57 |
3 files changed, 172 insertions, 402 deletions
diff --git a/src/collector/db.py b/src/collector/db.py index 0bfa014..3b16ef5 100755..100644 --- a/src/collector/db.py +++ b/src/collector/db.py @@ -1,148 +1,39 @@ -# A database storing dictionaries, keyed on a timestamp. value = A -# dict which will be stored as a JSON object encoded in UTF-8. Note -# that dict keys of type integer or float will become strings while -# values will keep their type. - -# Note that there's a (slim) chance that you'd stomp on the previous -# value if you're too quick with generating the timestamps, ie -# invoking time.time() several times quickly enough. - -from typing import Dict, List, Tuple, Union, Any -import os -import sys -import time - -from src import couch -from .schema import as_index_list, validate_collector_data - - -class DictDB: - def __init__(self) -> None: - """ - Check if the database exists, otherwise we will create it together - with the indexes specified in index.py. - """ - - print(os.environ) - - try: - self.database = os.environ["COUCHDB_NAME"] - self.hostname = os.environ["COUCHDB_HOSTNAME"] - self.username = os.environ["COUCHDB_USER"] - self.password = os.environ["COUCHDB_PASSWORD"] - except KeyError: - print( - "The environment variables COUCHDB_NAME, COUCHDB_HOSTNAME," - + " COUCHDB_USER and COUCHDB_PASSWORD must be set." - ) - sys.exit(-1) - - if "COUCHDB_PORT" in os.environ: - couchdb_port = os.environ["COUCHDB_PORT"] - else: - couchdb_port = "5984" - - self.server = couch.client.Server(f"http://{self.username}:{self.password}@{self.hostname}:{couchdb_port}/") - - try: - self.couchdb = self.server.database(self.database) - print("Database already exists") - except couch.exceptions.NotFound: - print("Creating database and indexes.") - self.couchdb = self.server.create(self.database) - - for i in as_index_list(): - self.couchdb.index(i) - - self._ts = time.time() - - def unique_key(self) -> int: - """ - Create a unique key based on the current time. We will use this as - the ID for any new documents we store in CouchDB. - """ - - ts = time.time() - while round(ts * 1000) == self._ts: - ts = time.time() - self._ts = round(ts * 1000) - - return self._ts - - # Why batch_write??? - def add(self, data: Union[List[Dict[str, Any]], Dict[str, Any]]) -> Union[str, Tuple[str, str]]: - """ - Store a document in CouchDB. - """ - - if isinstance(data, List): - for item in data: - error = validate_collector_data(item) - if error != "": - return error - item["_id"] = str(self.unique_key()) - ret: Tuple[str, str] = self.couchdb.save_bulk(data) +"""Our database module""" +from time import sleep +from sys import exit as app_exit +from dataclasses import dataclass + +from motor.motor_asyncio import ( + AsyncIOMotorClient, + AsyncIOMotorCollection, +) +from bson import ObjectId + + +@dataclass() +class DBClient: + """Class to hold database connections for us.""" + + client: AsyncIOMotorClient + collection: AsyncIOMotorCollection + + def __init__(self, username: str, password: str, collection: str) -> None: + self.client = AsyncIOMotorClient(f"mongodb://{username}:{password}@mongodb:27017/production", timeoutMS=2000) + self.collection = self.client["production"][collection] + + async def check_server(self) -> None: + """Try query the DB and exit the program if we fail after 5 times. + + :return: None + """ + for i in range(5): + try: + await self.collection.find_one({"_id": ObjectId("507f1f77bcf86cd799439011")}) + print("Connection to DB - OK") + break + except: # pylint: disable=bare-except + print(f"WARNING failed to connect to DB - {i} / 4", flush=True) + sleep(1) else: - error = validate_collector_data(data) - if error != "": - return error - data["_id"] = str(self.unique_key()) - ret = self.couchdb.save(data) - - return ret - - def get(self, key: int) -> Dict[str, Any]: - """ - Get a document based on its ID, return an empty dict if not found. - """ - - try: - doc: Dict[str, Any] = self.couchdb.get(key) - except couch.exceptions.NotFound: - doc = {} - - return doc - - # - # def slice(self, key_from=None, key_to=None): - # pass - - def search(self, limit: int = 25, skip: int = 0, **kwargs: Any) -> List[Dict[str, Any]]: - """ - Execute a Mango query, ideally we should have an index matching - the query otherwise things will be slow. - """ - - data: List[Dict[str, Any]] = [] - selector: Dict[str, Any] = {} - - try: - limit = int(limit) - skip = int(skip) - except ValueError: - limit = 25 - skip = 0 - - if kwargs: - selector = {"limit": limit, "skip": skip, "selector": {}} - - for key in kwargs: - if kwargs[key] and kwargs[key].isnumeric(): - kwargs[key] = int(kwargs[key]) - selector["selector"][key] = {"$eq": kwargs[key]} - - for doc in self.couchdb.find(selector, wrapper=None, limit=5): - data.append(doc) - - return data - - def delete(self, key: int) -> Union[int, None]: - """ - Delete a document based on its ID. - """ - try: - self.couchdb.delete(key) - except couch.exceptions.NotFound: - return None - - return key + print("Could not connect to DB - mongodb://REDACTED_USERNAME:REDACTED_PASSWORD@mongodb:27017/production") + app_exit(1) diff --git a/src/collector/main.py b/src/collector/main.py index c363885..096b788 100755 --- a/src/collector/main.py +++ b/src/collector/main.py @@ -1,267 +1,175 @@ -from typing import Dict, Union, List, Callable, Awaitable, Any -import json -import os +"""Our main module""" +from typing import Dict, Optional, List, Any +from os import environ +import asyncio import sys -import time +from json.decoder import JSONDecodeError -import uvicorn -from fastapi import Depends, FastAPI, Request, Response -from fastapi.middleware.cors import CORSMiddleware +from fastapi import FastAPI, Request from fastapi.responses import JSONResponse -from fastapi_jwt_auth import AuthJWT -from fastapi_jwt_auth.auth_config import AuthConfig -from fastapi_jwt_auth.exceptions import AuthJWTException from pydantic import BaseModel - -from .db import DictDB -from .schema import get_index_keys, validate_collector_data - -app = FastAPI() - -app.add_middleware( - CORSMiddleware, - allow_origins=["http://localhost:8001"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - expose_headers=["X-Total-Count"], +from pymongo.errors import OperationFailure +from bson import ( + ObjectId, + json_util, ) +from dotenv import load_dotenv -# TODO: X-Total-Count - - -@app.middleware("http") -async def mock_x_total_count_header(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: - - print(type(call_next)) - - response: Response = await call_next(request) - response.headers["X-Total-Count"] = "100" - return response - - -for i in range(10): - try: - db = DictDB() - except Exception as e: - print(f"Database not responding, will try again soon. Attempt {i + 1} of 10.") - else: - break - time.sleep(1) -else: - print("Database did not respond after 10 attempts, quitting.") - sys.exit(-1) - - -def get_pubkey() -> str: - try: - if "JWT_PUBKEY_PATH" in os.environ: - keypath = os.environ["JWT_PUBKEY_PATH"] - else: - keypath = "/opt/certs/public.pem" - - with open(keypath, "r") as fd: - pubkey = fd.read() - except FileNotFoundError: - print(f"Could not find JWT certificate in {keypath}") - sys.exit(-1) - - return pubkey - - -def get_data( - key: Union[int, None] = None, - limit: int = 25, - skip: int = 0, - ip: Union[str, None] = None, - port: Union[int, None] = None, - asn: Union[str, None] = None, - domain: Union[str, None] = None, -) -> List[Dict[str, Any]]: - if key: - return [db.get(key)] +from .db import DBClient +from .schema import valid_schema - selectors: Dict[str, Any] = {} - indexes = get_index_keys() - selectors["domain"] = domain - if ip and "ip" in indexes: - selectors["ip"] = ip - if port and "port" in indexes: - selectors["port"] = port - if asn and "asn" in indexes: - selectors["asn"] = asn +load_dotenv() +# Get credentials +if "MONGODB_USERNAME" not in environ or "MONGODB_PASSWORD" not in environ or "MONGODB_COLLECTION" not in environ: + print("Missing MONGODB_USERNAME or MONGODB_PASSWORD or MONGODB_COLLECTION in env") + sys.exit(1) - data: List[Dict[str, Any]] = db.search(**selectors, limit=limit, skip=skip) +# Create DB object +db = DBClient(environ["MONGODB_USERNAME"], environ["MONGODB_PASSWORD"], environ["MONGODB_COLLECTION"]) - return data - - -class JWTConfig(BaseModel): - authjwt_algorithm: str = "ES256" - authjwt_public_key: str = get_pubkey() +# Check DB +loop = asyncio.get_running_loop() +startup_task = loop.create_task(db.check_server()) +app = FastAPI() -@AuthJWT.load_config # type: ignore -def jwt_config(): - return JWTConfig() +# @app.exception_handler(RuntimeError) +# def app_exception_handler(request: Request, exc: RuntimeError) -> JSONResponse: +# print(exc, flush=True) +# return JSONResponse(content={"status": "error", "message": str(exc.with_traceback(None))}, status_code=400) +# return JSONResponse(content={"status": "error", "message": "Error during processing"}, status_code=400) -@app.exception_handler(AuthJWTException) -def authjwt_exception_handler(request: Request, exc: AuthJWTException) -> JSONResponse: - return JSONResponse(content={"status": "error", "message": exc.message}, status_code=400) +class SearchInput(BaseModel): + """Handle search data for HTTP request""" -@app.exception_handler(RuntimeError) -def app_exception_handler(request: Request, exc: RuntimeError) -> JSONResponse: - return JSONResponse(content={"status": "error", "message": str(exc.with_traceback(None))}, status_code=400) + search: Optional[Dict[str, Any]] + limit: int = 25 + skip: int = 0 -@app.get("/sc/v0/get") -async def get( - key: Union[int, None] = None, - limit: int = 25, - skip: int = 0, - ip: Union[str, None] = None, - port: Union[int, None] = None, - asn: Union[str, None] = None, - Authorize: AuthJWT = Depends(), -) -> JSONResponse: +@app.post("/sc/v0/search") +async def search(search_data: SearchInput) -> JSONResponse: + """/sc/v0/search, POST method - Authorize.jwt_required() + :param search_data: The search data. + :return: JSONResponse + """ + data: List[Dict[str, Any]] = [] - data = [] - raw_jwt = Authorize.get_raw_jwt() + cursor = db.collection.find(search_data.search) + cursor.sort("timestamp", -1).limit(search_data.limit).skip(search_data.skip) - if "read" not in raw_jwt: + try: + async for document in cursor: + data.append(document) + except OperationFailure as exc: + print(f"DB failed to process: {exc.details}") return JSONResponse( content={ "status": "error", - "message": "Could not find read claim in JWT token", + "message": "Probably wrong syntax, note the dictionary for find: " + + "https://motor.readthedocs.io/en/stable/tutorial-asyncio.html#async-for", }, status_code=400, ) - else: - domains = raw_jwt["read"] - for domain in domains: - data.extend(get_data(key, limit, skip, ip, port, asn, domain)) + if not data: + return JSONResponse(content={"status": "error", "message": "Document not found"}, status_code=400) - return JSONResponse(content={"status": "success", "docs": data}) + return JSONResponse(content={"status": "success", "docs": json_util.dumps(data)}) -@app.get("/sc/v0/get/{key}") -async def get_key(key: Union[int, None] = None, Authorize: AuthJWT = Depends()) -> JSONResponse: +@app.post("/sc/v0") +async def create(request: Request) -> JSONResponse: + """/sc/v0, POST method - Authorize.jwt_required() + :param request: The request where we get the json body. + :return: JSONResponse + """ - raw_jwt = Authorize.get_raw_jwt() + try: + json_data = await request.json() + except JSONDecodeError: + return JSONResponse(content={"status": "error", "message": "Invalid JSON"}, status_code=400) - if "read" not in raw_jwt: - return JSONResponse( - content={ - "status": "error", - "message": "Could not find read claim in JWT token", - }, - status_code=400, - ) - else: - allowed_domains = raw_jwt["read"] + if not valid_schema(json_data): + return JSONResponse(content={"status": "error", "message": "Not our JSON schema"}, status_code=400) - data_list = get_data(key) + result = await db.collection.insert_one(json_data) + return JSONResponse(content={"status": "success", "key": str(result.inserted_id)}) - # Handle if missing - data = data_list[0] - if data and data["domain"] not in allowed_domains: - return JSONResponse( - content={ - "status": "error", - "message": "User not authorized to view this object", - }, - status_code=400, - ) - - return JSONResponse(content={"status": "success", "docs": data}) +@app.put("/sc/v0") +async def update(request: Request) -> JSONResponse: + """/sc/v0, PUT method - -# WHY IS AUTH OUTCOMMENTED??? -@app.post("/sc/v0/add") -async def add(data: Request, Authorize: AuthJWT = Depends()) -> JSONResponse: - # Authorize.jwt_required() + :param request: The request where we get the json body. + :return: JSONResponse + """ try: - json_data = await data.json() - except json.decoder.JSONDecodeError: - return JSONResponse( - content={ - "status": "error", - "message": "Invalid JSON.", - }, - status_code=400, - ) - - key = db.add(json_data) - - if isinstance(key, str): - return JSONResponse( - content={ - "status": "error", - "message": key, - }, - status_code=400, - ) - - return JSONResponse(content={"status": "success", "docs": key}) - + json_data = await request.json() + except JSONDecodeError: + return JSONResponse(content={"status": "error", "message": "Invalid JSON"}, status_code=400) + + if "_id" not in json_data: + return JSONResponse(content={"status": "error", "message": "Missing key '_id'"}, status_code=400) + + # Get the key + if isinstance(json_data["_id"], str): + object_id = ObjectId(json_data["_id"]) + elif ( + isinstance(json_data["_id"], dict) and "$oid" in json_data["_id"] and isinstance(json_data["_id"]["$oid"], str) + ): + object_id = ObjectId(json_data["_id"]["$oid"]) + else: + return JSONResponse(content={"status": "error", "message": "Missing key '_id' with valid id"}, status_code=400) -@app.delete("/sc/v0/delete/{key}") -async def delete(key: int, Authorize: AuthJWT = Depends()) -> JSONResponse: + # Ensure the updating key exist + document = await db.collection.find_one({"_id": object_id}) + if document is None: + return JSONResponse(content={"status": "error", "message": "Document not found"}, status_code=400) - Authorize.jwt_required() + # Ensure valid schema + del json_data["_id"] + if not valid_schema(json_data): + return JSONResponse(content={"status": "error", "message": "Not our JSON schema"}, status_code=400) - raw_jwt = Authorize.get_raw_jwt() + # Replace the data + json_data["_id"] = object_id + await db.collection.replace_one({"_id": object_id}, json_data) + return JSONResponse(content={"status": "success", "key": str(object_id)}) - if "write" not in raw_jwt: - return JSONResponse( - content={ - "status": "error", - "message": "Could not find write claim in JWT token", - }, - status_code=400, - ) - else: - allowed_domains = raw_jwt["write"] - data_list = get_data(key) +@app.get("/sc/v0/{key}") +async def get(key: str) -> JSONResponse: + """/sc/v0, POST method - # Handle if missing - data = data_list[0] + :param key: The document key in the database. + :return: JSONResponse + """ - if data and data["domain"] not in allowed_domains: - return JSONResponse( - content={ - "status": "error", - "message": "User not authorized to delete this object", - }, - status_code=400, - ) + document = await db.collection.find_one({"_id": ObjectId(key)}) - if db.delete(key) is None: + if document is None: return JSONResponse(content={"status": "error", "message": "Document not found"}, status_code=400) - return JSONResponse(content={"status": "success", "docs": data}) + return JSONResponse(content={"status": "success", "docs": json_util.dumps(document)}) -# def main(standalone: bool = False): -# print(type(app)) -# if not standalone: -# return app +@app.delete("/sc/v0/{key}") +async def delete(key: str) -> JSONResponse: + """/sc/v0, POST method -# uvicorn.run(app, host="0.0.0.0", port=8000, log_level="debug") + :param key: The document key in the database. + :return: JSONResponse + """ + result = await db.collection.delete_one({"_id": ObjectId(key)}) + if result.deleted_count == 0: + return JSONResponse(content={"status": "error", "message": "Document not found"}, status_code=400) -# if __name__ == "__main__": -# main(standalone=True) -# else: -# app = main() + return JSONResponse(content={"status": "success", "key": key}) diff --git a/src/collector/schema.py b/src/collector/schema.py index e291f10..221990a 100644 --- a/src/collector/schema.py +++ b/src/collector/schema.py @@ -1,8 +1,5 @@ -from typing import List, Any, Dict -import json -import sys -import traceback - +"""Our schema module""" +from typing import Any, Dict import jsonschema # fmt:off @@ -64,7 +61,8 @@ schema = { ] }, { - "required": [ + "required": + [ "display_name", "investigation_needed", # "reliability", # TODO: reliability is required if investigation_needed = true @@ -93,44 +91,17 @@ schema = { "result", ], } -# fmt:on - - -def get_index_keys() -> List[str]: - keys: List[str] = [] - for key in schema["properties"]: - keys.append(key) - return keys -def as_index_list() -> List[Dict[str, Any]]: - index_list: List[Dict[str, Any]] = [] - for key in schema["properties"]: - name = f"{key}-json-index" - index = { - "index": { - "fields": [ - key, - ] - }, - "name": name, - "type": "json", - } - index_list.append(index) - - return index_list +def valid_schema(json_data: Dict[str, Any]) -> bool: + """Check if json data follows the schema. - -def validate_collector_data(json_blob: Dict[str, Any]) -> str: + :param json_data: Json object + :return: bool + """ try: - jsonschema.validate(json_blob, schema, format_checker=jsonschema.FormatChecker()) - except jsonschema.exceptions.ValidationError as e: - return f"Validation failed with error: {e.message}" - return "" - - -if __name__ == "__main__": - with open(sys.argv[1]) as fd: - json_data = json.loads(fd.read()) - - print(validate_collector_data(json_data)) + jsonschema.validate(json_data, schema, format_checker=jsonschema.FormatChecker()) + except jsonschema.exceptions.ValidationError as exc: + print(f"Validation failed with error: {exc.message}") + return False + return True |