diff options
-rw-r--r-- | data/mongodb_container/Dockerfile | 2 | ||||
-rwxr-xr-x | dev-run.sh | 98 | ||||
-rw-r--r-- | src/soc_collector/db.py | 39 | ||||
-rwxr-xr-x | src/soc_collector/main.py | 46 | ||||
-rw-r--r-- | src/soc_collector/schema.py | 28 | ||||
-rw-r--r-- | src/soc_collector/soc_collector_cli.py | 130 | ||||
-rw-r--r-- | tests/__init__.py | 0 | ||||
-rw-r--r-- | tests/data/example_data_1.json (renamed from data/example_data_1.json) | 0 | ||||
-rw-r--r-- | tests/data/example_data_1_replace_test.json | 57 | ||||
-rw-r--r-- | tests/data/example_data_3.json (renamed from data/example_data_3.json) | 0 | ||||
-rw-r--r-- | tests/data/example_data_3_replace_test.json (renamed from data/example_data_3_replace_test.json) | 0 | ||||
-rw-r--r-- | tests/test_auth.py | 275 | ||||
-rw-r--r-- | tests/test_delete.py | 86 | ||||
-rw-r--r-- | tests/test_get.py | 82 | ||||
-rw-r--r-- | tests/test_info.py | 37 | ||||
-rw-r--r-- | tests/test_insert.py | 110 | ||||
-rw-r--r-- | tests/test_replace.py | 102 | ||||
-rw-r--r-- | tests/test_search.py | 125 | ||||
-rw-r--r-- | tests/test_soc_collector_cli.py | 244 |
19 files changed, 1259 insertions, 202 deletions
diff --git a/data/mongodb_container/Dockerfile b/data/mongodb_container/Dockerfile index 8e17161..16e6b1f 100644 --- a/data/mongodb_container/Dockerfile +++ b/data/mongodb_container/Dockerfile @@ -20,7 +20,7 @@ RUN find / -xdev -perm /6000 -type f -exec chmod a-s {} \; || true COPY ./data/mongodb_entrypoint.sh /mongodb_entrypoint.sh COPY ./data/init-mongodb.js /init-mongodb.js COPY ./data/healthcheck-mongodb.js /healthcheck-mongodb.js -COPY ./healthcheck.sh /healthcheck.sh +COPY ./data/healthcheck.sh /healthcheck.sh USER mongodb @@ -1,103 +1,23 @@ #!/bin/bash echo "Checking package" -mypy --strict --namespace-packages --ignore-missing-imports --cache-dir=/tmp/ src/soc_collector/*.py # || exit 1 +mypy --strict --namespace-packages --ignore-missing-imports --cache-dir=/tmp/ src/soc_collector/*.py || exit 1 black --line-length 120 src/soc_collector/*.py # || exit 1 pylint --max-line-length 120 src/soc_collector/*.py # || exit 1 +echo "Checking tests" +mypy --strict --namespace-packages --ignore-missing-imports --cache-dir=/tmp/ tests/*.py # || exit 1 +black --line-length 120 tests/*.py # || exit 1 +pylint --disable R0801 --max-line-length 120 tests/*.py # || exit 1 + mkdir -p data/mongodb_data -sudo chown -R $USER data/mongodb_data +sudo chown -R "$USER" data/mongodb_data docker-compose -f docker-compose.yml build sudo chown -R 101 data/mongodb_data docker-compose -f docker-compose.yml up -d sleep 3 - - -echo -echo -curl -v -k --data-binary @data/example_data_3.json https://127.0.0.1:8000/sc/v0 -echo -echo - -curl -v -k -X DELETE https://127.0.0.1:8000/sc/v0/63702570e004d2b0b2254d27 -echo -echo -curl -v -k -X DELETE https://127.0.0.1:8000/sc/v0/63702570e004d2b0b2254d27 -echo -echo - -curl -v -k -d '{"search": {"port": {"$lt": 4}}}' -H 'Content-Type: application/json' https://127.0.0.1:8000/sc/v0/search -echo -echo -curl -v -k -d '{"search": {"port": 112}}' -H 'Content-Type: application/json' https://127.0.0.1:8000/sc/v0/search -echo -echo -curl -v -k -d '{"search": {"port": {"$gt": 4}}}' -H 'Content-Type: application/json' https://127.0.0.1:8000/sc/v0/search -echo -echo -curl -v -k -d '{"search": {"port": 111}}' -H 'Content-Type: application/json' https://127.0.0.1:8000/sc/v0/search -echo -echo -curl -v -k -d '{"search": {"port": {"sdfsf": 7}}}' -H 'Content-Type: application/json' https://127.0.0.1:8000/sc/v0/search -echo -echo -curl -v -k -d '{"search": {"port": {"$sdfsf": 7}}}' -H 'Content-Type: application/json' https://127.0.0.1:8000/sc/v0/search -echo -echo -curl -v -k -d '{"search": {"portfdv": {"$asa": 7}}}' -H 'Content-Type: application/json' https://127.0.0.1:8000/sc/v0/search -echo -echo - -echo -echo -curl -v -k -X PUT --data-binary @data/example_data_3_replace_test.json https://127.0.0.1:8000/sc/v0 - -echo -echo -curl -v -k https://127.0.0.1:8000/info - - -# bash quickstart.sh -b || exit 1 -# sleep 3 -# JWT=$(curl -k http://localhost:8000/api/v1.0/auth -X POST -p -u usr:pwd | jq -r .access_token) || exit 1 -# curl -k --data-binary @example_data_1.json -H "Authorization: Bearer $JWT" https://localhost:1443/sc/v0/add || exit 1 -# curl -k --data-binary @example_data_3.json -H "Authorization: Bearer $JWT" https://localhost:1443/sc/v0/add || exit 1 -# sleep 1 -# curl -k -H "Authorization: Bearer $JWT" https://localhost:1443/sc/v0/get | json_pp -json_opt utf8,pretty || exit 1 - -# curl -k -H "Authorization: Bearer $JWT" https://localhost:1443/sc/v0/get?port=111 || exit 1 - -# echo "OK" -# exit 0 - - -#echo "Checking tests" -#mypy --strict --namespace-packages --ignore-missing-imports --cache-dir=/dev/null tests/*.py || exit 1 -#black --line-length 120 tests/*.py || exit 1 -#pylint --max-line-length 120 tests/*.py || exit 1 - -# Stop old container, build and run the new one -# docker build -t pkcs11_ca_service_http . -# docker stop /pkcs11_ca_service_http -# docker rm /pkcs11_ca_service_http -# docker run \ -# --name pkcs11_ca_service_http \ -# --net pkcs11_ca_service_network \ -# --restart always \ -# --security-opt no-new-privileges \ -# --cap-drop all \ -# --read-only \ -# --memory 256m \ -# --cpus 2.75 \ -# --mount type=tmpfs,target=/dev/shm,readonly=true \ -# -v /app_softhsm:/var/lib/softhsm/tokens \ -# -p 8000:8000 \ -# -d \ -# pkcs11_ca_service_http - -# sleep 2 -# echo "Running tests" -# python3 -m unittest +echo "Running tests" +python3 -m unittest diff --git a/src/soc_collector/db.py b/src/soc_collector/db.py index b10d865..b1501d8 100644 --- a/src/soc_collector/db.py +++ b/src/soc_collector/db.py @@ -7,7 +7,6 @@ from dataclasses import dataclass from fastapi import HTTPException from pydantic import BaseModel from bson import ObjectId -from bson.errors import InvalidId from pymongo.errors import OperationFailure from pymongo import ( ASCENDING, @@ -19,19 +18,6 @@ from motor.motor_asyncio import ( ) -def object_id_from_key(key: str) -> ObjectId: - """Get ObjectId from key, 400 if invalid ObjectId - - :param key: Key. - :return: ObjectId - """ - - try: - return ObjectId(key) - except InvalidId as exc: - raise HTTPException(status_code=400, detail="Invalid key/object id") from exc - - class SearchInput(BaseModel): """Handle search data for HTTP request""" @@ -71,7 +57,10 @@ class DBClient: print(f"WARNING failed to connect to DB - {i} / 4", flush=True) sleep(1) else: - print("Could not connect to DB - mongodb://REDACTED_USERNAME:REDACTED_PASSWORD@mongodb:27017/production") + print( + "Could not connect to DB - mongodb://REDACTED_USERNAME:REDACTED_PASSWORD@mongodb:27017/production", + flush=True, + ) app_exit(1) async def find(self, search_data: SearchInput) -> List[Dict[str, Any]]: @@ -82,11 +71,11 @@ class DBClient: """ data: List[Dict[str, Any]] = [] - cursor = self.collection.find(search_data.filter) - - cursor.sort([("ip", ASCENDING), ("timestamp", DESCENDING)]).limit(search_data.limit).skip(search_data.skip) try: + cursor = self.collection.find(search_data.filter) + cursor.sort([("ip", ASCENDING), ("timestamp", DESCENDING)]).limit(search_data.limit).skip(search_data.skip) + async for document in cursor: if document is not None: document["_id"] = str(document["_id"]) @@ -101,7 +90,7 @@ class DBClient: detail="Probably wrong syntax, note the dictionary for find: " + "https://motor.readthedocs.io/en/stable/tutorial-asyncio.html#async-for", ) from exc - except BaseException as exc: + except Exception as exc: print(f"DB connection failed: {exc}") raise HTTPException(status_code=500, detail="DB connection failed") from exc @@ -119,7 +108,7 @@ class DBClient: return document return None - except BaseException as exc: + except Exception as exc: print(f"DB connection failed: {exc}") raise HTTPException(status_code=500, detail="DB connection failed") from exc @@ -136,7 +125,7 @@ class DBClient: return result.inserted_id return None - except BaseException as exc: + except Exception as exc: print(f"DB connection failed: {exc}") raise HTTPException(status_code=500, detail="DB connection failed") from exc @@ -154,7 +143,7 @@ class DBClient: return object_id return None - except BaseException as exc: + except Exception as exc: print(f"DB connection failed: {exc}") raise HTTPException(status_code=500, detail="DB connection failed") from exc @@ -171,7 +160,7 @@ class DBClient: return object_id return None - except BaseException as exc: + except Exception as exc: print(f"DB connection failed: {exc}") raise HTTPException(status_code=500, detail="DB connection failed") from exc @@ -182,11 +171,11 @@ class DBClient: """ try: - result = await self.collection.estimated_document_count() + result = await self.collection.estimated_document_count(maxTimeMS=4000) if isinstance(result, int): return result return None - except BaseException as exc: + except Exception as exc: print(f"DB connection failed: {exc}") raise HTTPException(status_code=500, detail="DB connection failed") from exc diff --git a/src/soc_collector/main.py b/src/soc_collector/main.py index e70199b..fd1bded 100755 --- a/src/soc_collector/main.py +++ b/src/soc_collector/main.py @@ -10,9 +10,8 @@ from fastapi.responses import JSONResponse from .db import ( DBClient, SearchInput, - object_id_from_key, ) -from .schema import valid_schema +from .schema import valid_schema, object_id_from_data from .auth import authorize_client, load_api_keys # Get credentials @@ -52,10 +51,10 @@ async def search(request: Request, search_data: SearchInput) -> JSONResponse: @app.post("/sc/v0") -async def create(request: Request) -> JSONResponse: +async def insert(request: Request) -> JSONResponse: """/sc/v0, POST method - :param request: The request where we get the json body. + :param request: The client request. :return: JSONResponse """ @@ -75,19 +74,19 @@ async def create(request: Request) -> JSONResponse: if not valid_schema(json_data): return JSONResponse(content={"status": "error", "message": "Not our JSON schema"}, status_code=400) - key = await db.insert_one(json_data) + object_id = await db.insert_one(json_data) - if key is None: + if object_id is None: return JSONResponse(content={"status": "error", "message": "DB error"}, status_code=500) - return JSONResponse(content={"status": "success", "key": str(key)}) + return JSONResponse(content={"status": "success", "_id": str(object_id)}) @app.put("/sc/v0") async def replace(request: Request) -> JSONResponse: # pylint: disable=too-many-return-statements """/sc/v0, PUT method - :param request: The request where we get the json body. + :param request: The client request. :return: JSONResponse """ @@ -99,13 +98,9 @@ async def replace(request: Request) -> JSONResponse: # pylint: disable=too-many except JSONDecodeError: return JSONResponse(content={"status": "error", "message": "Invalid JSON"}, status_code=400) - if "_id" not in json_data: - return JSONResponse(content={"status": "error", "message": "Missing key '_id'"}, status_code=400) - # Get the key - if isinstance(json_data["_id"], str): - object_id = object_id_from_key(json_data["_id"]) - else: + object_id = object_id_from_data(json_data) + if object_id is None: return JSONResponse(content={"status": "error", "message": "Missing key '_id' with valid id"}, status_code=400) # Ensure the updating key exist @@ -123,24 +118,28 @@ async def replace(request: Request) -> JSONResponse: # pylint: disable=too-many json_data["_id"] = object_id returned_object_id = await db.replace_one(object_id, json_data) - if returned_object_id is None: + if returned_object_id is None or returned_object_id != object_id: return JSONResponse(content={"status": "error", "message": "DB error"}, status_code=500) - return JSONResponse(content={"status": "success", "key": str(object_id)}) + return JSONResponse(content={"status": "success", "_id": str(object_id)}) @app.get("/sc/v0/{key}") async def get(request: Request, key: str) -> JSONResponse: """/sc/v0/{key}, GET method - :param key: The document key in the database. + :param request: The client request. + :param key: The document id in the database. :return: JSONResponse """ # Ensure authorization authorize_client(request, API_KEYS) - object_id = object_id_from_key(key) + # Get the id + object_id = object_id_from_data(key) + if object_id is None: + return JSONResponse(content={"status": "error", "message": "Invalid id"}, status_code=400) document = await db.find_one(object_id) @@ -154,27 +153,32 @@ async def get(request: Request, key: str) -> JSONResponse: async def delete(request: Request, key: str) -> JSONResponse: """/sc/v0/{key}, DELETE method - :param key: The document key in the database. + :param request: The client request. + :param key: The document id in the database. :return: JSONResponse """ # Ensure authorization authorize_client(request, API_KEYS) - object_id = object_id_from_key(key) + # Get the id + object_id = object_id_from_data(key) + if object_id is None: + return JSONResponse(content={"status": "error", "message": "Invalid id"}, status_code=400) result = await db.delete_one(object_id) if result is None: return JSONResponse(content={"status": "error", "message": "Document not found"}, status_code=404) - return JSONResponse(content={"status": "success", "key": str(object_id)}) + return JSONResponse(content={"status": "success", "_id": str(object_id)}) @app.get("/info") async def info(request: Request) -> JSONResponse: """/info, GET method + :param request: The client request. :return: JSONResponse """ diff --git a/src/soc_collector/schema.py b/src/soc_collector/schema.py index 221990a..2c2dfb9 100644 --- a/src/soc_collector/schema.py +++ b/src/soc_collector/schema.py @@ -1,6 +1,8 @@ """Our schema module""" -from typing import Any, Dict +from typing import Any, Dict, Optional, Union import jsonschema +from bson import ObjectId +from bson.errors import InvalidId # fmt:off # NOTE: Commented out properties are left intentionally, so it is easier to see @@ -99,9 +101,33 @@ def valid_schema(json_data: Dict[str, Any]) -> bool: :param json_data: Json object :return: bool """ + try: jsonschema.validate(json_data, schema, format_checker=jsonschema.FormatChecker()) except jsonschema.exceptions.ValidationError as exc: print(f"Validation failed with error: {exc.message}") return False return True + + +def object_id_from_data(data: Union[str, Dict[str, Any]]) -> Optional[ObjectId]: + """Get ObjectId from key. None if invalid. + + :param data: Key. + :return: Optional[ObjectId] + """ + + if isinstance(data, str): + try: + return ObjectId(data) + except InvalidId: + return None + + elif isinstance(data, Dict): + if "_id" in data and isinstance(data["_id"], str): + try: + return ObjectId(data["_id"]) + except InvalidId: + return None + + return None diff --git a/src/soc_collector/soc_collector_cli.py b/src/soc_collector/soc_collector_cli.py index d7add30..4929655 100644 --- a/src/soc_collector/soc_collector_cli.py +++ b/src/soc_collector/soc_collector_cli.py @@ -8,31 +8,11 @@ from sys import exit as app_exit import json import requests -if "COLLECTOR_API_KEY" not in environ: - print("Missing 'COLLECTOR_API_KEY' in environment") - app_exit(1) -API_KEY = environ["COLLECTOR_API_KEY"] +from .schema import object_id_from_data -API_URL = "https://collector-dev.soc.sunet.se:8000" ROOT_CA_FILE = __file__.replace("soc_collector_cli.py", "data/collector_root_ca.crt") -def valid_key(key: str) -> None: - """Ensure the document key is valid. exit(1) otherwise. - - :param key: The key. - """ - valid_chars = ["a", "b", "c", "d", "e", "f", "1", "2", "3", "4", "5", "6", "7", "8", "9", "0"] - if len(key) != 24: - print(f"ERROR: Invalid key '{key}'") - app_exit(1) - - for char in key: - if char not in valid_chars: - print(f"ERROR: Invalid key '{key}'") - app_exit(1) - - def json_load_data(data: str) -> Dict[str, Any]: """Load json from argument, json data or path to json file @@ -54,10 +34,14 @@ def json_load_data(data: str) -> Dict[str, Any]: app_exit(1) -def info_action() -> None: - """Get database info, currently number of documents.""" +def info_action(api_key: str, base_url: str = "https://collector-dev.soc.sunet.se:8000") -> None: + """Get database info, currently number of documents. + + :param api_key: The API key. + :param base_url: URL to the API. + """ - req = requests.get(f"{API_URL}/info", headers={"API-KEY": API_KEY}, timeout=5, verify=ROOT_CA_FILE) + req = requests.get(f"{base_url}/info", headers={"API-KEY": api_key}, timeout=5, verify=ROOT_CA_FILE) # Ensure ok status req.raise_for_status() @@ -67,16 +51,18 @@ def info_action() -> None: print(f"Estimated document count: {json_data['Estimated document count']}") -def search_action(data: str) -> None: +def search_action(data: str, api_key: str, base_url: str = "https://collector-dev.soc.sunet.se:8000") -> None: """Search for documents in the database. :param data: String with either json or path to a json file. + :param api_key: The API key. + :param base_url: URL to the API. """ search_data = json_load_data(data) req = requests.post( - f"{API_URL}/sc/v0/search", headers={"API-KEY": API_KEY}, json=search_data, timeout=5, verify=ROOT_CA_FILE + f"{base_url}/sc/v0/search", headers={"API-KEY": api_key}, json=search_data, timeout=5, verify=ROOT_CA_FILE ) # Ensure ok status @@ -87,24 +73,26 @@ def search_action(data: str) -> None: print(json.dumps(json_data["docs"], indent=4)) -def delete_action(data: str) -> None: +def delete_action(data: str, api_key: str, base_url: str = "https://collector-dev.soc.sunet.se:8000") -> None: """Delete a document in the DB. - :param data: key or path to a json file containing "_id". + :param data: id or path to a json file containing "_id". + :param api_key: The API key. + :param base_url: URL to the API. """ if data and isfile(data): json_data = json_load_data(data) - - if "_id" not in json_data or not isinstance(json_data["_id"], str): - print("ERROR: Valid '_id' key not in data") - app_exit(1) - key: str = json_data["_id"] + object_id = object_id_from_data(json_data) else: - key = data + object_id = object_id_from_data(data) - valid_key(key) + if object_id is None: + print("ERROR: id is not valid") + app_exit(1) - req = requests.delete(f"{API_URL}/sc/v0/{key}", headers={"API-KEY": API_KEY}, timeout=5, verify=ROOT_CA_FILE) + req = requests.delete( + f"{base_url}/sc/v0/{str(object_id)}", headers={"API-KEY": api_key}, timeout=5, verify=ROOT_CA_FILE + ) # Check status if req.status_code == 404: @@ -114,7 +102,7 @@ def delete_action(data: str) -> None: # Ensure ok status req.raise_for_status() - print(f"Deleted data OK - key: {key}") + print(f"Deleted data OK - key: {str(object_id)}") def update_local_action(data: str, update_data: str) -> None: @@ -133,21 +121,24 @@ def update_local_action(data: str, update_data: str) -> None: print(json.dumps(json_data, indent=4)) -def replace_action(data: str) -> None: +def replace_action(data: str, api_key: str, base_url: str = "https://collector-dev.soc.sunet.se:8000") -> None: """Replace the entire document in the database with this document, "_id" must exist as a key. :param data: json blob or path to json file, "_id" key must exist. + :param api_key: The API key. + :param base_url: URL to the API. """ json_data = json_load_data(data) - if "_id" not in json_data or not isinstance(json_data["_id"], str): + object_id = object_id_from_data(json_data) + if object_id is None: print("ERROR: Valid '_id' key not in data") app_exit(1) - valid_key(json_data["_id"]) - - req = requests.put(f"{API_URL}/sc/v0", json=json_data, headers={"API-KEY": API_KEY}, timeout=5, verify=ROOT_CA_FILE) + req = requests.put( + f"{base_url}/sc/v0", json=json_data, headers={"API-KEY": api_key}, timeout=5, verify=ROOT_CA_FILE + ) # Check status if req.status_code == 404: @@ -158,13 +149,15 @@ def replace_action(data: str) -> None: req.raise_for_status() json_data = json.loads(req.text) - print(f'Replaced data OK - key: {json_data["key"]}') + print(f'Replaced data OK - key: {json_data["_id"]}') -def insert_action(data: str) -> None: +def insert_action(data: str, api_key: str, base_url: str = "https://collector-dev.soc.sunet.se:8000") -> None: """Insert a new document into the database, "_id" must not exist in the document. :param data: json blob or path to json file, "_id" key must not exist. + :param api_key: The API key. + :param base_url: URL to the API. """ json_data = json_load_data(data) @@ -174,34 +167,35 @@ def insert_action(data: str) -> None: app_exit(1) req = requests.post( - f"{API_URL}/sc/v0", json=json_data, headers={"API-KEY": API_KEY}, timeout=5, verify=ROOT_CA_FILE + f"{base_url}/sc/v0", json=json_data, headers={"API-KEY": api_key}, timeout=5, verify=ROOT_CA_FILE ) # Ensure ok status req.raise_for_status() json_data = json.loads(req.text) - print(f'Inserted data OK - key: {json_data["key"]}') + print(f'Inserted data OK - key: {json_data["_id"]}') -def get_action(data: str) -> None: +def get_action(data: str, api_key: str, base_url: str = "https://collector-dev.soc.sunet.se:8000") -> None: """Get a document from the database. - :param data: key or path to a json file containing "_id". + :param data: id or path to a json file containing "_id". + :param api_key: The API key. + :param base_url: URL to the API. """ if data and isfile(data): json_data = json_load_data(data) - - if "_id" not in json_data or not isinstance(json_data["_id"], str): - print("ERROR: Valid '_id' key not in data") - app_exit(1) - key: str = json_data["_id"] + object_id = object_id_from_data(json_data) else: - key = data - - valid_key(key) + object_id = object_id_from_data(data) + if object_id is None: + print("ERROR: Invalid id") + app_exit(1) - req = requests.get(f"{API_URL}/sc/v0/{key}", headers={"API-KEY": API_KEY}, timeout=5, verify=ROOT_CA_FILE) + req = requests.get( + f"{base_url}/sc/v0/{str(object_id)}", headers={"API-KEY": api_key}, timeout=5, verify=ROOT_CA_FILE + ) # Check status if req.status_code == 404: @@ -218,6 +212,7 @@ def get_action(data: str) -> None: def main() -> None: """Main function.""" + parser = ArgumentParser(formatter_class=RawTextHelpFormatter, description="SOC Collector CLI") parser.add_argument( "action", @@ -231,7 +226,7 @@ def main() -> None: '{"filter": {"asn_country_code": "SE", "result": {"$exists": "cve_2015_0002"}}}' '{"filter": {}}' - get: key OR path to document using its "_id". + get: id OR path to document using its "_id". 637162378c92893fff92bf7e OR ./data.json insert: json blob OR path to file. Document MUST NOT contain "_id". @@ -244,7 +239,7 @@ def main() -> None: This does NOT send data to the database, use replace for that. 1st ARG: '{json_blob_here...}' OR ./data.json 2th ARG:'{"port": 555, "some_key": "some_data"}' OR ./data.json - delete: key OR path to file using its "_id". + delete: id OR path to file using its "_id". 637162378c92893fff92bf7e OR ./data.json @@ -288,20 +283,25 @@ def main() -> None: args = parser.parse_args() + if "COLLECTOR_API_KEY" not in environ: + print("Missing 'COLLECTOR_API_KEY' in environment") + app_exit(1) + api_key = environ["COLLECTOR_API_KEY"] + if args.action == "get": - get_action(args.data) + get_action(args.data, api_key) elif args.action == "insert": - insert_action(args.data) + insert_action(args.data, api_key) elif args.action == "replace": - replace_action(args.data) + replace_action(args.data, api_key) elif args.action == "update_local" and args.extra_data is not None: update_local_action(args.data, args.extra_data) elif args.action == "delete": - delete_action(args.data) + delete_action(args.data, api_key) elif args.action == "search": - search_action(args.data) + search_action(args.data, api_key) elif args.action == "info": - info_action() + info_action(api_key) else: print("ERROR: Wrong action") app_exit(1) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/__init__.py diff --git a/data/example_data_1.json b/tests/data/example_data_1.json index 69f5d85..69f5d85 100644 --- a/data/example_data_1.json +++ b/tests/data/example_data_1.json diff --git a/tests/data/example_data_1_replace_test.json b/tests/data/example_data_1_replace_test.json new file mode 100644 index 0000000..f56d82c --- /dev/null +++ b/tests/data/example_data_1_replace_test.json @@ -0,0 +1,57 @@ +{ + "document_version": 2, + "ip": "192.0.2.10", + "port": 444, + "whois_description": "SOMENET", + "asn": "AS65001", + "asn_country_code": "SE", + "ptr": "host10.test.soc.sunet.se", + "abuse_mail": "abuse@test.soc.sunet.se", + "domain": "sunet.se", + "timestamp": "2021-06-21T14:06:00Z", + "display_name": "Apache 2.1.3", + "description": "The Apache HTTP Server is a free and open-source cross-platform web server software, released under the terms of Apache License 2.0.", + "custom_data": { + "subject_cn": { + "data": "Apache", + "display_name": "Subject Common Name" + }, + "end_of_general_support": { + "data": false, + "display_name": "End of general support", + "description": "Is the software currently supported?" + } + }, + "result": { + "cve_2015_0049": { + "display_name": "CVE-2015-0049", + "vulnerable": false, + "description": "Allows remote attackers to execute arbitrary code or cause a denial of service (memory corruption)." + }, + "cve_2015_0050": { + "display_name": "CVE-2015-0050", + "vulnerable": false + }, + "cve_2015_0060": { + "display_name": "CVE-2015-0060", + "vulnerable": true, + "reliability": 2 + }, + "cve_2015_0063": { + "display_name": "CVE-2015-0063", + "vulnerable": false + }, + "insecure_cryptography": { + "display_name": "Insecure cryptography", + "vulnerable": true, + "reliability": 5, + "description": "Uses RSA instead of elliptic curve." + }, + "possible_webshell": { + "display_name": "Webshells (PST)", + "investigation_needed": true, + "reliability": 1, + "description": "A webshell of type PST was confirmed at /test/webshell.php" + } + } +} diff --git a/data/example_data_3.json b/tests/data/example_data_3.json index 44d483b..44d483b 100644 --- a/data/example_data_3.json +++ b/tests/data/example_data_3.json diff --git a/data/example_data_3_replace_test.json b/tests/data/example_data_3_replace_test.json index 31cc64d..31cc64d 100644 --- a/data/example_data_3_replace_test.json +++ b/tests/data/example_data_3_replace_test.json diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..d3fbf23 --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,275 @@ +""" +Test our auth +""" +import unittest +import json + +import requests + +from src.soc_collector.auth import load_api_keys +from src.soc_collector.soc_collector_cli import json_load_data + +BASE_URL = "https://localhost:8000" + + +class TestAuth(unittest.TestCase): + """ + Test our auth + """ + + def test_auth_info(self) -> None: + """ + Test auth info + """ + + api_keys = load_api_keys("data/api_keys.txt") + + # Test no key + req = requests.get(f"{BASE_URL}/info", timeout=5, verify="./data/collector_root_ca.crt") + self.assertTrue(req.status_code == 401) + + # Test wrong api key + request_headers = {"API-KEY": "dummy"} + req = requests.get( + f"{BASE_URL}/info", headers=request_headers, timeout=4, verify="./data/collector_root_ca.crt" + ) + self.assertTrue(req.status_code == 401) + + # OK api key + request_headers = {"API-KEY": api_keys[-1]} + req = requests.get( + f"{BASE_URL}/info", headers=request_headers, timeout=4, verify="./data/collector_root_ca.crt" + ) + self.assertTrue(req.status_code == 200) + + def test_auth_insert(self) -> None: + """ + Test auth insert + """ + + api_keys = load_api_keys("data/api_keys.txt") + insert_data = json_load_data("./tests/data/example_data_1.json") + + # Test no key + req = requests.post(f"{BASE_URL}/sc/v0", json=insert_data, timeout=5, verify="./data/collector_root_ca.crt") + self.assertTrue(req.status_code == 401) + + # Test wrong api key + request_headers = {"API-KEY": "dummy"} + req = requests.post( + f"{BASE_URL}/sc/v0", + json=insert_data, + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 401) + + # OK api key + request_headers = {"API-KEY": api_keys[-1]} + req = requests.post( + f"{BASE_URL}/sc/v0", + json=insert_data, + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 200) + + # Delete test data + key = json.loads(req.text)["_id"] + req = requests.delete( + f"{BASE_URL}/sc/v0/{key}", headers=request_headers, timeout=4, verify="./data/collector_root_ca.crt" + ) + self.assertTrue(req.status_code == 200) + + def test_auth_replace(self) -> None: + """ + Test auth replace + """ + + api_keys = load_api_keys("data/api_keys.txt") + insert_data = json_load_data("./tests/data/example_data_1.json") + replace_data = json_load_data("./tests/data/example_data_1_replace_test.json") + + request_headers = {"API-KEY": api_keys[0]} + req = requests.post( + f"{BASE_URL}/sc/v0", + json=insert_data, + headers=request_headers, + timeout=5, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 200) + replace_data["_id"] = json.loads(req.text)["_id"] + + # Test no key + req = requests.put(f"{BASE_URL}/sc/v0", json=replace_data, timeout=5, verify="./data/collector_root_ca.crt") + self.assertTrue(req.status_code == 401) + + # Test wrong api key + request_headers = {"API-KEY": "dummy"} + req = requests.put( + f"{BASE_URL}/sc/v0", + json=replace_data, + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 401) + + # OK api key + request_headers = {"API-KEY": api_keys[-1]} + req = requests.put( + f"{BASE_URL}/sc/v0", + json=replace_data, + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 200) + + # Delete test data + req = requests.delete( + f"{BASE_URL}/sc/v0/{replace_data['_id']}", + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 200) + + def test_auth_get(self) -> None: + """ + Test auth get + """ + + api_keys = load_api_keys("data/api_keys.txt") + insert_data = json_load_data("./tests/data/example_data_1.json") + + request_headers = {"API-KEY": api_keys[-1]} + req = requests.post( + f"{BASE_URL}/sc/v0", + json=insert_data, + headers=request_headers, + timeout=5, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 200) + key = json.loads(req.text)["_id"] + + # Test no key + req = requests.get(f"{BASE_URL}/sc/v0/{key}", timeout=5, verify="./data/collector_root_ca.crt") + self.assertTrue(req.status_code == 401) + + # Test wrong api key + request_headers = {"API-KEY": "dummy"} + req = requests.get( + f"{BASE_URL}/sc/v0/{key}", headers=request_headers, timeout=4, verify="./data/collector_root_ca.crt" + ) + self.assertTrue(req.status_code == 401) + + # OK api key + request_headers = {"API-KEY": api_keys[-1]} + req = requests.get( + f"{BASE_URL}/sc/v0/{key}", headers=request_headers, timeout=4, verify="./data/collector_root_ca.crt" + ) + self.assertTrue(req.status_code == 200) + + # Delete test data + req = requests.delete( + f"{BASE_URL}/sc/v0/{key}", headers=request_headers, timeout=4, verify="./data/collector_root_ca.crt" + ) + self.assertTrue(req.status_code == 200) + + def test_auth_delete(self) -> None: + """ + Test auth delete + """ + + api_keys = load_api_keys("data/api_keys.txt") + insert_data = json_load_data("./tests/data/example_data_1.json") + + request_headers = {"API-KEY": api_keys[-1]} + req = requests.post( + f"{BASE_URL}/sc/v0", + json=insert_data, + headers=request_headers, + timeout=5, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 200) + key = json.loads(req.text)["_id"] + + # Test no key + req = requests.delete(f"{BASE_URL}/sc/v0/{key}", timeout=5, verify="./data/collector_root_ca.crt") + self.assertTrue(req.status_code == 401) + + # Test wrong api key + request_headers = {"API-KEY": "dummy"} + req = requests.delete( + f"{BASE_URL}/sc/v0/{key}", headers=request_headers, timeout=4, verify="./data/collector_root_ca.crt" + ) + self.assertTrue(req.status_code == 401) + + # OK api key + request_headers = {"API-KEY": api_keys[0]} + req = requests.delete( + f"{BASE_URL}/sc/v0/{key}", headers=request_headers, timeout=4, verify="./data/collector_root_ca.crt" + ) + self.assertTrue(req.status_code == 200) + + def test_auth_search(self) -> None: + """ + Test auth search + """ + + api_keys = load_api_keys("data/api_keys.txt") + insert_data = json_load_data("./tests/data/example_data_1.json") + insert_data["timestamp"] = "2021-06-21T15:06:00Z" + + request_headers = {"API-KEY": api_keys[-1]} + req = requests.post( + f"{BASE_URL}/sc/v0", + json=insert_data, + headers=request_headers, + timeout=5, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 200) + key = json.loads(req.text)["_id"] + search_data = {"filter": {"timestamp": insert_data["timestamp"]}} + + # Test no key + req = requests.post( + f"{BASE_URL}/sc/v0/search", json=search_data, timeout=5, verify="./data/collector_root_ca.crt" + ) + self.assertTrue(req.status_code == 401) + + # Test wrong api key + request_headers = {"API-KEY": "dummy"} + req = requests.post( + f"{BASE_URL}/sc/v0/search", + json=search_data, + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 401) + + # OK api key + request_headers = {"API-KEY": api_keys[0]} + req = requests.post( + f"{BASE_URL}/sc/v0/search", + json=search_data, + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 200) + + # Delete test data + request_headers = {"API-KEY": api_keys[0]} + req = requests.delete( + f"{BASE_URL}/sc/v0/{key}", headers=request_headers, timeout=4, verify="./data/collector_root_ca.crt" + ) + self.assertTrue(req.status_code == 200) diff --git a/tests/test_delete.py b/tests/test_delete.py new file mode 100644 index 0000000..855d987 --- /dev/null +++ b/tests/test_delete.py @@ -0,0 +1,86 @@ +""" +Test our delete +""" +import unittest +import json + +import requests + +from src.soc_collector.auth import load_api_keys +from src.soc_collector.soc_collector_cli import json_load_data + + +BASE_URL = "https://localhost:8000" + + +class TestAddress(unittest.TestCase): + """ + Test our delete + """ + + def test_delete(self) -> None: + """ + Test delete + """ + + api_keys = load_api_keys("data/api_keys.txt") + insert_data = json_load_data("./tests/data/example_data_1.json") + + request_headers = {"API-KEY": api_keys[-1]} + req = requests.post( + f"{BASE_URL}/sc/v0", + json=insert_data, + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 200) + key = json.loads(req.text)["_id"] + + req = requests.delete( + f"{BASE_URL}/sc/v0/dummy", + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 400) + + req = requests.delete( + f"{BASE_URL}/sc/v0/63765238890b48a0c3118f4f", + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 404) + + req = requests.get( + f"{BASE_URL}/sc/v0/{key}", + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 200) + + req = requests.delete( + f"{BASE_URL}/sc/v0/{key}", + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 200) + + req = requests.get( + f"{BASE_URL}/sc/v0/{key}", + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 404) + + req = requests.delete( + f"{BASE_URL}/sc/v0/{key}", + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 404) diff --git a/tests/test_get.py b/tests/test_get.py new file mode 100644 index 0000000..377f35a --- /dev/null +++ b/tests/test_get.py @@ -0,0 +1,82 @@ +""" +Test our get +""" +import unittest +import json + +import requests + +from src.soc_collector.auth import load_api_keys +from src.soc_collector.soc_collector_cli import json_load_data + + +BASE_URL = "https://localhost:8000" + + +class TestAddress(unittest.TestCase): + """ + Test our get + """ + + def test_get(self) -> None: + """ + Test get + """ + + api_keys = load_api_keys("data/api_keys.txt") + insert_data = json_load_data("./tests/data/example_data_1.json") + + request_headers = {"API-KEY": api_keys[-1]} + req = requests.post( + f"{BASE_URL}/sc/v0", + json=insert_data, + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 200) + key = json.loads(req.text)["_id"] + + req = requests.get( + f"{BASE_URL}/sc/v0/dummy", + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 400) + + req = requests.get( + f"{BASE_URL}/sc/v0/63765238890b48a0c3118f4f", + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 404) + + req = requests.get( + f"{BASE_URL}/sc/v0/{key}", + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 200) + data1 = json.loads(req.text) + + req = requests.get( + f"{BASE_URL}/sc/v0/{key}", + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 200) + data2 = json.loads(req.text) + self.assertTrue(data1 == data2) + + # Delete test data + req = requests.delete( + f"{BASE_URL}/sc/v0/{key}", + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 200) diff --git a/tests/test_info.py b/tests/test_info.py new file mode 100644 index 0000000..ef775d7 --- /dev/null +++ b/tests/test_info.py @@ -0,0 +1,37 @@ +""" +Test our info +""" +import unittest +import json + +import requests + +from src.soc_collector.auth import load_api_keys + +BASE_URL = "https://localhost:8000" + + +class TestAddress(unittest.TestCase): + """ + Test our info + """ + + def test_info(self) -> None: + """ + Test info + """ + + api_keys = load_api_keys("data/api_keys.txt") + request_headers = {"API-KEY": api_keys[-1]} + + req = requests.get( + f"{BASE_URL}/info", + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 200) + data = json.loads(req.text) + self.assertTrue("Estimated document count" in data) + self.assertTrue(isinstance(data["Estimated document count"], int)) + self.assertTrue(data["Estimated document count"] >= 0) diff --git a/tests/test_insert.py b/tests/test_insert.py new file mode 100644 index 0000000..5bee72d --- /dev/null +++ b/tests/test_insert.py @@ -0,0 +1,110 @@ +""" +Test our insert +""" +import unittest +import json + +import requests + +from src.soc_collector.auth import load_api_keys +from src.soc_collector.soc_collector_cli import json_load_data + + +BASE_URL = "https://localhost:8000" + + +class TestAddress(unittest.TestCase): + """ + Test our insert + """ + + def test_insert(self) -> None: + """ + Test insert + """ + + api_keys = load_api_keys("data/api_keys.txt") + insert_data = json_load_data("./tests/data/example_data_1.json") + + request_headers = {"API-KEY": api_keys[-1]} + req = requests.post( + f"{BASE_URL}/sc/v0", + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 400) + + req = requests.post( + f"{BASE_URL}/sc/v0", + json={"dummy_data": "dummy"}, + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 400) + + bad_data = insert_data.copy() + bad_data["_id"] = "63765238890b49a0c3118f4f" + req = requests.post( + f"{BASE_URL}/sc/v0", + json=bad_data, + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 400) + del bad_data["_id"] + + del bad_data["ip"] + req = requests.post( + f"{BASE_URL}/sc/v0", + json=bad_data, + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 400) + + req = requests.post( + f"{BASE_URL}/sc/v0", + json=insert_data, + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 200) + key = json.loads(req.text)["_id"] + + # Allow duplicate data but with different id + req = requests.post( + f"{BASE_URL}/sc/v0", + json=insert_data, + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 200) + key2 = json.loads(req.text)["_id"] + + req = requests.get( + f"{BASE_URL}/sc/v0/{key}", + json=insert_data, + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 200) + db_data = json.loads(req.text)["doc"] + del db_data["_id"] + self.assertTrue(insert_data == db_data) + + # Delete test data + req = requests.delete( + f"{BASE_URL}/sc/v0/{key}", headers=request_headers, timeout=4, verify="./data/collector_root_ca.crt" + ) + self.assertTrue(req.status_code == 200) + req = requests.delete( + f"{BASE_URL}/sc/v0/{key2}", headers=request_headers, timeout=4, verify="./data/collector_root_ca.crt" + ) + self.assertTrue(req.status_code == 200) diff --git a/tests/test_replace.py b/tests/test_replace.py new file mode 100644 index 0000000..291207f --- /dev/null +++ b/tests/test_replace.py @@ -0,0 +1,102 @@ +""" +Test our replace +""" +import unittest +import json + +import requests + +from src.soc_collector.auth import load_api_keys +from src.soc_collector.soc_collector_cli import json_load_data + + +BASE_URL = "https://localhost:8000" + + +class TestAddress(unittest.TestCase): + """ + Test our replace + """ + + def test_replace(self) -> None: + """ + Test replace + """ + + api_keys = load_api_keys("data/api_keys.txt") + insert_data = json_load_data("./tests/data/example_data_1.json") + replace_data = json_load_data("./tests/data/example_data_1_replace_test.json") + + request_headers = {"API-KEY": api_keys[-1]} + req = requests.post( + f"{BASE_URL}/sc/v0", + json=insert_data, + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 200) + replace_data["_id"] = json.loads(req.text)["_id"] + + bad_data = replace_data.copy() + + # ip missing + del bad_data["ip"] + req = requests.put( + f"{BASE_URL}/sc/v0", + json=bad_data, + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 400) + + bad_data["ip"] = replace_data["ip"] + del bad_data["_id"] + req = requests.put( + f"{BASE_URL}/sc/v0", + json=bad_data, + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 400) + + bad_data["_id"] = "sdvnsvdlac" + req = requests.put( + f"{BASE_URL}/sc/v0", + json=bad_data, + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 400) + + req = requests.put( + f"{BASE_URL}/sc/v0", + json=replace_data, + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 200) + + req = requests.get( + f"{BASE_URL}/sc/v0/{replace_data['_id']}", + json=insert_data, + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 200) + db_data = json.loads(req.text)["doc"] + self.assertTrue(replace_data == db_data) + + # Delete test data + req = requests.delete( + f"{BASE_URL}/sc/v0/{replace_data['_id']}", + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 200) diff --git a/tests/test_search.py b/tests/test_search.py new file mode 100644 index 0000000..8a01332 --- /dev/null +++ b/tests/test_search.py @@ -0,0 +1,125 @@ +""" +Test our search +""" +from typing import Dict, Any +import unittest +import json + +import requests + +from src.soc_collector.auth import load_api_keys +from src.soc_collector.soc_collector_cli import json_load_data + + +BASE_URL = "https://localhost:8000" + + +class TestAddress(unittest.TestCase): + """ + Test our search + """ + + def test_search(self) -> None: + """ + Test search + """ + + api_keys = load_api_keys("data/api_keys.txt") + insert_data = json_load_data("./tests/data/example_data_1.json") + insert_data["ip"] = "test_dummy_ip1" + + request_headers = {"API-KEY": api_keys[-1]} + req = requests.post( + f"{BASE_URL}/sc/v0", + json=insert_data, + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 200) + key1 = json.loads(req.text)["_id"] + + insert_data["port"] = 4123 + req = requests.post( + f"{BASE_URL}/sc/v0", + json=insert_data, + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 200) + key2 = json.loads(req.text)["_id"] + + req = requests.post( + f"{BASE_URL}/sc/v0/search", + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 422) + + search_data: Dict[str, Any] = {"dummy": {"ip": "test_dummy_ip"}} + req = requests.post( + f"{BASE_URL}/sc/v0/search", + json=search_data, + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 422) + + search_data = {"filter": {"ip": "test_dummy_ip"}} + req = requests.post( + f"{BASE_URL}/sc/v0/search", + json=search_data, + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 200) + data = json.loads(req.text)["docs"] + self.assertTrue(data == []) + + search_data = {"filter": {"ip": "test_dummy_ip1"}} + req = requests.post( + f"{BASE_URL}/sc/v0/search", + json=search_data, + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 200) + data = json.loads(req.text)["docs"] + self.assertTrue(len(data) == 2) + del data[0]["_id"] + del data[1]["_id"] + self.assertTrue(data[0] != insert_data) + self.assertTrue(data[1] == insert_data) + + search_data = {"filter": {"ip": "test_dummy_ip1", "port": 4123}} + req = requests.post( + f"{BASE_URL}/sc/v0/search", + json=search_data, + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 200) + data = json.loads(req.text)["docs"] + self.assertTrue(len(data) == 1) + + # Delete test data + req = requests.delete( + f"{BASE_URL}/sc/v0/{key1}", + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 200) + req = requests.delete( + f"{BASE_URL}/sc/v0/{key2}", + headers=request_headers, + timeout=4, + verify="./data/collector_root_ca.crt", + ) + self.assertTrue(req.status_code == 200) diff --git a/tests/test_soc_collector_cli.py b/tests/test_soc_collector_cli.py new file mode 100644 index 0000000..f867d60 --- /dev/null +++ b/tests/test_soc_collector_cli.py @@ -0,0 +1,244 @@ +""" +Test our cli +""" +import unittest +import io +import sys +import json +import os + +from src.soc_collector.auth import load_api_keys + +from src.soc_collector.soc_collector_cli import ( + json_load_data, + info_action, + get_action, + replace_action, + delete_action, + insert_action, + update_local_action, + search_action, +) + +BASE_URL = "https://localhost:8000" + + +class TestAddress(unittest.TestCase): + """ + Test our cli + """ + + def test_json_load_data(self) -> None: + """ + Test cli json_load_data + """ + + insert_data = json_load_data("./tests/data/example_data_1.json") + with open("./json_load_data_test_here11.json", "w", encoding="utf-8") as f_data: + f_data.write(json.dumps(insert_data)) + + data = json_load_data("./json_load_data_test_here11.json") + self.assertTrue(isinstance(data, dict)) + + data = json_load_data('{"filter": {"ip": "123.123.4.123"}}') + self.assertTrue(isinstance(data, dict)) + + # Remove test data + os.remove("./json_load_data_test_here11.json") + + def test_cli_info(self) -> None: + """ + Test cli info + """ + + api_keys = load_api_keys("data/api_keys.txt") + output = io.StringIO() + sys.stdout = output + info_action(api_keys[-1], BASE_URL) + self.assertTrue("Estimated document count: " in output.getvalue()) + sys.stdout = sys.__stdout__ + + def test_cli_insert(self) -> None: + """ + Test cli insert + """ + + api_keys = load_api_keys("data/api_keys.txt") + insert_data = json_load_data("./tests/data/example_data_1.json") + + output = io.StringIO() + sys.stdout = output + insert_action(json.dumps(insert_data), api_keys[-1], BASE_URL) + self.assertTrue("Inserted data OK - key: " in output.getvalue()) + sys.stdout = sys.__stdout__ + key1 = output.getvalue().split(" OK - key: ")[1].strip() + + output = io.StringIO() + sys.stdout = output + insert_action("./tests/data/example_data_1.json", api_keys[-1], BASE_URL) + self.assertTrue("Inserted data OK - key: " in output.getvalue()) + sys.stdout = sys.__stdout__ + key2 = output.getvalue().split(" OK - key: ")[1].strip() + + # Delete test data + output = io.StringIO() + sys.stdout = output + delete_action(key1, api_keys[-1], BASE_URL) + delete_action(key2, api_keys[-1], BASE_URL) + sys.stdout = sys.__stdout__ + + def test_cli_replace(self) -> None: + """ + Test cli insert + """ + + api_keys = load_api_keys("data/api_keys.txt") + insert_data = json_load_data("./tests/data/example_data_1.json") + replace_vals = {"ip": "replace_test1"} + replace_data = insert_data.copy() + replace_data.update(replace_vals) + + output = io.StringIO() + sys.stdout = output + insert_action(json.dumps(insert_data), api_keys[-1], BASE_URL) + self.assertTrue("Inserted data OK - key: " in output.getvalue()) + sys.stdout = sys.__stdout__ + key1 = output.getvalue().split(" OK - key: ")[1].strip() + + output = io.StringIO() + sys.stdout = output + replace_data["_id"] = key1 + replace_action(json.dumps(replace_data), api_keys[-1], BASE_URL) + self.assertTrue("Replaced data OK - key: " in output.getvalue()) + sys.stdout = sys.__stdout__ + key2 = output.getvalue().split(" OK - key: ")[1].strip() + self.assertTrue(key1 == key2) + + replace_data["ip"] = "replace_test2" + with open("./replace_test_here11.json", "w", encoding="utf-8") as f_data: + f_data.write(json.dumps(replace_data)) + + output = io.StringIO() + sys.stdout = output + replace_action("./replace_test_here11.json", api_keys[-1], BASE_URL) + self.assertTrue("Replaced data OK - key: " in output.getvalue()) + sys.stdout = sys.__stdout__ + key3 = output.getvalue().split(" OK - key: ")[1].strip() + self.assertTrue(key1 == key2 == key3) + + # Delete test data + os.remove("./replace_test_here11.json") + output = io.StringIO() + sys.stdout = output + delete_action(key1, api_keys[-1], BASE_URL) + sys.stdout = sys.__stdout__ + + def test_cli_update_local(self) -> None: + """ + Test cli update_local + """ + + data = json_load_data("./tests/data/example_data_1.json") + update_vals = {"ip": "update_local_test1"} + updated_data = data.copy() + updated_data.update(update_vals) + + output = io.StringIO() + sys.stdout = output + update_local_action(json.dumps(data), json.dumps(update_vals)) + self.assertTrue(json.dumps(updated_data, indent=4) in output.getvalue()) + sys.stdout = sys.__stdout__ + + with open("./update_local_test_here11.json", "w", encoding="utf-8") as f_data: + f_data.write(json.dumps(data)) + + output = io.StringIO() + sys.stdout = output + update_local_action("./update_local_test_here11.json", json.dumps(update_vals)) + os.remove("./update_local_test_here11.json") + self.assertTrue(json.dumps(updated_data, indent=4) in output.getvalue()) + sys.stdout = sys.__stdout__ + + def test_cli_get(self) -> None: + """ + Test cli get + """ + + api_keys = load_api_keys("data/api_keys.txt") + insert_data = json_load_data("./tests/data/example_data_1.json") + + output = io.StringIO() + sys.stdout = output + insert_action(json.dumps(insert_data), api_keys[-1], BASE_URL) + self.assertTrue("Inserted data OK - key: " in output.getvalue()) + sys.stdout = sys.__stdout__ + key1 = output.getvalue().split(" OK - key: ")[1].strip() + + output = io.StringIO() + sys.stdout = output + get_action(key1, api_keys[-1], BASE_URL) + expected_data = {"_id": key1} + expected_data.update(insert_data) + self.assertTrue(json.dumps(expected_data, indent=4) in output.getvalue()) + sys.stdout = sys.__stdout__ + + with open("./get_test_here11.json", "w", encoding="utf-8") as f_data: + f_data.write(json.dumps(expected_data)) + + output = io.StringIO() + sys.stdout = output + get_action("./get_test_here11.json", api_keys[-1], BASE_URL) + expected_data = {"_id": key1} + expected_data.update(insert_data) + self.assertTrue(json.dumps(expected_data, indent=4) in output.getvalue()) + sys.stdout = sys.__stdout__ + + # Delete test data + os.remove("./get_test_here11.json") + output = io.StringIO() + sys.stdout = output + delete_action(key1, api_keys[-1], BASE_URL) + sys.stdout = sys.__stdout__ + + def test_cli_search(self) -> None: + """ + Test cli search + """ + + api_keys = load_api_keys("data/api_keys.txt") + insert_data = json_load_data("./tests/data/example_data_1.json") + insert_data["ip"] = "search_dummy_ip1" + search_data = {"filter": {"ip": "search_dummy_ip1"}} + + output = io.StringIO() + sys.stdout = output + insert_action(json.dumps(insert_data), api_keys[-1], BASE_URL) + self.assertTrue("Inserted data OK - key: " in output.getvalue()) + sys.stdout = sys.__stdout__ + key1 = output.getvalue().split(" OK - key: ")[1].strip() + + output = io.StringIO() + sys.stdout = output + search_action(json.dumps(search_data), api_keys[-1], BASE_URL) + expected_data = {"_id": key1} + expected_data.update(insert_data) + self.assertTrue(json.dumps([expected_data], indent=4) in output.getvalue()[:-1]) + sys.stdout = sys.__stdout__ + + with open("./search_test_here11.json", "w", encoding="utf-8") as f_data: + f_data.write(json.dumps(search_data)) + + output = io.StringIO() + sys.stdout = output + search_action("./search_test_here11.json", api_keys[-1], BASE_URL) + expected_data = {"_id": key1} + expected_data.update(insert_data) + self.assertTrue(json.dumps([expected_data], indent=4) in output.getvalue()[:-1]) + sys.stdout = sys.__stdout__ + + # Delete test data + os.remove("./search_test_here11.json") + output = io.StringIO() + sys.stdout = output + delete_action(key1, api_keys[-1], BASE_URL) + sys.stdout = sys.__stdout__ |