summaryrefslogtreecommitdiff
path: root/src/collector/db.py
blob: 2f16e12237857d6d984a8b15ecedc047f017286c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
"""Our database module"""
from typing import List, Dict, Optional, Any
from time import sleep
from sys import exit as app_exit
from dataclasses import dataclass

from fastapi import HTTPException
from pydantic import BaseModel
from pymongo.errors import OperationFailure
from motor.motor_asyncio import (
    AsyncIOMotorClient,
    AsyncIOMotorCollection,
)
from bson import ObjectId


class SearchInput(BaseModel):
    """Handle search data for HTTP request"""

    search: Optional[Dict[str, Any]]
    limit: int = 25
    skip: int = 0


@dataclass()
class DBClient:
    """Class to hold database connections for us."""

    client: AsyncIOMotorClient
    collection: AsyncIOMotorCollection

    def __init__(self, username: str, password: str, collection: str) -> None:
        self.client = AsyncIOMotorClient(
            f"mongodb://{username}:{password}@mongodb:27017/production",
            maxConnecting=4,
            timeoutMS=3000,
            serverSelectionTimeoutMS=3000,
        )
        self.collection = self.client["production"][collection]

    async def startup(self) -> None:
        """Try query the DB and exit the program if we fail after 5 times.

        :return: None
        """

        for i in range(5):
            try:
                await self.collection.find_one({"_id": ObjectId("507f1f77bcf86cd799439011")})
                print("Connection to DB - OK")
                break
            except Exception:  # pylint: disable=braod-except
                print(f"WARNING failed to connect to DB - {i} / 4", flush=True)
            sleep(1)
        else:
            print("Could not connect to DB - mongodb://REDACTED_USERNAME:REDACTED_PASSWORD@mongodb:27017/production")
            app_exit(1)

    async def find(self, search_data: SearchInput) -> Optional[List[Dict[str, Any]]]:
        """Wrap the find() method, handling timeouts and return data type.

        :param search_data: Instance of SearchInput.
        :return: Optional[List[Dict[str, Any]]]
        """

        data: List[Dict[str, Any]] = []
        cursor = self.collection.find(search_data.search)

        # Sort on timestamp
        cursor.sort("timestamp", -1).limit(search_data.limit).skip(search_data.skip)

        try:
            async for document in cursor:
                data.append(document)

            if data:
                return data

        except OperationFailure as exc:
            print(f"DB failed to process: {exc.details}")
            raise HTTPException(
                status_code=400,
                detail="Probably wrong syntax, note the dictionary for find: "
                + "https://motor.readthedocs.io/en/stable/tutorial-asyncio.html#async-for",
            ) from exc
        except BaseException as exc:
            print(f"DB connection failed: {exc}")
            raise HTTPException(status_code=500, detail="DB connection failed") from exc

        return None

    async def find_one(self, object_id: ObjectId) -> Optional[Dict[str, Any]]:
        """Wrap the find_one() method, handling timeouts and return data type.

        :param object_id: The object id to find.
        :return: Optional[Dict[str, Any]]
        """

        try:
            document = await self.collection.find_one({"_id": object_id})
            if isinstance(document, Dict):
                return document
            return None

        except BaseException as exc:
            print(f"DB connection failed: {exc}")
            raise HTTPException(status_code=500, detail="DB connection failed") from exc

    async def insert_one(self, data: Dict[str, Any]) -> Optional[ObjectId]:
        """Wrap the insert_one() method, handling timeouts and return data type.

        :param data: The data to insert into the DB.
        :return: Optional[ObjectId]
        """

        try:
            result = await self.collection.insert_one(data)
            if isinstance(result.inserted_id, ObjectId) and len(str(result.inserted_id)) == 24:
                return result.inserted_id
            return None

        except BaseException as exc:
            print(f"DB connection failed: {exc}")
            raise HTTPException(status_code=500, detail="DB connection failed") from exc

    async def replace_one(self, object_id: ObjectId, data: Dict[str, Any]) -> Optional[ObjectId]:
        """Wrap the replace_one() method, handling timeouts and return data type.

        :param object_id: The object id to replace.
        :param data: The data to replace with.
        :return: Optional[ObjectId]
        """

        try:
            result = await self.collection.replace_one({"_id": object_id}, data)
            if result.matched_count == 1:
                return object_id
            return None

        except BaseException as exc:
            print(f"DB connection failed: {exc}")
            raise HTTPException(status_code=500, detail="DB connection failed") from exc

    async def delete_one(self, object_id: ObjectId) -> Optional[ObjectId]:
        """Wrap the delete_one() method, handling timeouts and return data type.

        :param object_id: The object id to delete from the DB.
        :return: Optional[ObjectId]
        """

        try:
            result = await self.collection.delete_one({"_id": object_id})
            if result.deleted_count == 1:
                return object_id

        except BaseException as exc:
            print(f"DB connection failed: {exc}")
            raise HTTPException(status_code=500, detail="DB connection failed") from exc

        return None

    async def estimated_document_count(self) -> Optional[int]:
        """Wrap the estimated_document_count() method, handling timeouts and return data type.

        :return: Optional[int]
        """

        try:
            result = await self.collection.estimated_document_count()
            if isinstance(result, int):
                return result
            return None

        except BaseException as exc:
            print(f"DB connection failed: {exc}")
            raise HTTPException(status_code=500, detail="DB connection failed") from exc