summaryrefslogtreecommitdiff
path: root/src/soc_collector/db.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/soc_collector/db.py')
-rw-r--r--src/soc_collector/db.py178
1 files changed, 178 insertions, 0 deletions
diff --git a/src/soc_collector/db.py b/src/soc_collector/db.py
new file mode 100644
index 0000000..d601a82
--- /dev/null
+++ b/src/soc_collector/db.py
@@ -0,0 +1,178 @@
+"""Our database module"""
+from typing import List, Dict, Optional, Any
+from time import sleep
+from sys import exit as app_exit
+from dataclasses import dataclass
+
+from fastapi import HTTPException
+from pydantic import BaseModel
+from pymongo.errors import OperationFailure
+from pymongo import (
+ ASCENDING,
+ DESCENDING,
+)
+from motor.motor_asyncio import (
+ AsyncIOMotorClient,
+ AsyncIOMotorCollection,
+)
+from bson import ObjectId
+
+
+class SearchInput(BaseModel):
+ """Handle search data for HTTP request"""
+
+ search: Optional[Dict[str, Any]]
+ limit: int = 25
+ skip: int = 0
+
+
+@dataclass()
+class DBClient:
+ """Class to hold database connections for us."""
+
+ client: AsyncIOMotorClient
+ collection: AsyncIOMotorCollection
+
+ def __init__(self, username: str, password: str, collection: str) -> None:
+ self.client = AsyncIOMotorClient(
+ f"mongodb://{username}:{password}@mongodb:27017/production",
+ maxConnecting=4,
+ timeoutMS=3000,
+ serverSelectionTimeoutMS=3000,
+ )
+ self.collection = self.client["production"][collection]
+
+ async def startup(self) -> None:
+ """Try query the DB and exit the program if we fail after 5 times.
+
+ :return: None
+ """
+
+ for i in range(5):
+ try:
+ await self.collection.find_one({"_id": ObjectId("507f1f77bcf86cd799439011")})
+ print("Connection to DB - OK")
+ break
+ except Exception: # pylint: disable=broad-except
+ print(f"WARNING failed to connect to DB - {i} / 4", flush=True)
+ sleep(1)
+ else:
+ print("Could not connect to DB - mongodb://REDACTED_USERNAME:REDACTED_PASSWORD@mongodb:27017/production")
+ app_exit(1)
+
+ async def find(self, search_data: SearchInput) -> List[Dict[str, Any]]:
+ """Wrap the find() method, handling timeouts and return data type.
+
+ :param search_data: Instance of SearchInput.
+ :return: Optional[List[Dict[str, Any]]]
+ """
+
+ data: List[Dict[str, Any]] = []
+ cursor = self.collection.find(search_data.search)
+
+ cursor.sort({"ip": ASCENDING, "timestamp": DESCENDING}).limit(search_data.limit).skip(search_data.skip)
+
+ try:
+ async for document in cursor:
+ if document is not None:
+ document["_id"] = str(document["_id"])
+ data.append(document)
+
+ return data
+
+ except OperationFailure as exc:
+ print(f"DB failed to process: {exc.details}")
+ raise HTTPException(
+ status_code=400,
+ detail="Probably wrong syntax, note the dictionary for find: "
+ + "https://motor.readthedocs.io/en/stable/tutorial-asyncio.html#async-for",
+ ) from exc
+ except BaseException as exc:
+ print(f"DB connection failed: {exc}")
+ raise HTTPException(status_code=500, detail="DB connection failed") from exc
+
+ async def find_one(self, object_id: ObjectId) -> Optional[Dict[str, Any]]:
+ """Wrap the find_one() method, handling timeouts and return data type.
+
+ :param object_id: The object id to find.
+ :return: Optional[Dict[str, Any]]
+ """
+
+ try:
+ document = await self.collection.find_one({"_id": object_id})
+ if isinstance(document, Dict):
+ document["_id"] = str(document["_id"])
+ return document
+ return None
+
+ except BaseException as exc:
+ print(f"DB connection failed: {exc}")
+ raise HTTPException(status_code=500, detail="DB connection failed") from exc
+
+ async def insert_one(self, data: Dict[str, Any]) -> Optional[ObjectId]:
+ """Wrap the insert_one() method, handling timeouts and return data type.
+
+ :param data: The data to insert into the DB.
+ :return: Optional[ObjectId]
+ """
+
+ try:
+ result = await self.collection.insert_one(data)
+ if isinstance(result.inserted_id, ObjectId) and len(str(result.inserted_id)) == 24:
+ return result.inserted_id
+ return None
+
+ except BaseException as exc:
+ print(f"DB connection failed: {exc}")
+ raise HTTPException(status_code=500, detail="DB connection failed") from exc
+
+ async def replace_one(self, object_id: ObjectId, data: Dict[str, Any]) -> Optional[ObjectId]:
+ """Wrap the replace_one() method, handling timeouts and return data type.
+
+ :param object_id: The object id to replace.
+ :param data: The data to replace with.
+ :return: Optional[ObjectId]
+ """
+
+ try:
+ result = await self.collection.replace_one({"_id": object_id}, data)
+ if result.matched_count == 1:
+ return object_id
+ return None
+
+ except BaseException as exc:
+ print(f"DB connection failed: {exc}")
+ raise HTTPException(status_code=500, detail="DB connection failed") from exc
+
+ async def delete_one(self, object_id: ObjectId) -> Optional[ObjectId]:
+ """Wrap the delete_one() method, handling timeouts and return data type.
+
+ :param object_id: The object id to delete from the DB.
+ :return: Optional[ObjectId]
+ """
+
+ try:
+ result = await self.collection.delete_one({"_id": object_id})
+ if isinstance(result.deleted_count, int) and result.deleted_count == 1:
+ return object_id
+ return None
+
+ except BaseException as exc:
+ print(f"DB connection failed: {exc}")
+ raise HTTPException(status_code=500, detail="DB connection failed") from exc
+
+ async def estimated_document_count(self) -> Optional[int]:
+ """Wrap the estimated_document_count() method, handling timeouts and return data type.
+
+ :return: Optional[int]
+ """
+
+ try:
+ result = await self.collection.estimated_document_count()
+ if isinstance(result, int):
+ return result
+ return None
+
+ except BaseException as exc:
+ print(f"DB connection failed: {exc}")
+ raise HTTPException(status_code=500, detail="DB connection failed") from exc