summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorVictor Näslund <victor@sunet.se>2022-11-01 01:55:25 +0100
committerVictor Näslund <victor@sunet.se>2022-11-01 01:55:25 +0100
commitffb26f4a81a9ca61c4105df037f7e1beb8dc5fb0 (patch)
tree41094f051edbf300a6cd2c2de8dfb8435bfc18a4 /src
parent1b836e78db2737ba5d1ae43da9828601a5a5c114 (diff)
initial fresh up
Diffstat (limited to 'src')
-rw-r--r--src/couch/client.py21
-rw-r--r--src/couch/feedreader.py10
-rw-r--r--src/couch/resource.py27
-rw-r--r--src/couch/utils.py58
-rwxr-xr-xsrc/db.py22
-rwxr-xr-xsrc/main.py45
-rw-r--r--src/schema.py11
7 files changed, 109 insertions, 85 deletions
diff --git a/src/couch/client.py b/src/couch/client.py
index 188e0de..52477be 100644
--- a/src/couch/client.py
+++ b/src/couch/client.py
@@ -17,7 +17,7 @@ from couch.resource import Resource
DEFAULT_BASE_URL = os.environ.get('COUCHDB_URL', 'http://localhost:5984/')
-def _id_to_path(_id):
+def _id_to_path(_id: str) -> str:
if _id[:1] == "_":
return _id.split("/", 1)
return [_id]
@@ -360,7 +360,7 @@ class Database(object):
revision.
:returns: doc
"""
-
+
_doc = copy.copy(doc)
if "_id" not in _doc:
_doc['_id'] = uuid.uuid4().hex
@@ -371,12 +371,27 @@ class Database(object):
params = {}
data = utils.force_bytes(json.dumps(_doc))
+
+ print("gg1", flush=True)
+ print(data, flush=True)
+ print("vv1", flush=True)
+
(resp, result) = self.resource(_doc['_id']).put(
data=data, params=params)
+ print("gg3", flush=True)
+ print(resp.status_code)
+ print(resp.content)
+ #print(resp.contents)
+
+ print("gg2", flush=True)
+ print(data, flush=True)
+ print(result, flush=True)
+ print("vv2", flush=True)
+
if resp.status_code == 409:
raise exp.Conflict(result['reason'])
-
+
if "rev" in result and result["rev"] is not None:
_doc["_rev"] = result["rev"]
diff --git a/src/couch/feedreader.py b/src/couch/feedreader.py
index e293932..98401ab 100644
--- a/src/couch/feedreader.py
+++ b/src/couch/feedreader.py
@@ -1,13 +1,13 @@
# -*- coding: utf-8 -*-
# Based on py-couchdb (https://github.com/histrio/py-couchdb)
+from __future__ import annotations
-
-class BaseFeedReader(object):
+class BaseFeedReader:
"""
Base interface class for changes feed reader.
"""
- def __call__(self, db):
+ def __call__(self, db) -> BaseFeedReader:
self.db = db
return self
@@ -44,9 +44,9 @@ class SimpleFeedReader(BaseFeedReader):
a valid feed reader interface.
"""
- def __call__(self, db, callback):
+ def __call__(self, db, callback) -> BaseFeedReader:
self.callback = callback
return super(SimpleFeedReader, self).__call__(db)
- def on_message(self, message):
+ def on_message(self, message) -> None:
self.callback(message, db=self.db)
diff --git a/src/couch/resource.py b/src/couch/resource.py
index da1e0dd..364bff4 100644
--- a/src/couch/resource.py
+++ b/src/couch/resource.py
@@ -3,17 +3,20 @@
from __future__ import unicode_literals
+from typing import Union, Tuple
import json
import requests
+
+
from couch import utils
from couch import exceptions
-class Resource(object):
- def __init__(self, base_url, full_commit=True, session=None,
- credentials=None, authmethod="session", verify=False):
+class Resource:
+ def __init__(self, base_url: str, full_commit: bool = True, session: Union[requests.sessions.Session, None] = None,
+ credentials: Union[Tuple[str, str], None] = None, authmethod: str = "session", verify: bool = False) -> None:
self.base_url = base_url
# self.verify = verify
@@ -31,7 +34,7 @@ class Resource(object):
self.session = session
self.session.verify = verify
- def _authenticate(self, credentials, method):
+ def _authenticate(self, credentials: Union[Tuple[str, str], None], method: str) -> None:
if not credentials:
return
@@ -50,11 +53,11 @@ class Resource(object):
else:
raise RuntimeError("Invalid authentication method")
- def __call__(self, *path):
+ def __call__(self, *path: str):
base_url = utils.urljoin(self.base_url, *path)
return self.__class__(base_url, session=self.session)
- def _check_result(self, response, result):
+ def _check_result(self, response, result) -> None:
try:
error = result.get('error', None)
reason = result.get('reason', None)
@@ -74,7 +77,7 @@ class Resource(object):
raise exceptions.BadRequest(reason or "Bad request")
raise exceptions.GenericError(result)
- def request(self, method, path, params=None, data=None,
+ def request(self, method, path: str, params=None, data=None,
headers=None, stream=False, **kwargs):
if headers is None:
@@ -112,17 +115,17 @@ class Resource(object):
return response, result
- def get(self, path=None, **kwargs):
+ def get(self, path: Union[str, None] = None, **kwargs):
return self.request("GET", path, **kwargs)
- def put(self, path=None, **kwargs):
+ def put(self, path: Union[str, None] = None, **kwargs):
return self.request("PUT", path, **kwargs)
- def post(self, path=None, **kwargs):
+ def post(self, path: Union[str, None] = None, **kwargs):
return self.request("POST", path, **kwargs)
- def delete(self, path=None, **kwargs):
+ def delete(self, path: Union[str, None] = None, **kwargs):
return self.request("DELETE", path, **kwargs)
- def head(self, path=None, **kwargs):
+ def head(self, path: Union[str, None] = None, **kwargs):
return self.request("HEAD", path, **kwargs)
diff --git a/src/couch/utils.py b/src/couch/utils.py
index 1cd21d8..f0883a6 100644
--- a/src/couch/utils.py
+++ b/src/couch/utils.py
@@ -1,26 +1,15 @@
# -*- coding: utf-8 -*-
# Based on py-couchdb (https://github.com/histrio/py-couchdb)
-
+from typing import Tuple, Union, Dict, List, Any
import json
import sys
+import requests
+from urllib.parse import unquote as _unquote
+from urllib.parse import urlunsplit, urlsplit
-if sys.version_info[0] == 3:
- from urllib.parse import unquote as _unquote
- from urllib.parse import urlunsplit, urlsplit
-
- string_type = str
- bytes_type = bytes
-
- from functools import reduce
-
-else:
- from urllib import unquote as _unquote
- from urlparse import urlunsplit, urlsplit
-
- string_type = unicode # noqa: F821
- bytes_type = str
+from functools import reduce
URLSPLITTER = '/'
@@ -28,7 +17,7 @@ URLSPLITTER = '/'
json_encoder = json.JSONEncoder()
-def extract_credentials(url):
+def extract_credentials(url: str) -> Tuple[str, Union[Tuple[str, str], None]]:
"""
Extract authentication (user name and password) credentials from the
given URL.
@@ -46,19 +35,20 @@ def extract_credentials(url):
if '@' in netloc:
creds, netloc = netloc.split('@')
credentials = tuple(_unquote(i) for i in creds.split(':'))
- parts = list(parts)
- parts[1] = netloc
- else:
- credentials = None
- return urlunsplit(parts), credentials
+ parts_list = list(parts)
+ parts_list[1] = netloc
+ return urlunsplit(parts_list), (credentials[0], credentials[1])
+
+ parts_list = list(parts)
+ return urlunsplit(parts_list), None
-def _join(head, tail):
+def _join(head: str, tail: str) -> str:
parts = [head.rstrip(URLSPLITTER), tail.lstrip(URLSPLITTER)]
return URLSPLITTER.join(parts)
-def urljoin(base, *path):
+def urljoin(base: str, *path: str) -> str:
"""
Assemble a uri based on a base, any number of path segments, and query
string parameters.
@@ -87,18 +77,18 @@ def urljoin(base, *path):
"""
return reduce(_join, path, base)
-
-def as_json(response):
+# Probably bugs here
+def as_json(response: requests.models.Response) -> Union[Dict[str, Any], None]:
if "application/json" in response.headers['content-type']:
response_src = response.content.decode('utf-8')
if response.content != b'':
- return json.loads(response_src)
+ ret: Dict[str, Any] = json.loads(response_src)
+ return ret
else:
return response_src
return None
-
-def _path_from_name(name, type):
+def _path_from_name(name: str, type: str) -> List[str]:
"""
Expand a 'design/foo' style name to its full path as a list of
segments.
@@ -114,7 +104,7 @@ def _path_from_name(name, type):
return ['_design', design, type, name]
-def encode_view_options(options):
+def encode_view_options(options: Dict[str, Any]) -> Dict[str, Any]:
"""
Encode any items in the options dict that are sent as a JSON string to a
view/list function.
@@ -138,13 +128,13 @@ def encode_view_options(options):
return retval
-def force_bytes(data, encoding="utf-8"):
- if isinstance(data, string_type):
+def force_bytes(data: Union[str, bytes], encoding: str = "utf-8") -> bytes:
+ if isinstance(data, str):
data = data.encode(encoding)
return data
-def force_text(data, encoding="utf-8"):
- if isinstance(data, bytes_type):
+def force_text(data: Union[str, bytes], encoding: str = "utf-8") -> str:
+ if isinstance(data, bytes):
data = data.decode(encoding)
return data
diff --git a/src/db.py b/src/db.py
index 6f25ec3..5173dda 100755
--- a/src/db.py
+++ b/src/db.py
@@ -7,6 +7,7 @@
# value if you're too quick with generating the timestamps, ie
# invoking time.time() several times quickly enough.
+from typing import Dict, List, Tuple, Union, Any
import os
import sys
import time
@@ -16,7 +17,7 @@ from schema import as_index_list, validate_collector_data
class DictDB():
- def __init__(self):
+ def __init__(self) -> None:
"""
Check if the database exists, otherwise we will create it together
with the indexes specified in index.py.
@@ -35,7 +36,7 @@ class DictDB():
if 'COUCHDB_PORT' in os.environ:
couchdb_port = os.environ['COUCHDB_PORT']
else:
- couchdb_port = 5984
+ couchdb_port = "5984"
self.server = couch.client.Server(
f"http://{self.username}:{self.password}@{self.hostname}:{couchdb_port}/")
@@ -52,7 +53,7 @@ class DictDB():
self._ts = time.time()
- def unique_key(self):
+ def unique_key(self) -> int:
"""
Create a unique key based on the current time. We will use this as
the ID for any new documents we store in CouchDB.
@@ -65,18 +66,19 @@ class DictDB():
return self._ts
- def add(self, data, batch_write=False):
+ # Why batch_write???
+ def add(self, data: Union[List[Dict[str, Any]], Dict[str, Any]]) -> Union[str, Tuple[str, str]]:
"""
Store a document in CouchDB.
"""
- if type(data) is list:
+ if isinstance(data, List):
for item in data:
error = validate_collector_data(item)
if error != "":
return error
item['_id'] = str(self.unique_key())
- ret = self.couchdb.save_bulk(data)
+ ret: Tuple[str, str] = self.couchdb.save_bulk(data)
else:
error = validate_collector_data(data)
if error != "":
@@ -86,13 +88,13 @@ class DictDB():
return ret
- def get(self, key):
+ def get(self, key: int) -> Dict[str, Any]:
"""
Get a document based on its ID, return an empty dict if not found.
"""
try:
- doc = self.couchdb.get(key)
+ doc: Dict[str, Any] = self.couchdb.get(key)
except couch.exceptions.NotFound:
doc = {}
@@ -101,7 +103,7 @@ class DictDB():
def slice(self, key_from=None, key_to=None):
pass
- def search(self, limit=25, skip=0, **kwargs):
+ def search(self, limit: int = 25, skip: int = 0, **kwargs: Any) -> List[Dict[str, Any]]:
"""
Execute a Mango query, ideally we should have an index matching
the query otherwise things will be slow.
@@ -134,7 +136,7 @@ class DictDB():
return data
- def delete(self, key):
+ def delete(self, key: int) -> Union[int, None]:
"""
Delete a document based on its ID.
"""
diff --git a/src/main.py b/src/main.py
index 9de8eb8..2730b83 100755
--- a/src/main.py
+++ b/src/main.py
@@ -1,3 +1,4 @@
+from typing import Dict, Union, List, Any
import json
import os
import sys
@@ -48,7 +49,7 @@ else:
sys.exit(-1)
-def get_pubkey():
+def get_pubkey() -> str:
try:
if 'JWT_PUBKEY_PATH' in os.environ:
keypath = os.environ['JWT_PUBKEY_PATH']
@@ -64,12 +65,18 @@ def get_pubkey():
return pubkey
-def get_data(key=None, limit=25, skip=0, ip=None,
- port=None, asn=None, domain=None):
+
+def get_data(key: Union[int, None] = None,
+ limit: int = 25,
+ skip: int = 0,
+ ip: Union[str, None] = None,
+ port: Union[int, None] = None,
+ asn: Union[str, None] = None,
+ domain: Union[str, None] = None) -> List[Dict[str, Any]]:
if key:
- return db.get(key)
+ return [db.get(key)]
- selectors = dict()
+ selectors: Dict[str, Any] = {}
indexes = get_index_keys()
selectors['domain'] = domain
@@ -80,7 +87,7 @@ def get_data(key=None, limit=25, skip=0, ip=None,
if asn and 'asn' in indexes:
selectors['asn'] = asn
- data = db.search(**selectors, limit=limit, skip=skip)
+ data: List[Dict[str, Any]] = db.search(**selectors, limit=limit, skip=skip)
return data
@@ -96,21 +103,20 @@ def jwt_config():
@app.exception_handler(AuthJWTException)
-def authjwt_exception_handler(request: Request, exc: AuthJWTException):
+def authjwt_exception_handler(request: Request, exc: AuthJWTException) -> JSONResponse:
return JSONResponse(content={"status": "error", "message":
exc.message}, status_code=400)
@app.exception_handler(RuntimeError)
-def app_exception_handler(request: Request, exc: RuntimeError):
+def app_exception_handler(request: Request, exc: RuntimeError) -> JSONResponse:
return JSONResponse(content={"status": "error", "message":
str(exc.with_traceback(None))},
status_code=400)
@app.get('/sc/v0/get')
-async def get(key=None, limit=25, skip=0, ip=None, port=None,
- asn=None, Authorize: AuthJWT = Depends()):
+async def get(key: Union[int, None] = None, limit: int = 25, skip: int = 0, ip: Union[str, None] = None, port: Union[int, None] = None, asn: Union[str, None] = None, Authorize: AuthJWT = Depends()) -> JSONResponse:
Authorize.jwt_required()
@@ -135,7 +141,7 @@ async def get(key=None, limit=25, skip=0, ip=None, port=None,
@app.get('/sc/v0/get/{key}')
-async def get_key(key=None, Authorize: AuthJWT = Depends()):
+async def get_key(key: Union[int, None] = None, Authorize: AuthJWT = Depends()) -> JSONResponse:
Authorize.jwt_required()
@@ -152,7 +158,10 @@ async def get_key(key=None, Authorize: AuthJWT = Depends()):
else:
allowed_domains = raw_jwt["read"]
- data = get_data(key)
+ data_list = get_data(key)
+
+ # Handle if missing
+ data = data_list[0]
if data and data["domain"] not in allowed_domains:
return JSONResponse(
@@ -166,8 +175,9 @@ async def get_key(key=None, Authorize: AuthJWT = Depends()):
return JSONResponse(content={"status": "success", "docs": data})
+# WHY IS AUTH OUTCOMMENTED???
@app.post('/sc/v0/add')
-async def add(data: Request, Authorize: AuthJWT = Depends()):
+async def add(data: Request, Authorize: AuthJWT = Depends()) -> JSONResponse:
# Authorize.jwt_required()
try:
@@ -196,7 +206,7 @@ async def add(data: Request, Authorize: AuthJWT = Depends()):
@app.delete('/sc/v0/delete/{key}')
-async def delete(key, Authorize: AuthJWT = Depends()):
+async def delete(key: int, Authorize: AuthJWT = Depends()) -> JSONResponse:
Authorize.jwt_required()
@@ -213,7 +223,10 @@ async def delete(key, Authorize: AuthJWT = Depends()):
else:
allowed_domains = raw_jwt["write"]
- data = get_data(key)
+ data_list = get_data(key)
+
+ # Handle if missing
+ data = data_list[0]
if data and data["domain"] not in allowed_domains:
return JSONResponse(
@@ -232,7 +245,7 @@ async def delete(key, Authorize: AuthJWT = Depends()):
return JSONResponse(content={"status": "success", "docs": data})
-def main(standalone=False):
+def main(standalone: bool = False):
if not standalone:
return app
diff --git a/src/schema.py b/src/schema.py
index fe2b76c..2b479d2 100644
--- a/src/schema.py
+++ b/src/schema.py
@@ -1,3 +1,4 @@
+from typing import List, Any, Dict
import json
import sys
import traceback
@@ -95,15 +96,15 @@ schema = {
# fmt:on
-def get_index_keys():
- keys = list()
+def get_index_keys() -> List[str]:
+ keys: List[str] = []
for key in schema["properties"]:
keys.append(key)
return keys
-def as_index_list():
- index_list = list()
+def as_index_list() -> List[Dict[str, Any]]:
+ index_list: List[Dict[str, Any]] = []
for key in schema["properties"]:
name = f"{key}-json-index"
index = {
@@ -120,7 +121,7 @@ def as_index_list():
return index_list
-def validate_collector_data(json_blob):
+def validate_collector_data(json_blob: Dict[str, Any]) -> str:
try:
jsonschema.validate(json_blob, schema, format_checker=jsonschema.FormatChecker())
except jsonschema.exceptions.ValidationError as e: