"""Our database module""" from typing import List, Dict, Optional, Any from time import sleep from sys import exit as app_exit from dataclasses import dataclass from fastapi import HTTPException from pydantic import BaseModel from pymongo.errors import OperationFailure from motor.motor_asyncio import ( AsyncIOMotorClient, AsyncIOMotorCollection, ) from bson import ObjectId class SearchInput(BaseModel): """Handle search data for HTTP request""" search: Optional[Dict[str, Any]] limit: int = 25 skip: int = 0 @dataclass() class DBClient: """Class to hold database connections for us.""" client: AsyncIOMotorClient collection: AsyncIOMotorCollection def __init__(self, username: str, password: str, collection: str) -> None: self.client = AsyncIOMotorClient( f"mongodb://{username}:{password}@mongodb:27017/production", maxConnecting=4, timeoutMS=3000, serverSelectionTimeoutMS=3000, ) self.collection = self.client["production"][collection] async def startup(self) -> None: """Try query the DB and exit the program if we fail after 5 times. :return: None """ for i in range(5): try: await self.collection.find_one({"_id": ObjectId("507f1f77bcf86cd799439011")}) print("Connection to DB - OK") break except Exception: # pylint: disable=broad-except print(f"WARNING failed to connect to DB - {i} / 4", flush=True) sleep(1) else: print("Could not connect to DB - mongodb://REDACTED_USERNAME:REDACTED_PASSWORD@mongodb:27017/production") app_exit(1) async def find(self, search_data: SearchInput) -> List[Dict[str, Any]]: """Wrap the find() method, handling timeouts and return data type. :param search_data: Instance of SearchInput. :return: Optional[List[Dict[str, Any]]] """ data: List[Dict[str, Any]] = [] cursor = self.collection.find(search_data.search) # Sort on timestamp # TODO: Also sort on IP as well cursor.sort("timestamp", -1).limit(search_data.limit).skip(search_data.skip) try: async for document in cursor: if document is not None: document["_id"] = str(document["_id"]) data.append(document) return data except OperationFailure as exc: print(f"DB failed to process: {exc.details}") raise HTTPException( status_code=400, detail="Probably wrong syntax, note the dictionary for find: " + "https://motor.readthedocs.io/en/stable/tutorial-asyncio.html#async-for", ) from exc except BaseException as exc: print(f"DB connection failed: {exc}") raise HTTPException(status_code=500, detail="DB connection failed") from exc async def find_one(self, object_id: ObjectId) -> Optional[Dict[str, Any]]: """Wrap the find_one() method, handling timeouts and return data type. :param object_id: The object id to find. :return: Optional[Dict[str, Any]] """ try: document = await self.collection.find_one({"_id": object_id}) if isinstance(document, Dict): document["_id"] = str(document["_id"]) return document return None except BaseException as exc: print(f"DB connection failed: {exc}") raise HTTPException(status_code=500, detail="DB connection failed") from exc async def insert_one(self, data: Dict[str, Any]) -> Optional[ObjectId]: """Wrap the insert_one() method, handling timeouts and return data type. :param data: The data to insert into the DB. :return: Optional[ObjectId] """ try: result = await self.collection.insert_one(data) if isinstance(result.inserted_id, ObjectId) and len(str(result.inserted_id)) == 24: return result.inserted_id return None except BaseException as exc: print(f"DB connection failed: {exc}") raise HTTPException(status_code=500, detail="DB connection failed") from exc async def replace_one(self, object_id: ObjectId, data: Dict[str, Any]) -> Optional[ObjectId]: """Wrap the replace_one() method, handling timeouts and return data type. :param object_id: The object id to replace. :param data: The data to replace with. :return: Optional[ObjectId] """ try: result = await self.collection.replace_one({"_id": object_id}, data) if result.matched_count == 1: return object_id return None except BaseException as exc: print(f"DB connection failed: {exc}") raise HTTPException(status_code=500, detail="DB connection failed") from exc async def delete_one(self, object_id: ObjectId) -> Optional[ObjectId]: """Wrap the delete_one() method, handling timeouts and return data type. :param object_id: The object id to delete from the DB. :return: Optional[ObjectId] """ try: result = await self.collection.delete_one({"_id": object_id}) if isinstance(result.deleted_count, int) and result.deleted_count == 1: return object_id return None except BaseException as exc: print(f"DB connection failed: {exc}") raise HTTPException(status_code=500, detail="DB connection failed") from exc async def estimated_document_count(self) -> Optional[int]: """Wrap the estimated_document_count() method, handling timeouts and return data type. :return: Optional[int] """ try: result = await self.collection.estimated_document_count() if isinstance(result, int): return result return None except BaseException as exc: print(f"DB connection failed: {exc}") raise HTTPException(status_code=500, detail="DB connection failed") from exc