diff options
-rw-r--r-- | src/soc_collector/db.py | 16 | ||||
-rwxr-xr-x | src/soc_collector/main.py | 24 | ||||
-rw-r--r-- | src/soc_collector/soc_collector_cli.py | 2 |
3 files changed, 31 insertions, 11 deletions
diff --git a/src/soc_collector/db.py b/src/soc_collector/db.py index f537f4a..b10d865 100644 --- a/src/soc_collector/db.py +++ b/src/soc_collector/db.py @@ -6,6 +6,8 @@ from dataclasses import dataclass from fastapi import HTTPException from pydantic import BaseModel +from bson import ObjectId +from bson.errors import InvalidId from pymongo.errors import OperationFailure from pymongo import ( ASCENDING, @@ -15,7 +17,19 @@ from motor.motor_asyncio import ( AsyncIOMotorClient, AsyncIOMotorCollection, ) -from bson import ObjectId + + +def object_id_from_key(key: str) -> ObjectId: + """Get ObjectId from key, 400 if invalid ObjectId + + :param key: Key. + :return: ObjectId + """ + + try: + return ObjectId(key) + except InvalidId as exc: + raise HTTPException(status_code=400, detail="Invalid key/object id") from exc class SearchInput(BaseModel): diff --git a/src/soc_collector/main.py b/src/soc_collector/main.py index eb6041f..e70199b 100755 --- a/src/soc_collector/main.py +++ b/src/soc_collector/main.py @@ -6,10 +6,11 @@ from json.decoder import JSONDecodeError from fastapi import FastAPI, Request from fastapi.responses import JSONResponse -from bson import ObjectId + from .db import ( DBClient, SearchInput, + object_id_from_key, ) from .schema import valid_schema from .auth import authorize_client, load_api_keys @@ -66,6 +67,11 @@ async def create(request: Request) -> JSONResponse: except JSONDecodeError: return JSONResponse(content={"status": "error", "message": "Invalid JSON"}, status_code=400) + if "_id" in json_data: + return JSONResponse( + content={"status": "error", "message": "Internal key '_id' must not exist in document"}, status_code=400 + ) + if not valid_schema(json_data): return JSONResponse(content={"status": "error", "message": "Not our JSON schema"}, status_code=400) @@ -98,11 +104,7 @@ async def replace(request: Request) -> JSONResponse: # pylint: disable=too-many # Get the key if isinstance(json_data["_id"], str): - object_id = ObjectId(json_data["_id"]) - elif ( - isinstance(json_data["_id"], dict) and "$oid" in json_data["_id"] and isinstance(json_data["_id"]["$oid"], str) - ): - object_id = ObjectId(json_data["_id"]["$oid"]) + object_id = object_id_from_key(json_data["_id"]) else: return JSONResponse(content={"status": "error", "message": "Missing key '_id' with valid id"}, status_code=400) @@ -138,7 +140,9 @@ async def get(request: Request, key: str) -> JSONResponse: # Ensure authorization authorize_client(request, API_KEYS) - document = await db.find_one(ObjectId(key)) + object_id = object_id_from_key(key) + + document = await db.find_one(object_id) if document is None: return JSONResponse(content={"status": "error", "message": "Document not found"}, status_code=404) @@ -157,12 +161,14 @@ async def delete(request: Request, key: str) -> JSONResponse: # Ensure authorization authorize_client(request, API_KEYS) - result = await db.delete_one(ObjectId(key)) + object_id = object_id_from_key(key) + + result = await db.delete_one(object_id) if result is None: return JSONResponse(content={"status": "error", "message": "Document not found"}, status_code=404) - return JSONResponse(content={"status": "success", "key": str(key)}) + return JSONResponse(content={"status": "success", "key": str(object_id)}) @app.get("/info") diff --git a/src/soc_collector/soc_collector_cli.py b/src/soc_collector/soc_collector_cli.py index 9f8c793..f9b0fad 100644 --- a/src/soc_collector/soc_collector_cli.py +++ b/src/soc_collector/soc_collector_cli.py @@ -231,7 +231,7 @@ def main() -> None: """, ) - parser.add_argument("data", default="info", help="json blob or path to file") + parser.add_argument("data", nargs="?", help="json blob or path to file") parser.add_argument( "extra_data", nargs="?", |