summaryrefslogtreecommitdiff
path: root/src/collector
diff options
context:
space:
mode:
Diffstat (limited to 'src/collector')
-rw-r--r--[-rwxr-xr-x]src/collector/db.py185
-rwxr-xr-xsrc/collector/main.py332
-rw-r--r--src/collector/schema.py57
3 files changed, 172 insertions, 402 deletions
diff --git a/src/collector/db.py b/src/collector/db.py
index 0bfa014..3b16ef5 100755..100644
--- a/src/collector/db.py
+++ b/src/collector/db.py
@@ -1,148 +1,39 @@
-# A database storing dictionaries, keyed on a timestamp. value = A
-# dict which will be stored as a JSON object encoded in UTF-8. Note
-# that dict keys of type integer or float will become strings while
-# values will keep their type.
-
-# Note that there's a (slim) chance that you'd stomp on the previous
-# value if you're too quick with generating the timestamps, ie
-# invoking time.time() several times quickly enough.
-
-from typing import Dict, List, Tuple, Union, Any
-import os
-import sys
-import time
-
-from src import couch
-from .schema import as_index_list, validate_collector_data
-
-
-class DictDB:
- def __init__(self) -> None:
- """
- Check if the database exists, otherwise we will create it together
- with the indexes specified in index.py.
- """
-
- print(os.environ)
-
- try:
- self.database = os.environ["COUCHDB_NAME"]
- self.hostname = os.environ["COUCHDB_HOSTNAME"]
- self.username = os.environ["COUCHDB_USER"]
- self.password = os.environ["COUCHDB_PASSWORD"]
- except KeyError:
- print(
- "The environment variables COUCHDB_NAME, COUCHDB_HOSTNAME,"
- + " COUCHDB_USER and COUCHDB_PASSWORD must be set."
- )
- sys.exit(-1)
-
- if "COUCHDB_PORT" in os.environ:
- couchdb_port = os.environ["COUCHDB_PORT"]
- else:
- couchdb_port = "5984"
-
- self.server = couch.client.Server(f"http://{self.username}:{self.password}@{self.hostname}:{couchdb_port}/")
-
- try:
- self.couchdb = self.server.database(self.database)
- print("Database already exists")
- except couch.exceptions.NotFound:
- print("Creating database and indexes.")
- self.couchdb = self.server.create(self.database)
-
- for i in as_index_list():
- self.couchdb.index(i)
-
- self._ts = time.time()
-
- def unique_key(self) -> int:
- """
- Create a unique key based on the current time. We will use this as
- the ID for any new documents we store in CouchDB.
- """
-
- ts = time.time()
- while round(ts * 1000) == self._ts:
- ts = time.time()
- self._ts = round(ts * 1000)
-
- return self._ts
-
- # Why batch_write???
- def add(self, data: Union[List[Dict[str, Any]], Dict[str, Any]]) -> Union[str, Tuple[str, str]]:
- """
- Store a document in CouchDB.
- """
-
- if isinstance(data, List):
- for item in data:
- error = validate_collector_data(item)
- if error != "":
- return error
- item["_id"] = str(self.unique_key())
- ret: Tuple[str, str] = self.couchdb.save_bulk(data)
+"""Our database module"""
+from time import sleep
+from sys import exit as app_exit
+from dataclasses import dataclass
+
+from motor.motor_asyncio import (
+ AsyncIOMotorClient,
+ AsyncIOMotorCollection,
+)
+from bson import ObjectId
+
+
+@dataclass()
+class DBClient:
+ """Class to hold database connections for us."""
+
+ client: AsyncIOMotorClient
+ collection: AsyncIOMotorCollection
+
+ def __init__(self, username: str, password: str, collection: str) -> None:
+ self.client = AsyncIOMotorClient(f"mongodb://{username}:{password}@mongodb:27017/production", timeoutMS=2000)
+ self.collection = self.client["production"][collection]
+
+ async def check_server(self) -> None:
+ """Try query the DB and exit the program if we fail after 5 times.
+
+ :return: None
+ """
+ for i in range(5):
+ try:
+ await self.collection.find_one({"_id": ObjectId("507f1f77bcf86cd799439011")})
+ print("Connection to DB - OK")
+ break
+ except: # pylint: disable=bare-except
+ print(f"WARNING failed to connect to DB - {i} / 4", flush=True)
+ sleep(1)
else:
- error = validate_collector_data(data)
- if error != "":
- return error
- data["_id"] = str(self.unique_key())
- ret = self.couchdb.save(data)
-
- return ret
-
- def get(self, key: int) -> Dict[str, Any]:
- """
- Get a document based on its ID, return an empty dict if not found.
- """
-
- try:
- doc: Dict[str, Any] = self.couchdb.get(key)
- except couch.exceptions.NotFound:
- doc = {}
-
- return doc
-
- #
- # def slice(self, key_from=None, key_to=None):
- # pass
-
- def search(self, limit: int = 25, skip: int = 0, **kwargs: Any) -> List[Dict[str, Any]]:
- """
- Execute a Mango query, ideally we should have an index matching
- the query otherwise things will be slow.
- """
-
- data: List[Dict[str, Any]] = []
- selector: Dict[str, Any] = {}
-
- try:
- limit = int(limit)
- skip = int(skip)
- except ValueError:
- limit = 25
- skip = 0
-
- if kwargs:
- selector = {"limit": limit, "skip": skip, "selector": {}}
-
- for key in kwargs:
- if kwargs[key] and kwargs[key].isnumeric():
- kwargs[key] = int(kwargs[key])
- selector["selector"][key] = {"$eq": kwargs[key]}
-
- for doc in self.couchdb.find(selector, wrapper=None, limit=5):
- data.append(doc)
-
- return data
-
- def delete(self, key: int) -> Union[int, None]:
- """
- Delete a document based on its ID.
- """
- try:
- self.couchdb.delete(key)
- except couch.exceptions.NotFound:
- return None
-
- return key
+ print("Could not connect to DB - mongodb://REDACTED_USERNAME:REDACTED_PASSWORD@mongodb:27017/production")
+ app_exit(1)
diff --git a/src/collector/main.py b/src/collector/main.py
index c363885..096b788 100755
--- a/src/collector/main.py
+++ b/src/collector/main.py
@@ -1,267 +1,175 @@
-from typing import Dict, Union, List, Callable, Awaitable, Any
-import json
-import os
+"""Our main module"""
+from typing import Dict, Optional, List, Any
+from os import environ
+import asyncio
import sys
-import time
+from json.decoder import JSONDecodeError
-import uvicorn
-from fastapi import Depends, FastAPI, Request, Response
-from fastapi.middleware.cors import CORSMiddleware
+from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
-from fastapi_jwt_auth import AuthJWT
-from fastapi_jwt_auth.auth_config import AuthConfig
-from fastapi_jwt_auth.exceptions import AuthJWTException
from pydantic import BaseModel
-
-from .db import DictDB
-from .schema import get_index_keys, validate_collector_data
-
-app = FastAPI()
-
-app.add_middleware(
- CORSMiddleware,
- allow_origins=["http://localhost:8001"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- expose_headers=["X-Total-Count"],
+from pymongo.errors import OperationFailure
+from bson import (
+ ObjectId,
+ json_util,
)
+from dotenv import load_dotenv
-# TODO: X-Total-Count
-
-
-@app.middleware("http")
-async def mock_x_total_count_header(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
-
- print(type(call_next))
-
- response: Response = await call_next(request)
- response.headers["X-Total-Count"] = "100"
- return response
-
-
-for i in range(10):
- try:
- db = DictDB()
- except Exception as e:
- print(f"Database not responding, will try again soon. Attempt {i + 1} of 10.")
- else:
- break
- time.sleep(1)
-else:
- print("Database did not respond after 10 attempts, quitting.")
- sys.exit(-1)
-
-
-def get_pubkey() -> str:
- try:
- if "JWT_PUBKEY_PATH" in os.environ:
- keypath = os.environ["JWT_PUBKEY_PATH"]
- else:
- keypath = "/opt/certs/public.pem"
-
- with open(keypath, "r") as fd:
- pubkey = fd.read()
- except FileNotFoundError:
- print(f"Could not find JWT certificate in {keypath}")
- sys.exit(-1)
-
- return pubkey
-
-
-def get_data(
- key: Union[int, None] = None,
- limit: int = 25,
- skip: int = 0,
- ip: Union[str, None] = None,
- port: Union[int, None] = None,
- asn: Union[str, None] = None,
- domain: Union[str, None] = None,
-) -> List[Dict[str, Any]]:
- if key:
- return [db.get(key)]
+from .db import DBClient
+from .schema import valid_schema
- selectors: Dict[str, Any] = {}
- indexes = get_index_keys()
- selectors["domain"] = domain
- if ip and "ip" in indexes:
- selectors["ip"] = ip
- if port and "port" in indexes:
- selectors["port"] = port
- if asn and "asn" in indexes:
- selectors["asn"] = asn
+load_dotenv()
+# Get credentials
+if "MONGODB_USERNAME" not in environ or "MONGODB_PASSWORD" not in environ or "MONGODB_COLLECTION" not in environ:
+ print("Missing MONGODB_USERNAME or MONGODB_PASSWORD or MONGODB_COLLECTION in env")
+ sys.exit(1)
- data: List[Dict[str, Any]] = db.search(**selectors, limit=limit, skip=skip)
+# Create DB object
+db = DBClient(environ["MONGODB_USERNAME"], environ["MONGODB_PASSWORD"], environ["MONGODB_COLLECTION"])
- return data
-
-
-class JWTConfig(BaseModel):
- authjwt_algorithm: str = "ES256"
- authjwt_public_key: str = get_pubkey()
+# Check DB
+loop = asyncio.get_running_loop()
+startup_task = loop.create_task(db.check_server())
+app = FastAPI()
-@AuthJWT.load_config # type: ignore
-def jwt_config():
- return JWTConfig()
+# @app.exception_handler(RuntimeError)
+# def app_exception_handler(request: Request, exc: RuntimeError) -> JSONResponse:
+# print(exc, flush=True)
+# return JSONResponse(content={"status": "error", "message": str(exc.with_traceback(None))}, status_code=400)
+# return JSONResponse(content={"status": "error", "message": "Error during processing"}, status_code=400)
-@app.exception_handler(AuthJWTException)
-def authjwt_exception_handler(request: Request, exc: AuthJWTException) -> JSONResponse:
- return JSONResponse(content={"status": "error", "message": exc.message}, status_code=400)
+class SearchInput(BaseModel):
+ """Handle search data for HTTP request"""
-@app.exception_handler(RuntimeError)
-def app_exception_handler(request: Request, exc: RuntimeError) -> JSONResponse:
- return JSONResponse(content={"status": "error", "message": str(exc.with_traceback(None))}, status_code=400)
+ search: Optional[Dict[str, Any]]
+ limit: int = 25
+ skip: int = 0
-@app.get("/sc/v0/get")
-async def get(
- key: Union[int, None] = None,
- limit: int = 25,
- skip: int = 0,
- ip: Union[str, None] = None,
- port: Union[int, None] = None,
- asn: Union[str, None] = None,
- Authorize: AuthJWT = Depends(),
-) -> JSONResponse:
+@app.post("/sc/v0/search")
+async def search(search_data: SearchInput) -> JSONResponse:
+ """/sc/v0/search, POST method
- Authorize.jwt_required()
+ :param search_data: The search data.
+ :return: JSONResponse
+ """
+ data: List[Dict[str, Any]] = []
- data = []
- raw_jwt = Authorize.get_raw_jwt()
+ cursor = db.collection.find(search_data.search)
+ cursor.sort("timestamp", -1).limit(search_data.limit).skip(search_data.skip)
- if "read" not in raw_jwt:
+ try:
+ async for document in cursor:
+ data.append(document)
+ except OperationFailure as exc:
+ print(f"DB failed to process: {exc.details}")
return JSONResponse(
content={
"status": "error",
- "message": "Could not find read claim in JWT token",
+ "message": "Probably wrong syntax, note the dictionary for find: "
+ + "https://motor.readthedocs.io/en/stable/tutorial-asyncio.html#async-for",
},
status_code=400,
)
- else:
- domains = raw_jwt["read"]
- for domain in domains:
- data.extend(get_data(key, limit, skip, ip, port, asn, domain))
+ if not data:
+ return JSONResponse(content={"status": "error", "message": "Document not found"}, status_code=400)
- return JSONResponse(content={"status": "success", "docs": data})
+ return JSONResponse(content={"status": "success", "docs": json_util.dumps(data)})
-@app.get("/sc/v0/get/{key}")
-async def get_key(key: Union[int, None] = None, Authorize: AuthJWT = Depends()) -> JSONResponse:
+@app.post("/sc/v0")
+async def create(request: Request) -> JSONResponse:
+ """/sc/v0, POST method
- Authorize.jwt_required()
+ :param request: The request where we get the json body.
+ :return: JSONResponse
+ """
- raw_jwt = Authorize.get_raw_jwt()
+ try:
+ json_data = await request.json()
+ except JSONDecodeError:
+ return JSONResponse(content={"status": "error", "message": "Invalid JSON"}, status_code=400)
- if "read" not in raw_jwt:
- return JSONResponse(
- content={
- "status": "error",
- "message": "Could not find read claim in JWT token",
- },
- status_code=400,
- )
- else:
- allowed_domains = raw_jwt["read"]
+ if not valid_schema(json_data):
+ return JSONResponse(content={"status": "error", "message": "Not our JSON schema"}, status_code=400)
- data_list = get_data(key)
+ result = await db.collection.insert_one(json_data)
+ return JSONResponse(content={"status": "success", "key": str(result.inserted_id)})
- # Handle if missing
- data = data_list[0]
- if data and data["domain"] not in allowed_domains:
- return JSONResponse(
- content={
- "status": "error",
- "message": "User not authorized to view this object",
- },
- status_code=400,
- )
-
- return JSONResponse(content={"status": "success", "docs": data})
+@app.put("/sc/v0")
+async def update(request: Request) -> JSONResponse:
+ """/sc/v0, PUT method
-
-# WHY IS AUTH OUTCOMMENTED???
-@app.post("/sc/v0/add")
-async def add(data: Request, Authorize: AuthJWT = Depends()) -> JSONResponse:
- # Authorize.jwt_required()
+ :param request: The request where we get the json body.
+ :return: JSONResponse
+ """
try:
- json_data = await data.json()
- except json.decoder.JSONDecodeError:
- return JSONResponse(
- content={
- "status": "error",
- "message": "Invalid JSON.",
- },
- status_code=400,
- )
-
- key = db.add(json_data)
-
- if isinstance(key, str):
- return JSONResponse(
- content={
- "status": "error",
- "message": key,
- },
- status_code=400,
- )
-
- return JSONResponse(content={"status": "success", "docs": key})
-
+ json_data = await request.json()
+ except JSONDecodeError:
+ return JSONResponse(content={"status": "error", "message": "Invalid JSON"}, status_code=400)
+
+ if "_id" not in json_data:
+ return JSONResponse(content={"status": "error", "message": "Missing key '_id'"}, status_code=400)
+
+ # Get the key
+ if isinstance(json_data["_id"], str):
+ object_id = ObjectId(json_data["_id"])
+ elif (
+ isinstance(json_data["_id"], dict) and "$oid" in json_data["_id"] and isinstance(json_data["_id"]["$oid"], str)
+ ):
+ object_id = ObjectId(json_data["_id"]["$oid"])
+ else:
+ return JSONResponse(content={"status": "error", "message": "Missing key '_id' with valid id"}, status_code=400)
-@app.delete("/sc/v0/delete/{key}")
-async def delete(key: int, Authorize: AuthJWT = Depends()) -> JSONResponse:
+ # Ensure the updating key exist
+ document = await db.collection.find_one({"_id": object_id})
+ if document is None:
+ return JSONResponse(content={"status": "error", "message": "Document not found"}, status_code=400)
- Authorize.jwt_required()
+ # Ensure valid schema
+ del json_data["_id"]
+ if not valid_schema(json_data):
+ return JSONResponse(content={"status": "error", "message": "Not our JSON schema"}, status_code=400)
- raw_jwt = Authorize.get_raw_jwt()
+ # Replace the data
+ json_data["_id"] = object_id
+ await db.collection.replace_one({"_id": object_id}, json_data)
+ return JSONResponse(content={"status": "success", "key": str(object_id)})
- if "write" not in raw_jwt:
- return JSONResponse(
- content={
- "status": "error",
- "message": "Could not find write claim in JWT token",
- },
- status_code=400,
- )
- else:
- allowed_domains = raw_jwt["write"]
- data_list = get_data(key)
+@app.get("/sc/v0/{key}")
+async def get(key: str) -> JSONResponse:
+ """/sc/v0, POST method
- # Handle if missing
- data = data_list[0]
+ :param key: The document key in the database.
+ :return: JSONResponse
+ """
- if data and data["domain"] not in allowed_domains:
- return JSONResponse(
- content={
- "status": "error",
- "message": "User not authorized to delete this object",
- },
- status_code=400,
- )
+ document = await db.collection.find_one({"_id": ObjectId(key)})
- if db.delete(key) is None:
+ if document is None:
return JSONResponse(content={"status": "error", "message": "Document not found"}, status_code=400)
- return JSONResponse(content={"status": "success", "docs": data})
+ return JSONResponse(content={"status": "success", "docs": json_util.dumps(document)})
-# def main(standalone: bool = False):
-# print(type(app))
-# if not standalone:
-# return app
+@app.delete("/sc/v0/{key}")
+async def delete(key: str) -> JSONResponse:
+ """/sc/v0, POST method
-# uvicorn.run(app, host="0.0.0.0", port=8000, log_level="debug")
+ :param key: The document key in the database.
+ :return: JSONResponse
+ """
+ result = await db.collection.delete_one({"_id": ObjectId(key)})
+ if result.deleted_count == 0:
+ return JSONResponse(content={"status": "error", "message": "Document not found"}, status_code=400)
-# if __name__ == "__main__":
-# main(standalone=True)
-# else:
-# app = main()
+ return JSONResponse(content={"status": "success", "key": key})
diff --git a/src/collector/schema.py b/src/collector/schema.py
index e291f10..221990a 100644
--- a/src/collector/schema.py
+++ b/src/collector/schema.py
@@ -1,8 +1,5 @@
-from typing import List, Any, Dict
-import json
-import sys
-import traceback
-
+"""Our schema module"""
+from typing import Any, Dict
import jsonschema
# fmt:off
@@ -64,7 +61,8 @@ schema = {
]
},
{
- "required": [
+ "required":
+ [
"display_name",
"investigation_needed",
# "reliability", # TODO: reliability is required if investigation_needed = true
@@ -93,44 +91,17 @@ schema = {
"result",
],
}
-# fmt:on
-
-
-def get_index_keys() -> List[str]:
- keys: List[str] = []
- for key in schema["properties"]:
- keys.append(key)
- return keys
-def as_index_list() -> List[Dict[str, Any]]:
- index_list: List[Dict[str, Any]] = []
- for key in schema["properties"]:
- name = f"{key}-json-index"
- index = {
- "index": {
- "fields": [
- key,
- ]
- },
- "name": name,
- "type": "json",
- }
- index_list.append(index)
-
- return index_list
+def valid_schema(json_data: Dict[str, Any]) -> bool:
+ """Check if json data follows the schema.
-
-def validate_collector_data(json_blob: Dict[str, Any]) -> str:
+ :param json_data: Json object
+ :return: bool
+ """
try:
- jsonschema.validate(json_blob, schema, format_checker=jsonschema.FormatChecker())
- except jsonschema.exceptions.ValidationError as e:
- return f"Validation failed with error: {e.message}"
- return ""
-
-
-if __name__ == "__main__":
- with open(sys.argv[1]) as fd:
- json_data = json.loads(fd.read())
-
- print(validate_collector_data(json_data))
+ jsonschema.validate(json_data, schema, format_checker=jsonschema.FormatChecker())
+ except jsonschema.exceptions.ValidationError as exc:
+ print(f"Validation failed with error: {exc.message}")
+ return False
+ return True