diff options
author | Victor Näslund <victor@sunet.se> | 2022-11-02 15:31:23 +0100 |
---|---|---|
committer | Victor Näslund <victor@sunet.se> | 2022-11-02 15:31:23 +0100 |
commit | 8baecf339e8061160bee519e87ffe837d1525c18 (patch) | |
tree | 22664c10f22382b1d4647b5f2e96bcea4220d879 | |
parent | ffb26f4a81a9ca61c4105df037f7e1beb8dc5fb0 (diff) |
more freshup
-rwxr-xr-x | dev-run.sh | 13 | ||||
-rw-r--r-- | docker/collector/Dockerfile | 46 | ||||
-rw-r--r-- | docker/collector/_dev_dockerfile_dev (renamed from docker/collector/Dockerfile-dev) | 0 | ||||
-rw-r--r-- | docker/collector/supervisord.conf | 2 | ||||
-rw-r--r-- | docker/docker-compose.yaml | 4 | ||||
-rwxr-xr-x | quickstart.sh | 2 | ||||
-rwxr-xr-x | quickstart_test.sh | 67 | ||||
-rw-r--r-- | src/collector/__init__.py | 4 | ||||
-rwxr-xr-x | src/collector/db.py (renamed from src/db.py) | 50 | ||||
-rwxr-xr-x | src/collector/main.py (renamed from src/main.py) | 109 | ||||
-rw-r--r-- | src/collector/py.typed (renamed from src/test/__init__.py) | 0 | ||||
-rw-r--r-- | src/collector/schema.py (renamed from src/schema.py) | 2 | ||||
-rw-r--r-- | src/couch/__init__.py | 2 | ||||
-rw-r--r-- | src/couch/client.py | 82 | ||||
-rw-r--r-- | src/couch/feedreader.py | 8 | ||||
-rw-r--r-- | src/couch/resource.py | 60 | ||||
-rw-r--r-- | src/couch/utils.py | 6 | ||||
-rw-r--r-- | tests/__init__.py | 0 | ||||
-rw-r--r-- | tests/test_api.py (renamed from src/test/test_api.py) | 0 |
19 files changed, 297 insertions, 160 deletions
@@ -1,15 +1,18 @@ #!/bin/bash + +echo "Checking package" +mypy --strict --namespace-packages --ignore-missing-imports --cache-dir=/dev/null src/collector/*.py # || exit 1 +black --line-length 120 src/collector/*.py # || exit 1 +pylint --max-line-length 120 src/collector/*.py # || exit 1 + + bash quickstart.sh -b || exit 1 -sleep 2 +sleep 3 JWT=$(curl -k http://localhost:8000/api/v1.0/auth -X POST -p -u usr:pwd | jq -r .access_token) || exit 1 curl -k --data-binary @example_data_1.json -H "Authorization: Bearer $JWT" https://localhost:1443/sc/v0/add || exit 1 exit 0 -echo "Checking package" -mypy --strict --namespace-packages --ignore-missing-imports --cache-dir=/dev/null src/*.py || exit 1 -black --line-length 120 src/*.py || exit 1 -pylint --max-line-length 120 src/*.py || exit 1 echo "Checking tests" #mypy --strict --namespace-packages --ignore-missing-imports --cache-dir=/dev/null tests/*.py || exit 1 diff --git a/docker/collector/Dockerfile b/docker/collector/Dockerfile index a62d78e..099bc0a 100644 --- a/docker/collector/Dockerfile +++ b/docker/collector/Dockerfile @@ -1,23 +1,49 @@ FROM debian:bullseye-20221024-slim@sha256:76cdda8fe5eb597ef5e712e4c9a9f5f1fb119e69f353daaa7bd6d0f6e66e541d # FROM debian:bullseye +# ENV DEBIAN_FRONTEND noninteractive +# RUN apt-get update +# RUN apt-get install -y git supervisor emacs-nox virtualenv procps -ENV DEBIAN_FRONTEND noninteractive +COPY ./requirements.txt /opt/collector/requirements.txt -RUN apt update -RUN apt install -y git supervisor emacs-nox virtualenv procps -RUN apt clean +RUN apt-get update \ + && apt-get install -y python3 python3-pip \ + && pip3 install -r /opt/collector/requirements.txt \ + && apt-get remove -y \ + gcc \ + curl \ + wget \ + python3-pip \ + python3-dev \ + && apt-get autoremove -y \ + && apt-get clean -WORKDIR /opt/ -RUN git clone https://git.sunet.se/soc_collector.git /opt/collector +# Remove setuid and setgid +RUN find / -xdev -perm /6000 -type f -exec chmod a-s {} \; || true + +# Add user +RUN useradd collector -u 1500 -s /usr/sbin/nologin + +COPY ./src /opt/collector/src WORKDIR /opt/collector/ -COPY setup.sh /opt/collector/ -COPY supervisord.conf /etc/supervisor/ +USER collector + +ENTRYPOINT ["uvicorn", "src.collector.main:app", "--host", "0.0.0.0", "--workers", "1", "--header", "server:collector"] +# ENTRYPOINT ["sleep", "300"] + +# RUN git clone https://git.sunet.se/soc_collector.git /opt/collector +# WORKDIR /opt/collector/ +# COPY setup.sh /opt/collector/ +# COPY supervisord.conf /etc/supervisor/ + +# RUN /opt/collector/setup.sh +# ENTRYPOINT supervisord -c /etc/supervisor/supervisord.conf + + -RUN /opt/collector/setup.sh -ENTRYPOINT supervisord -c /etc/supervisor/supervisord.conf diff --git a/docker/collector/Dockerfile-dev b/docker/collector/_dev_dockerfile_dev index 15a6ebe..15a6ebe 100644 --- a/docker/collector/Dockerfile-dev +++ b/docker/collector/_dev_dockerfile_dev diff --git a/docker/collector/supervisord.conf b/docker/collector/supervisord.conf index 7e260c6..2a2f5ca 100644 --- a/docker/collector/supervisord.conf +++ b/docker/collector/supervisord.conf @@ -3,7 +3,7 @@ nodaemon=true [program:uvicorn] directory = /opt/collector/src/ -command = /opt/collector/venv/bin/uvicorn --proxy-headers --host 0.0.0.0 --port 8000 main:app +command = /opt/collector/venv/bin/uvicorn --log-level debug --proxy-headers --host 0.0.0.0 --port 8000 main:app stdout_logfile=/dev/stdout stdout_logfile_maxbytes=0 stderr_logfile=/dev/stderr diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 23a543b..119d3a9 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -10,7 +10,9 @@ services: - certs:/etc/ssl/collector/ collector: - build: ./collector/ + build: + context: .. + dockerfile: docker/collector/Dockerfile environment: - COUCHDB_USER - COUCHDB_PASSWORD diff --git a/quickstart.sh b/quickstart.sh index 7c14a1f..56cc77a 100755 --- a/quickstart.sh +++ b/quickstart.sh @@ -42,5 +42,5 @@ if [ ! -f ${DOCKER_JWT_HTPASSWD_PATH}/userdb.yaml ]; then fi # Launch the containers. -docker-compose -f docker/docker-compose-dev.yaml up -d $build +docker-compose -f docker/docker-compose.yaml up -d $build docker-compose -f auth-server-poc/docker-compose.yml up -d $build diff --git a/quickstart_test.sh b/quickstart_test.sh new file mode 100755 index 0000000..9254271 --- /dev/null +++ b/quickstart_test.sh @@ -0,0 +1,67 @@ +# Usage: ./quickstart_test.sh [-v] [-c] [-- <args to pytest>] + +export COUCHDB_NAME=unittest +export COUCHDB_HOSTNAME=localhost +export COUCHDB_USER=test +export COUCHDB_PASSWORD=test + +export DOCKER_JWT_PUBKEY_PATH="`pwd`/test/unittest_cert/" +export JWT_PUBKEY_PATH="`pwd`/test/unittest_cert/public.pem" + +virtualenv=no +couchdb=no + +while getopts ":vc" flag +do + case "$flag" in + v) virtualenv=yes;; + c) couchdb=yes;; + esac +done + +if [ -d test/unittest_cert ]; then + rm -r test/unittest_cert +fi + +if [ $virtualenv == "yes" ]; then + shift + if [ -d test/unittest_venv ]; then + rm -r test/unittest_venv + fi + + virtualenv test/unittest_venv + source test/unittest_venv/bin/activate + pip3 install -r ../requirements.txt +fi + +if [ $couchdb == "yes" ]; then + shift + docker run -it -p 6123:5984 --rm -d --name unittest_couchdb -e COUCHDB_USER=$COUCHDB_USER -e COUCHDB_PASSWORD=$COUCHDB_PASSWORD couchdb + + docker inspect unittest_couchdb > /dev/null + + if (( $? != 0 )); then + echo "Failed to start CouchDB container." + exit + fi + + export COUCHDB_PORT=6123 +fi + +mkdir test/unittest_cert + +cat <<EOF > test/unittest_cert/public.pem +-----BEGIN PUBLIC KEY----- +MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEGHX8ipqVWtr49TXyX0f/L4GPhEpg +N0Erzy7hHkXVrkgKpnHSRLYWgbW4rscLoJAJeEv7Be5iH0TM8l09w8Q3wQ== +-----END PUBLIC KEY----- +EOF + +shift +pytest --capture=tee-sys "$@" + +rm -r test/unittest_cert + +if [ $couchdb == "yes" ]; then + docker kill unittest_couchdb +fi diff --git a/src/collector/__init__.py b/src/collector/__init__.py new file mode 100644 index 0000000..6530fdd --- /dev/null +++ b/src/collector/__init__.py @@ -0,0 +1,4 @@ +"""Collector +""" + +__version__ = "1.03" diff --git a/src/db.py b/src/collector/db.py index 5173dda..0bfa014 100755 --- a/src/db.py +++ b/src/collector/db.py @@ -12,34 +12,37 @@ import os import sys import time -import couch -from schema import as_index_list, validate_collector_data +from src import couch +from .schema import as_index_list, validate_collector_data -class DictDB(): +class DictDB: def __init__(self) -> None: """ Check if the database exists, otherwise we will create it together with the indexes specified in index.py. """ + print(os.environ) + try: - self.database = os.environ['COUCHDB_NAME'] - self.hostname = os.environ['COUCHDB_HOSTNAME'] - self.username = os.environ['COUCHDB_USER'] - self.password = os.environ['COUCHDB_PASSWORD'] + self.database = os.environ["COUCHDB_NAME"] + self.hostname = os.environ["COUCHDB_HOSTNAME"] + self.username = os.environ["COUCHDB_USER"] + self.password = os.environ["COUCHDB_PASSWORD"] except KeyError: - print('The environment variables COUCHDB_NAME, COUCHDB_HOSTNAME,' + - ' COUCHDB_USER and COUCHDB_PASSWORD must be set.') + print( + "The environment variables COUCHDB_NAME, COUCHDB_HOSTNAME," + + " COUCHDB_USER and COUCHDB_PASSWORD must be set." + ) sys.exit(-1) - if 'COUCHDB_PORT' in os.environ: - couchdb_port = os.environ['COUCHDB_PORT'] + if "COUCHDB_PORT" in os.environ: + couchdb_port = os.environ["COUCHDB_PORT"] else: couchdb_port = "5984" - self.server = couch.client.Server( - f"http://{self.username}:{self.password}@{self.hostname}:{couchdb_port}/") + self.server = couch.client.Server(f"http://{self.username}:{self.password}@{self.hostname}:{couchdb_port}/") try: self.couchdb = self.server.database(self.database) @@ -77,13 +80,13 @@ class DictDB(): error = validate_collector_data(item) if error != "": return error - item['_id'] = str(self.unique_key()) + item["_id"] = str(self.unique_key()) ret: Tuple[str, str] = self.couchdb.save_bulk(data) else: error = validate_collector_data(data) if error != "": return error - data['_id'] = str(self.unique_key()) + data["_id"] = str(self.unique_key()) ret = self.couchdb.save(data) return ret @@ -100,8 +103,9 @@ class DictDB(): return doc - def slice(self, key_from=None, key_to=None): - pass + # + # def slice(self, key_from=None, key_to=None): + # pass def search(self, limit: int = 25, skip: int = 0, **kwargs: Any) -> List[Dict[str, Any]]: """ @@ -109,8 +113,8 @@ class DictDB(): the query otherwise things will be slow. """ - data = list() - selector = dict() + data: List[Dict[str, Any]] = [] + selector: Dict[str, Any] = {} try: limit = int(limit) @@ -120,16 +124,12 @@ class DictDB(): skip = 0 if kwargs: - selector = { - "limit": limit, - "skip": skip, - "selector": {} - } + selector = {"limit": limit, "skip": skip, "selector": {}} for key in kwargs: if kwargs[key] and kwargs[key].isnumeric(): kwargs[key] = int(kwargs[key]) - selector['selector'][key] = {'$eq': kwargs[key]} + selector["selector"][key] = {"$eq": kwargs[key]} for doc in self.couchdb.find(selector, wrapper=None, limit=5): data.append(doc) diff --git a/src/main.py b/src/collector/main.py index 2730b83..c363885 100755 --- a/src/main.py +++ b/src/collector/main.py @@ -1,19 +1,20 @@ -from typing import Dict, Union, List, Any +from typing import Dict, Union, List, Callable, Awaitable, Any import json import os import sys import time import uvicorn -from fastapi import Depends, FastAPI, Request +from fastapi import Depends, FastAPI, Request, Response from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from fastapi_jwt_auth import AuthJWT +from fastapi_jwt_auth.auth_config import AuthConfig from fastapi_jwt_auth.exceptions import AuthJWTException from pydantic import BaseModel -from db import DictDB -from schema import get_index_keys, validate_collector_data +from .db import DictDB +from .schema import get_index_keys, validate_collector_data app = FastAPI() @@ -30,31 +31,34 @@ app.add_middleware( @app.middleware("http") -async def mock_x_total_count_header(request: Request, call_next): - response = await call_next(request) +async def mock_x_total_count_header(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: + + print(type(call_next)) + + response: Response = await call_next(request) response.headers["X-Total-Count"] = "100" return response + for i in range(10): try: db = DictDB() - except Exception: - print( - f'Database not responding, will try again soon. Attempt {i + 1} of 10.') + except Exception as e: + print(f"Database not responding, will try again soon. Attempt {i + 1} of 10.") else: break time.sleep(1) else: - print('Database did not respond after 10 attempts, quitting.') + print("Database did not respond after 10 attempts, quitting.") sys.exit(-1) def get_pubkey() -> str: try: - if 'JWT_PUBKEY_PATH' in os.environ: - keypath = os.environ['JWT_PUBKEY_PATH'] + if "JWT_PUBKEY_PATH" in os.environ: + keypath = os.environ["JWT_PUBKEY_PATH"] else: - keypath = '/opt/certs/public.pem' + keypath = "/opt/certs/public.pem" with open(keypath, "r") as fd: pubkey = fd.read() @@ -65,27 +69,28 @@ def get_pubkey() -> str: return pubkey - -def get_data(key: Union[int, None] = None, - limit: int = 25, - skip: int = 0, - ip: Union[str, None] = None, - port: Union[int, None] = None, - asn: Union[str, None] = None, - domain: Union[str, None] = None) -> List[Dict[str, Any]]: +def get_data( + key: Union[int, None] = None, + limit: int = 25, + skip: int = 0, + ip: Union[str, None] = None, + port: Union[int, None] = None, + asn: Union[str, None] = None, + domain: Union[str, None] = None, +) -> List[Dict[str, Any]]: if key: return [db.get(key)] selectors: Dict[str, Any] = {} indexes = get_index_keys() - selectors['domain'] = domain + selectors["domain"] = domain - if ip and 'ip' in indexes: - selectors['ip'] = ip - if port and 'port' in indexes: - selectors['port'] = port - if asn and 'asn' in indexes: - selectors['asn'] = asn + if ip and "ip" in indexes: + selectors["ip"] = ip + if port and "port" in indexes: + selectors["port"] = port + if asn and "asn" in indexes: + selectors["asn"] = asn data: List[Dict[str, Any]] = db.search(**selectors, limit=limit, skip=skip) @@ -97,26 +102,31 @@ class JWTConfig(BaseModel): authjwt_public_key: str = get_pubkey() -@AuthJWT.load_config +@AuthJWT.load_config # type: ignore def jwt_config(): return JWTConfig() @app.exception_handler(AuthJWTException) def authjwt_exception_handler(request: Request, exc: AuthJWTException) -> JSONResponse: - return JSONResponse(content={"status": "error", "message": - exc.message}, status_code=400) + return JSONResponse(content={"status": "error", "message": exc.message}, status_code=400) @app.exception_handler(RuntimeError) def app_exception_handler(request: Request, exc: RuntimeError) -> JSONResponse: - return JSONResponse(content={"status": "error", "message": - str(exc.with_traceback(None))}, - status_code=400) + return JSONResponse(content={"status": "error", "message": str(exc.with_traceback(None))}, status_code=400) -@app.get('/sc/v0/get') -async def get(key: Union[int, None] = None, limit: int = 25, skip: int = 0, ip: Union[str, None] = None, port: Union[int, None] = None, asn: Union[str, None] = None, Authorize: AuthJWT = Depends()) -> JSONResponse: +@app.get("/sc/v0/get") +async def get( + key: Union[int, None] = None, + limit: int = 25, + skip: int = 0, + ip: Union[str, None] = None, + port: Union[int, None] = None, + asn: Union[str, None] = None, + Authorize: AuthJWT = Depends(), +) -> JSONResponse: Authorize.jwt_required() @@ -140,7 +150,7 @@ async def get(key: Union[int, None] = None, limit: int = 25, skip: int = 0, ip: return JSONResponse(content={"status": "success", "docs": data}) -@app.get('/sc/v0/get/{key}') +@app.get("/sc/v0/get/{key}") async def get_key(key: Union[int, None] = None, Authorize: AuthJWT = Depends()) -> JSONResponse: Authorize.jwt_required() @@ -176,7 +186,7 @@ async def get_key(key: Union[int, None] = None, Authorize: AuthJWT = Depends()) # WHY IS AUTH OUTCOMMENTED??? -@app.post('/sc/v0/add') +@app.post("/sc/v0/add") async def add(data: Request, Authorize: AuthJWT = Depends()) -> JSONResponse: # Authorize.jwt_required() @@ -205,7 +215,7 @@ async def add(data: Request, Authorize: AuthJWT = Depends()) -> JSONResponse: return JSONResponse(content={"status": "success", "docs": key}) -@app.delete('/sc/v0/delete/{key}') +@app.delete("/sc/v0/delete/{key}") async def delete(key: int, Authorize: AuthJWT = Depends()) -> JSONResponse: Authorize.jwt_required() @@ -238,21 +248,20 @@ async def delete(key: int, Authorize: AuthJWT = Depends()) -> JSONResponse: ) if db.delete(key) is None: - return JSONResponse(content={"status": "error", - "message": "Document not found"}, - status_code=400) + return JSONResponse(content={"status": "error", "message": "Document not found"}, status_code=400) return JSONResponse(content={"status": "success", "docs": data}) -def main(standalone: bool = False): - if not standalone: - return app +# def main(standalone: bool = False): +# print(type(app)) +# if not standalone: +# return app - uvicorn.run(app, host="0.0.0.0", port=8000, log_level="debug") +# uvicorn.run(app, host="0.0.0.0", port=8000, log_level="debug") -if __name__ == '__main__': - main(standalone=True) -else: - app = main() +# if __name__ == "__main__": +# main(standalone=True) +# else: +# app = main() diff --git a/src/test/__init__.py b/src/collector/py.typed index e69de29..e69de29 100644 --- a/src/test/__init__.py +++ b/src/collector/py.typed diff --git a/src/schema.py b/src/collector/schema.py index 2b479d2..e291f10 100644 --- a/src/schema.py +++ b/src/collector/schema.py @@ -114,7 +114,7 @@ def as_index_list() -> List[Dict[str, Any]]: ] }, "name": name, - "type": "json" + "type": "json", } index_list.append(index) diff --git a/src/couch/__init__.py b/src/couch/__init__.py index a7537bc..64e0252 100644 --- a/src/couch/__init__.py +++ b/src/couch/__init__.py @@ -8,4 +8,4 @@ __email__ = "rinat.sabitov@gmail.com" __status__ = "Development" -from couch.client import Server # noqa: F401 +from .client import Server # noqa: F401 diff --git a/src/couch/client.py b/src/couch/client.py index 52477be..96dc78a 100644 --- a/src/couch/client.py +++ b/src/couch/client.py @@ -8,10 +8,24 @@ import copy import mimetypes import warnings -from couch import utils -from couch import feedreader -from couch import exceptions as exp -from couch.resource import Resource +from .utils import ( + force_bytes, + force_text, + encode_view_options, + extract_credentials, +) +from .feedreader import ( + SimpleFeedReader, + BaseFeedReader, +) + +from .exceptions import ( + Conflict, + NotFound, + FeedReaderExited, + UnexpectedError, +) +from .resource import Resource DEFAULT_BASE_URL = os.environ.get('COUCHDB_URL', 'http://localhost:5984/') @@ -25,16 +39,16 @@ def _id_to_path(_id: str) -> str: def _listen_feed(object, node, feed_reader, **kwargs): if not callable(feed_reader): - raise exp.UnexpectedError("feed_reader must be callable or class") + raise UnexpectedError("feed_reader must be callable or class") - if isinstance(feed_reader, feedreader.BaseFeedReader): + if isinstance(feed_reader, BaseFeedReader): reader = feed_reader(object) else: - reader = feedreader.SimpleFeedReader()(object, feed_reader) + reader = SimpleFeedReader()(object, feed_reader) # Possible options: "continuous", "longpoll" kwargs.setdefault("feed", "continuous") - data = utils.force_bytes(json.dumps(kwargs.pop('data', {}))) + data = force_bytes(json.dumps(kwargs.pop('data', {}))) (resp, result) = object.resource(node).post( params=kwargs, data=data, stream=True) @@ -44,8 +58,8 @@ def _listen_feed(object, node, feed_reader, **kwargs): if not line: reader.on_heartbeat() else: - reader.on_message(json.loads(utils.force_text(line))) - except exp.FeedReaderExited: + reader.on_message(json.loads(force_text(line))) + except FeedReaderExited: reader.on_close() @@ -100,7 +114,7 @@ class Server(object): def __init__(self, base_url=DEFAULT_BASE_URL, full_commit=True, authmethod="basic", verify=False): - self.base_url, credentials = utils.extract_credentials(base_url) + self.base_url, credentials = extract_credentials(base_url) self.resource = Resource(self.base_url, full_commit, credentials=credentials, authmethod=authmethod, @@ -112,7 +126,7 @@ class Server(object): def __contains__(self, name): try: self.resource.head(name) - except exp.NotFound: + except NotFound: return False else: return True @@ -158,7 +172,7 @@ class Server(object): """ (r, result) = self.resource.head(name) if r.status_code == 404: - raise exp.NotFound("Database '{0}' does not exists".format(name)) + raise NotFound("Database '{0}' does not exists".format(name)) db = Database(self.resource(name), name) return db @@ -206,7 +220,7 @@ class Server(object): data = {'source': source, 'target': target} data.update(kwargs) - data = utils.force_bytes(json.dumps(data)) + data = force_bytes(json.dumps(data)) (resp, result) = self.resource.post('_replicate', data=data) return result @@ -244,7 +258,7 @@ class Database(object): try: (resp, result) = self.resource.head(_id_to_path(doc_id)) return resp.status_code < 206 - except exp.NotFound: + except NotFound: return False def config(self): @@ -308,14 +322,14 @@ class Database(object): if "_deleted" not in doc: doc["_deleted"] = True - data = utils.force_bytes(json.dumps({"docs": _docs})) + data = force_bytes(json.dumps({"docs": _docs})) params = {"all_or_nothing": "true" if transaction else "false"} (resp, results) = self.resource.post( "_bulk_docs", data=data, params=params) for result, doc in zip(results, _docs): if "error" in result: - raise exp.Conflict("one or more docs are not saved") + raise Conflict("one or more docs are not saved") return results @@ -370,7 +384,7 @@ class Database(object): else: params = {} - data = utils.force_bytes(json.dumps(_doc)) + data = force_bytes(json.dumps(_doc)) print("gg1", flush=True) print(data, flush=True) @@ -390,7 +404,7 @@ class Database(object): print("vv2", flush=True) if resp.status_code == 409: - raise exp.Conflict(result['reason']) + raise Conflict(result['reason']) if "rev" in result and result["rev"] is not None: _doc["_rev"] = result["rev"] @@ -420,7 +434,7 @@ class Database(object): if "_id" not in doc: doc["_id"] = uuid.uuid4().hex - data = utils.force_bytes(json.dumps({"docs": _docs})) + data = orce_bytes(json.dumps({"docs": _docs})) params = {"all_or_nothing": "true" if transaction else "false"} (resp, results) = self.resource.post("_bulk_docs", data=data, @@ -456,9 +470,9 @@ class Database(object): if "keys" in params: data = {"keys": params.pop("keys")} - data = utils.force_bytes(json.dumps(data)) + data = force_bytes(json.dumps(data)) - params = utils.encode_view_options(params) + params = encode_view_options(params) if data: (resp, result) = self.resource.post( "_all_docs", params=params, data=data) @@ -538,7 +552,7 @@ class Database(object): resource = self.resource(doc_id) (resp, result) = resource.get(params=params) if resp.status_code == 404: - raise exp.NotFound("Document id `{0}` not found".format(doc_id)) + raise NotFound("Document id `{0}` not found".format(doc_id)) for rev in result['_revs_info']: if status and rev['status'] == status: @@ -566,10 +580,10 @@ class Database(object): (resp, result) = resource.delete( filename, params={'rev': _doc['_rev']}) if resp.status_code == 404: - raise exp.NotFound("filename {0} not found".format(filename)) + raise NotFound("filename {0} not found".format(filename)) if resp.status_code > 205: - raise exp.Conflict(result['reason']) + raise Conflict(result['reason']) _doc['_rev'] = result['rev'] try: @@ -645,7 +659,7 @@ class Database(object): if resp.status_code < 206: return self.get(doc["_id"]) - raise exp.Conflict(result['reason']) + raise Conflict(result['reason']) def one(self, name, flat=None, wrapper=None, **kwargs): """ @@ -665,16 +679,16 @@ class Database(object): params = {"limit": 1} params.update(kwargs) - path = utils._path_from_name(name, '_view') + path = _path_from_name(name, '_view') data = None if "keys" in params: data = {"keys": params.pop('keys')} if data: - data = utils.force_bytes(json.dumps(data)) + data = force_bytes(json.dumps(data)) - params = utils.encode_view_options(params) + params = encode_view_options(params) result = list(self._query(self.resource(*path), wrapper=wrapper, flat=flat, params=params, data=data)) @@ -716,16 +730,16 @@ class Database(object): :returns: generator object """ params = copy.copy(kwargs) - path = utils._path_from_name(name, '_view') + path = _path_from_name(name, '_view') data = None if "keys" in params: data = {"keys": params.pop('keys')} if data: - data = utils.force_bytes(json.dumps(data)) + data = force_bytes(json.dumps(data)) - params = utils.encode_view_options(params) + params = encode_view_options(params) result = self._query(self.resource(*path), wrapper=wrapper, flat=flat, params=params, data=data) @@ -768,7 +782,7 @@ class Database(object): """ path = '_find' - data = utils.force_bytes(json.dumps(selector)) + data = force_bytes(json.dumps(selector)) (resp, result) = self.resource.post(path, data=data, params=kwargs) @@ -780,7 +794,7 @@ class Database(object): def index(self, index_doc, **kwargs): path = '_index' - data = utils.force_bytes(json.dumps(index_doc)) + data = force_bytes(json.dumps(index_doc)) (resp, result) = self.resource.post(path, data=data, params=kwargs) diff --git a/src/couch/feedreader.py b/src/couch/feedreader.py index 98401ab..aac51d3 100644 --- a/src/couch/feedreader.py +++ b/src/couch/feedreader.py @@ -11,7 +11,7 @@ class BaseFeedReader: self.db = db return self - def on_message(self, message): + def on_message(self, message: str) -> None: """ Callback method that is called when change message is received from couchdb. @@ -22,14 +22,14 @@ class BaseFeedReader: raise NotImplementedError() - def on_close(self): + def on_close(self) -> None: """ Callback method that is received when connection is closed with a server. By default, does nothing. """ pass - def on_heartbeat(self): + def on_heartbeat(self) -> None: """ Callback method invoked when a hearbeat (empty line) is received from the _changes stream. Override this to purge the reader's internal @@ -48,5 +48,5 @@ class SimpleFeedReader(BaseFeedReader): self.callback = callback return super(SimpleFeedReader, self).__call__(db) - def on_message(self, message) -> None: + def on_message(self, message: str) -> None: self.callback(message, db=self.db) diff --git a/src/couch/resource.py b/src/couch/resource.py index 364bff4..f110c8d 100644 --- a/src/couch/resource.py +++ b/src/couch/resource.py @@ -1,17 +1,25 @@ # -*- coding: utf-8 -*- # Based on py-couchdb (https://github.com/histrio/py-couchdb) - +from __future__ import annotations from __future__ import unicode_literals -from typing import Union, Tuple +from typing import Union, Tuple, Dict, Any import json import requests - - -from couch import utils -from couch import exceptions +from .utils import ( + urljoin, + as_json, + force_bytes, +) +from .exceptions import ( + GenericError, + NotFound, + BadRequest, + Conflict, + AuthenticationFailed, +) class Resource: @@ -40,12 +48,11 @@ class Resource: if method == "session": data = {"name": credentials[0], "password": credentials[1]} - data = utils.force_bytes(json.dumps(data)) - post_url = utils.urljoin(self.base_url, "_session") - r = self.session.post(post_url, data=data) - if r.status_code != 200: - raise exceptions.AuthenticationFailed() + post_url = urljoin(self.base_url, "_session") + r = self.session.post(post_url, data=force_bytes(json.dumps(data))) + if r and r.status_code != 200: + raise AuthenticationFailed() elif method == "basic": self.session.auth = credentials @@ -53,8 +60,8 @@ class Resource: else: raise RuntimeError("Invalid authentication method") - def __call__(self, *path: str): - base_url = utils.urljoin(self.base_url, *path) + def __call__(self, *path: str) -> Resource: + base_url = urljoin(self.base_url, *path) return self.__class__(base_url, session=self.session) def _check_result(self, response, result) -> None: @@ -68,17 +75,18 @@ class Resource: # This is here because couchdb can return http 201 # but containing a list of conflict errors if error == 'conflict' or error == "file_exists": - raise exceptions.Conflict(reason or "Conflict") + raise Conflict(reason or "Conflict") if response.status_code > 205: if response.status_code == 404 or error == 'not_found': - raise exceptions.NotFound(reason or 'Not found') + raise NotFound(reason or 'Not found') elif error == 'bad_request': - raise exceptions.BadRequest(reason or "Bad request") - raise exceptions.GenericError(result) + raise BadRequest(reason or "Bad request") + raise GenericError(result) - def request(self, method, path: str, params=None, data=None, - headers=None, stream=False, **kwargs): + + def request(self, method, path: Union[str, None], params=None, data=None, + headers=None, stream=False, **kwargs) -> Tuple[requests.models.Response, Union[Dict[str, Any], None]]: if headers is None: headers = {} @@ -88,7 +96,7 @@ class Resource: if path: if not isinstance(path, (list, tuple)): path = [path] - url = utils.urljoin(self.base_url, *path) + url = urljoin(self.base_url, *path) else: url = self.base_url @@ -102,7 +110,7 @@ class Resource: result = None self._check_result(response, result) else: - result = utils.as_json(response) + result = as_json(response) if result is None: return response, result @@ -115,17 +123,17 @@ class Resource: return response, result - def get(self, path: Union[str, None] = None, **kwargs): + def get(self, path: Union[str, None] = None, **kwargs: Any) -> Tuple[requests.models.Response, Union[Dict[str, Any], None]]: return self.request("GET", path, **kwargs) - def put(self, path: Union[str, None] = None, **kwargs): + def put(self, path: Union[str, None] = None, **kwargs: Any) -> Tuple[requests.models.Response, Union[Dict[str, Any], None]]: return self.request("PUT", path, **kwargs) - def post(self, path: Union[str, None] = None, **kwargs): + def post(self, path: Union[str, None] = None, **kwargs: Any) -> Tuple[requests.models.Response, Union[Dict[str, Any], None]]: return self.request("POST", path, **kwargs) - def delete(self, path: Union[str, None] = None, **kwargs): + def delete(self, path: Union[str, None] = None, **kwargs: Any) -> Tuple[requests.models.Response, Union[Dict[str, Any], None]]: return self.request("DELETE", path, **kwargs) - def head(self, path: Union[str, None] = None, **kwargs): + def head(self, path: Union[str, None] = None, **kwargs: Any) -> Tuple[requests.models.Response, Union[Dict[str, Any], None]]: return self.request("HEAD", path, **kwargs) diff --git a/src/couch/utils.py b/src/couch/utils.py index f0883a6..b3e5aa3 100644 --- a/src/couch/utils.py +++ b/src/couch/utils.py @@ -78,13 +78,17 @@ def urljoin(base: str, *path: str) -> str: return reduce(_join, path, base) # Probably bugs here -def as_json(response: requests.models.Response) -> Union[Dict[str, Any], None]: +def as_json(response: requests.models.Response) -> Union[Dict[str, Any], None, str]: if "application/json" in response.headers['content-type']: response_src = response.content.decode('utf-8') + print(response.content) if response.content != b'': ret: Dict[str, Any] = json.loads(response_src) return ret else: + print("fff") + print("fff") + print(type(response_src)) return response_src return None diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/__init__.py diff --git a/src/test/test_api.py b/tests/test_api.py index 371fcf2..371fcf2 100644 --- a/src/test/test_api.py +++ b/tests/test_api.py |