diff options
author | Victor Näslund <victor@sunet.se> | 2022-11-01 01:55:25 +0100 |
---|---|---|
committer | Victor Näslund <victor@sunet.se> | 2022-11-01 01:55:25 +0100 |
commit | ffb26f4a81a9ca61c4105df037f7e1beb8dc5fb0 (patch) | |
tree | 41094f051edbf300a6cd2c2de8dfb8435bfc18a4 /src | |
parent | 1b836e78db2737ba5d1ae43da9828601a5a5c114 (diff) |
initial fresh up
Diffstat (limited to 'src')
-rw-r--r-- | src/couch/client.py | 21 | ||||
-rw-r--r-- | src/couch/feedreader.py | 10 | ||||
-rw-r--r-- | src/couch/resource.py | 27 | ||||
-rw-r--r-- | src/couch/utils.py | 58 | ||||
-rwxr-xr-x | src/db.py | 22 | ||||
-rwxr-xr-x | src/main.py | 45 | ||||
-rw-r--r-- | src/schema.py | 11 |
7 files changed, 109 insertions, 85 deletions
diff --git a/src/couch/client.py b/src/couch/client.py index 188e0de..52477be 100644 --- a/src/couch/client.py +++ b/src/couch/client.py @@ -17,7 +17,7 @@ from couch.resource import Resource DEFAULT_BASE_URL = os.environ.get('COUCHDB_URL', 'http://localhost:5984/') -def _id_to_path(_id): +def _id_to_path(_id: str) -> str: if _id[:1] == "_": return _id.split("/", 1) return [_id] @@ -360,7 +360,7 @@ class Database(object): revision. :returns: doc """ - + _doc = copy.copy(doc) if "_id" not in _doc: _doc['_id'] = uuid.uuid4().hex @@ -371,12 +371,27 @@ class Database(object): params = {} data = utils.force_bytes(json.dumps(_doc)) + + print("gg1", flush=True) + print(data, flush=True) + print("vv1", flush=True) + (resp, result) = self.resource(_doc['_id']).put( data=data, params=params) + print("gg3", flush=True) + print(resp.status_code) + print(resp.content) + #print(resp.contents) + + print("gg2", flush=True) + print(data, flush=True) + print(result, flush=True) + print("vv2", flush=True) + if resp.status_code == 409: raise exp.Conflict(result['reason']) - + if "rev" in result and result["rev"] is not None: _doc["_rev"] = result["rev"] diff --git a/src/couch/feedreader.py b/src/couch/feedreader.py index e293932..98401ab 100644 --- a/src/couch/feedreader.py +++ b/src/couch/feedreader.py @@ -1,13 +1,13 @@ # -*- coding: utf-8 -*- # Based on py-couchdb (https://github.com/histrio/py-couchdb) +from __future__ import annotations - -class BaseFeedReader(object): +class BaseFeedReader: """ Base interface class for changes feed reader. """ - def __call__(self, db): + def __call__(self, db) -> BaseFeedReader: self.db = db return self @@ -44,9 +44,9 @@ class SimpleFeedReader(BaseFeedReader): a valid feed reader interface. """ - def __call__(self, db, callback): + def __call__(self, db, callback) -> BaseFeedReader: self.callback = callback return super(SimpleFeedReader, self).__call__(db) - def on_message(self, message): + def on_message(self, message) -> None: self.callback(message, db=self.db) diff --git a/src/couch/resource.py b/src/couch/resource.py index da1e0dd..364bff4 100644 --- a/src/couch/resource.py +++ b/src/couch/resource.py @@ -3,17 +3,20 @@ from __future__ import unicode_literals +from typing import Union, Tuple import json import requests + + from couch import utils from couch import exceptions -class Resource(object): - def __init__(self, base_url, full_commit=True, session=None, - credentials=None, authmethod="session", verify=False): +class Resource: + def __init__(self, base_url: str, full_commit: bool = True, session: Union[requests.sessions.Session, None] = None, + credentials: Union[Tuple[str, str], None] = None, authmethod: str = "session", verify: bool = False) -> None: self.base_url = base_url # self.verify = verify @@ -31,7 +34,7 @@ class Resource(object): self.session = session self.session.verify = verify - def _authenticate(self, credentials, method): + def _authenticate(self, credentials: Union[Tuple[str, str], None], method: str) -> None: if not credentials: return @@ -50,11 +53,11 @@ class Resource(object): else: raise RuntimeError("Invalid authentication method") - def __call__(self, *path): + def __call__(self, *path: str): base_url = utils.urljoin(self.base_url, *path) return self.__class__(base_url, session=self.session) - def _check_result(self, response, result): + def _check_result(self, response, result) -> None: try: error = result.get('error', None) reason = result.get('reason', None) @@ -74,7 +77,7 @@ class Resource(object): raise exceptions.BadRequest(reason or "Bad request") raise exceptions.GenericError(result) - def request(self, method, path, params=None, data=None, + def request(self, method, path: str, params=None, data=None, headers=None, stream=False, **kwargs): if headers is None: @@ -112,17 +115,17 @@ class Resource(object): return response, result - def get(self, path=None, **kwargs): + def get(self, path: Union[str, None] = None, **kwargs): return self.request("GET", path, **kwargs) - def put(self, path=None, **kwargs): + def put(self, path: Union[str, None] = None, **kwargs): return self.request("PUT", path, **kwargs) - def post(self, path=None, **kwargs): + def post(self, path: Union[str, None] = None, **kwargs): return self.request("POST", path, **kwargs) - def delete(self, path=None, **kwargs): + def delete(self, path: Union[str, None] = None, **kwargs): return self.request("DELETE", path, **kwargs) - def head(self, path=None, **kwargs): + def head(self, path: Union[str, None] = None, **kwargs): return self.request("HEAD", path, **kwargs) diff --git a/src/couch/utils.py b/src/couch/utils.py index 1cd21d8..f0883a6 100644 --- a/src/couch/utils.py +++ b/src/couch/utils.py @@ -1,26 +1,15 @@ # -*- coding: utf-8 -*- # Based on py-couchdb (https://github.com/histrio/py-couchdb) - +from typing import Tuple, Union, Dict, List, Any import json import sys +import requests +from urllib.parse import unquote as _unquote +from urllib.parse import urlunsplit, urlsplit -if sys.version_info[0] == 3: - from urllib.parse import unquote as _unquote - from urllib.parse import urlunsplit, urlsplit - - string_type = str - bytes_type = bytes - - from functools import reduce - -else: - from urllib import unquote as _unquote - from urlparse import urlunsplit, urlsplit - - string_type = unicode # noqa: F821 - bytes_type = str +from functools import reduce URLSPLITTER = '/' @@ -28,7 +17,7 @@ URLSPLITTER = '/' json_encoder = json.JSONEncoder() -def extract_credentials(url): +def extract_credentials(url: str) -> Tuple[str, Union[Tuple[str, str], None]]: """ Extract authentication (user name and password) credentials from the given URL. @@ -46,19 +35,20 @@ def extract_credentials(url): if '@' in netloc: creds, netloc = netloc.split('@') credentials = tuple(_unquote(i) for i in creds.split(':')) - parts = list(parts) - parts[1] = netloc - else: - credentials = None - return urlunsplit(parts), credentials + parts_list = list(parts) + parts_list[1] = netloc + return urlunsplit(parts_list), (credentials[0], credentials[1]) + + parts_list = list(parts) + return urlunsplit(parts_list), None -def _join(head, tail): +def _join(head: str, tail: str) -> str: parts = [head.rstrip(URLSPLITTER), tail.lstrip(URLSPLITTER)] return URLSPLITTER.join(parts) -def urljoin(base, *path): +def urljoin(base: str, *path: str) -> str: """ Assemble a uri based on a base, any number of path segments, and query string parameters. @@ -87,18 +77,18 @@ def urljoin(base, *path): """ return reduce(_join, path, base) - -def as_json(response): +# Probably bugs here +def as_json(response: requests.models.Response) -> Union[Dict[str, Any], None]: if "application/json" in response.headers['content-type']: response_src = response.content.decode('utf-8') if response.content != b'': - return json.loads(response_src) + ret: Dict[str, Any] = json.loads(response_src) + return ret else: return response_src return None - -def _path_from_name(name, type): +def _path_from_name(name: str, type: str) -> List[str]: """ Expand a 'design/foo' style name to its full path as a list of segments. @@ -114,7 +104,7 @@ def _path_from_name(name, type): return ['_design', design, type, name] -def encode_view_options(options): +def encode_view_options(options: Dict[str, Any]) -> Dict[str, Any]: """ Encode any items in the options dict that are sent as a JSON string to a view/list function. @@ -138,13 +128,13 @@ def encode_view_options(options): return retval -def force_bytes(data, encoding="utf-8"): - if isinstance(data, string_type): +def force_bytes(data: Union[str, bytes], encoding: str = "utf-8") -> bytes: + if isinstance(data, str): data = data.encode(encoding) return data -def force_text(data, encoding="utf-8"): - if isinstance(data, bytes_type): +def force_text(data: Union[str, bytes], encoding: str = "utf-8") -> str: + if isinstance(data, bytes): data = data.decode(encoding) return data @@ -7,6 +7,7 @@ # value if you're too quick with generating the timestamps, ie # invoking time.time() several times quickly enough. +from typing import Dict, List, Tuple, Union, Any import os import sys import time @@ -16,7 +17,7 @@ from schema import as_index_list, validate_collector_data class DictDB(): - def __init__(self): + def __init__(self) -> None: """ Check if the database exists, otherwise we will create it together with the indexes specified in index.py. @@ -35,7 +36,7 @@ class DictDB(): if 'COUCHDB_PORT' in os.environ: couchdb_port = os.environ['COUCHDB_PORT'] else: - couchdb_port = 5984 + couchdb_port = "5984" self.server = couch.client.Server( f"http://{self.username}:{self.password}@{self.hostname}:{couchdb_port}/") @@ -52,7 +53,7 @@ class DictDB(): self._ts = time.time() - def unique_key(self): + def unique_key(self) -> int: """ Create a unique key based on the current time. We will use this as the ID for any new documents we store in CouchDB. @@ -65,18 +66,19 @@ class DictDB(): return self._ts - def add(self, data, batch_write=False): + # Why batch_write??? + def add(self, data: Union[List[Dict[str, Any]], Dict[str, Any]]) -> Union[str, Tuple[str, str]]: """ Store a document in CouchDB. """ - if type(data) is list: + if isinstance(data, List): for item in data: error = validate_collector_data(item) if error != "": return error item['_id'] = str(self.unique_key()) - ret = self.couchdb.save_bulk(data) + ret: Tuple[str, str] = self.couchdb.save_bulk(data) else: error = validate_collector_data(data) if error != "": @@ -86,13 +88,13 @@ class DictDB(): return ret - def get(self, key): + def get(self, key: int) -> Dict[str, Any]: """ Get a document based on its ID, return an empty dict if not found. """ try: - doc = self.couchdb.get(key) + doc: Dict[str, Any] = self.couchdb.get(key) except couch.exceptions.NotFound: doc = {} @@ -101,7 +103,7 @@ class DictDB(): def slice(self, key_from=None, key_to=None): pass - def search(self, limit=25, skip=0, **kwargs): + def search(self, limit: int = 25, skip: int = 0, **kwargs: Any) -> List[Dict[str, Any]]: """ Execute a Mango query, ideally we should have an index matching the query otherwise things will be slow. @@ -134,7 +136,7 @@ class DictDB(): return data - def delete(self, key): + def delete(self, key: int) -> Union[int, None]: """ Delete a document based on its ID. """ diff --git a/src/main.py b/src/main.py index 9de8eb8..2730b83 100755 --- a/src/main.py +++ b/src/main.py @@ -1,3 +1,4 @@ +from typing import Dict, Union, List, Any import json import os import sys @@ -48,7 +49,7 @@ else: sys.exit(-1) -def get_pubkey(): +def get_pubkey() -> str: try: if 'JWT_PUBKEY_PATH' in os.environ: keypath = os.environ['JWT_PUBKEY_PATH'] @@ -64,12 +65,18 @@ def get_pubkey(): return pubkey -def get_data(key=None, limit=25, skip=0, ip=None, - port=None, asn=None, domain=None): + +def get_data(key: Union[int, None] = None, + limit: int = 25, + skip: int = 0, + ip: Union[str, None] = None, + port: Union[int, None] = None, + asn: Union[str, None] = None, + domain: Union[str, None] = None) -> List[Dict[str, Any]]: if key: - return db.get(key) + return [db.get(key)] - selectors = dict() + selectors: Dict[str, Any] = {} indexes = get_index_keys() selectors['domain'] = domain @@ -80,7 +87,7 @@ def get_data(key=None, limit=25, skip=0, ip=None, if asn and 'asn' in indexes: selectors['asn'] = asn - data = db.search(**selectors, limit=limit, skip=skip) + data: List[Dict[str, Any]] = db.search(**selectors, limit=limit, skip=skip) return data @@ -96,21 +103,20 @@ def jwt_config(): @app.exception_handler(AuthJWTException) -def authjwt_exception_handler(request: Request, exc: AuthJWTException): +def authjwt_exception_handler(request: Request, exc: AuthJWTException) -> JSONResponse: return JSONResponse(content={"status": "error", "message": exc.message}, status_code=400) @app.exception_handler(RuntimeError) -def app_exception_handler(request: Request, exc: RuntimeError): +def app_exception_handler(request: Request, exc: RuntimeError) -> JSONResponse: return JSONResponse(content={"status": "error", "message": str(exc.with_traceback(None))}, status_code=400) @app.get('/sc/v0/get') -async def get(key=None, limit=25, skip=0, ip=None, port=None, - asn=None, Authorize: AuthJWT = Depends()): +async def get(key: Union[int, None] = None, limit: int = 25, skip: int = 0, ip: Union[str, None] = None, port: Union[int, None] = None, asn: Union[str, None] = None, Authorize: AuthJWT = Depends()) -> JSONResponse: Authorize.jwt_required() @@ -135,7 +141,7 @@ async def get(key=None, limit=25, skip=0, ip=None, port=None, @app.get('/sc/v0/get/{key}') -async def get_key(key=None, Authorize: AuthJWT = Depends()): +async def get_key(key: Union[int, None] = None, Authorize: AuthJWT = Depends()) -> JSONResponse: Authorize.jwt_required() @@ -152,7 +158,10 @@ async def get_key(key=None, Authorize: AuthJWT = Depends()): else: allowed_domains = raw_jwt["read"] - data = get_data(key) + data_list = get_data(key) + + # Handle if missing + data = data_list[0] if data and data["domain"] not in allowed_domains: return JSONResponse( @@ -166,8 +175,9 @@ async def get_key(key=None, Authorize: AuthJWT = Depends()): return JSONResponse(content={"status": "success", "docs": data}) +# WHY IS AUTH OUTCOMMENTED??? @app.post('/sc/v0/add') -async def add(data: Request, Authorize: AuthJWT = Depends()): +async def add(data: Request, Authorize: AuthJWT = Depends()) -> JSONResponse: # Authorize.jwt_required() try: @@ -196,7 +206,7 @@ async def add(data: Request, Authorize: AuthJWT = Depends()): @app.delete('/sc/v0/delete/{key}') -async def delete(key, Authorize: AuthJWT = Depends()): +async def delete(key: int, Authorize: AuthJWT = Depends()) -> JSONResponse: Authorize.jwt_required() @@ -213,7 +223,10 @@ async def delete(key, Authorize: AuthJWT = Depends()): else: allowed_domains = raw_jwt["write"] - data = get_data(key) + data_list = get_data(key) + + # Handle if missing + data = data_list[0] if data and data["domain"] not in allowed_domains: return JSONResponse( @@ -232,7 +245,7 @@ async def delete(key, Authorize: AuthJWT = Depends()): return JSONResponse(content={"status": "success", "docs": data}) -def main(standalone=False): +def main(standalone: bool = False): if not standalone: return app diff --git a/src/schema.py b/src/schema.py index fe2b76c..2b479d2 100644 --- a/src/schema.py +++ b/src/schema.py @@ -1,3 +1,4 @@ +from typing import List, Any, Dict import json import sys import traceback @@ -95,15 +96,15 @@ schema = { # fmt:on -def get_index_keys(): - keys = list() +def get_index_keys() -> List[str]: + keys: List[str] = [] for key in schema["properties"]: keys.append(key) return keys -def as_index_list(): - index_list = list() +def as_index_list() -> List[Dict[str, Any]]: + index_list: List[Dict[str, Any]] = [] for key in schema["properties"]: name = f"{key}-json-index" index = { @@ -120,7 +121,7 @@ def as_index_list(): return index_list -def validate_collector_data(json_blob): +def validate_collector_data(json_blob: Dict[str, Any]) -> str: try: jsonschema.validate(json_blob, schema, format_checker=jsonschema.FormatChecker()) except jsonschema.exceptions.ValidationError as e: |