diff options
Diffstat (limited to 'src/db')
-rw-r--r-- | src/db/couch/__init__.py | 11 | ||||
-rw-r--r-- | src/db/couch/client.py | 770 | ||||
-rw-r--r-- | src/db/couch/exceptions.py | 38 | ||||
-rw-r--r-- | src/db/couch/feedreader.py | 52 | ||||
-rw-r--r-- | src/db/couch/resource.py | 127 | ||||
-rw-r--r-- | src/db/couch/utils.py | 150 | ||||
-rwxr-xr-x | src/db/dictionary.py | 146 | ||||
-rw-r--r-- | src/db/index.py | 61 | ||||
-rw-r--r-- | src/db/schema.py | 135 | ||||
-rw-r--r-- | src/db/sql.py | 170 |
10 files changed, 1660 insertions, 0 deletions
diff --git a/src/db/couch/__init__.py b/src/db/couch/__init__.py new file mode 100644 index 0000000..b099235 --- /dev/null +++ b/src/db/couch/__init__.py @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- + +__author__ = "Andrey Antukh" +__license__ = "BSD" +__version__ = "1.14.1" +__maintainer__ = "Rinat Sabitov" +__email__ = "rinat.sabitov@gmail.com" +__status__ = "Development" + + +from db.couch.client import Server # noqa: F401 diff --git a/src/db/couch/client.py b/src/db/couch/client.py new file mode 100644 index 0000000..73d85a1 --- /dev/null +++ b/src/db/couch/client.py @@ -0,0 +1,770 @@ +# -*- coding: utf-8 -*- +# Based on py-couchdb (https://github.com/histrio/py-couchdb) + +import copy +import json +import mimetypes +import os +import uuid +import warnings + +from db.couch import exceptions as exp +from db.couch import feedreader, utils +from db.couch.resource import Resource + +DEFAULT_BASE_URL = os.environ.get('COUCHDB_URL', 'http://localhost:5984/') + + +def _id_to_path(_id): + if _id[:1] == "_": + return _id.split("/", 1) + return [_id] + + +def _listen_feed(object, node, feed_reader, **kwargs): + if not callable(feed_reader): + raise exp.UnexpectedError("feed_reader must be callable or class") + + if isinstance(feed_reader, feedreader.BaseFeedReader): + reader = feed_reader(object) + else: + reader = feedreader.SimpleFeedReader()(object, feed_reader) + + # Possible options: "continuous", "longpoll" + kwargs.setdefault("feed", "continuous") + data = utils.force_bytes(json.dumps(kwargs.pop('data', {}))) + + (resp, result) = object.resource(node).post( + params=kwargs, data=data, stream=True) + try: + for line in resp.iter_lines(): + # ignore heartbeats + if not line: + reader.on_heartbeat() + else: + reader.on_message(json.loads(utils.force_text(line))) + except exp.FeedReaderExited: + reader.on_close() + + +class _StreamResponse(object): + """ + Proxy object for python-requests stream response. + + See more on: + http://docs.python-requests.org/en/latest/user/advanced/#streaming-requests + """ + + def __init__(self, response): + self._response = response + + def iter_content(self, chunk_size=1, decode_unicode=False): + return self._response.iter_content(chunk_size=chunk_size, + decode_unicode=decode_unicode) + + def iter_lines(self, chunk_size=512, decode_unicode=None): + return self._response.iter_lines(chunk_size=chunk_size, + decode_unicode=decode_unicode) + + @property + def raw(self): + return self._response.raw + + @property + def url(self): + return self._response.url + + +class Server(object): + """ + Class that represents a couchdb connection. + + :param verify: setup ssl verification. + :param base_url: a full url to couchdb (can contain auth data). + :param full_commit: If ``False``, couchdb not commits all data on a + request is finished. + :param authmethod: specify a authentication method. By default "basic" + method is used but also exists "session" (that requires + some server configuration changes). + + .. versionchanged: 1.4 + Set basic auth method as default instead of session method. + + .. versionchanged: 1.5 + Add verify parameter for setup ssl verificaton + + """ + + def __init__(self, base_url=DEFAULT_BASE_URL, full_commit=True, + authmethod="basic", verify=False): + + self.base_url, credentials = utils.extract_credentials(base_url) + self.resource = Resource(self.base_url, full_commit, + credentials=credentials, + authmethod=authmethod, + verify=verify) + + def __repr__(self): + return '<CouchDB Server "{}">'.format(self.base_url) + + def __contains__(self, name): + try: + self.resource.head(name) + except exp.NotFound: + return False + else: + return True + + def __iter__(self): + (r, result) = self.resource.get('_all_dbs') + return iter(result) + + def __len__(self): + (r, result) = self.resource.get('_all_dbs') + return len(result) + + def info(self): + """ + Get server info. + + :returns: dict with all data that couchdb returns. + :rtype: dict + """ + (r, result) = self.resource.get() + return result + + def delete(self, name): + """ + Delete some database. + + :param name: database name + :raises: :py:exc:`~pycouchdb.exceptions.NotFound` + if a database does not exists + """ + + self.resource.delete(name) + + def database(self, name): + """ + Get a database instance. + + :param name: database name + :raises: :py:exc:`~pycouchdb.exceptions.NotFound` + if a database does not exists + + :returns: a :py:class:`~pycouchdb.client.Database` instance + """ + (r, result) = self.resource.head(name) + if r.status_code == 404: + raise exp.NotFound("Database '{0}' does not exists".format(name)) + + db = Database(self.resource(name), name) + return db + + # TODO: Config in 2.0 are applicable for nodes only + # TODO: Reimplement when nodes endpoint will be ready + # def config(self): + # pass + + def version(self): + """ + Get the current version of a couchdb server. + """ + (resp, result) = self.resource.get() + return result["version"] + + # TODO: Stats in 2.0 are applicable for nodes only + # TODO: Reimplement when nodes endpoint will be ready + # def stats(self, name=None): + # pass + + def create(self, name): + """ + Create a database. + + :param name: database name + :raises: :py:exc:`~pycouchdb.exceptions.Conflict` + if a database already exists + :returns: a :py:class:`~pycouchdb.client.Database` instance + """ + (resp, result) = self.resource.put(name) + if resp.status_code in (200, 201): + return self.database(name) + + def replicate(self, source, target, **kwargs): + """ + Replicate the source database to the target one. + + .. versionadded:: 1.3 + + :param source: full URL to the source database + :param target: full URL to the target database + """ + + data = {'source': source, 'target': target} + data.update(kwargs) + + data = utils.force_bytes(json.dumps(data)) + + (resp, result) = self.resource.post('_replicate', data=data) + return result + + def changes_feed(self, feed_reader, **kwargs): + """ + Subscribe to changes feed of the whole CouchDB server. + + Note: this method is blocking. + + + :param feed_reader: callable or :py:class:`~BaseFeedReader` + instance + + .. [Ref] http://docs.couchdb.org/en/1.6.1/api/server/common.html#db-updates + .. versionadded: 1.10 + """ + object = self + _listen_feed(object, "_db_updates", feed_reader, **kwargs) + + +class Database(object): + """ + Class that represents a couchdb database. + """ + + def __init__(self, resource, name): + self.resource = resource + self.name = name + + def __repr__(self): + return '<CouchDB Database "{}">'.format(self.name) + + def __contains__(self, doc_id): + try: + (resp, result) = self.resource.head(_id_to_path(doc_id)) + return resp.status_code < 206 + except exp.NotFound: + return False + + def config(self): + """ + Get database status data such as document count, update sequence etc. + :return: dict + """ + (resp, result) = self.resource.get() + return result + + def __nonzero__(self): + """Is the database available""" + resp, _ = self.resource.head() + return resp.status_code == 200 + + def __len__(self): + return self.config()['doc_count'] + + def delete(self, doc_or_id): + """ + Delete document by id. + + .. versionchanged:: 1.2 + Accept document or id. + + :param doc_or_id: document or id + :raises: :py:exc:`~pycouchdb.exceptions.NotFound` if a document + not exists + :raises: :py:exc:`~pycouchdb.exceptions.Conflict` if delete with + wrong revision. + """ + + _id = None + if isinstance(doc_or_id, dict): + if "_id" not in doc_or_id: + raise ValueError("Invalid document, missing _id attr") + _id = doc_or_id['_id'] + else: + _id = doc_or_id + + resource = self.resource(*_id_to_path(_id)) + + (r, result) = resource.head() + (r, result) = resource.delete( + params={"rev": r.headers["etag"].strip('"')}) + + def delete_bulk(self, docs, transaction=True): + """ + Delete a bulk of documents. + + .. versionadded:: 1.2 + + :param docs: list of docs + :raises: :py:exc:`~pycouchdb.exceptions.Conflict` if a delete + is not success + :returns: raw results from server + """ + + _docs = copy.copy(docs) + for doc in _docs: + if "_deleted" not in doc: + doc["_deleted"] = True + + data = utils.force_bytes(json.dumps({"docs": _docs})) + params = {"all_or_nothing": "true" if transaction else "false"} + (resp, results) = self.resource.post( + "_bulk_docs", data=data, params=params) + + for result, doc in zip(results, _docs): + if "error" in result: + raise exp.Conflict("one or more docs are not saved") + + return results + + def get(self, doc_id, params=None, **kwargs): + """ + Get a document by id. + + .. versionadded: 1.5 + Now the prefered method to pass params is via **kwargs + instead of params argument. **params** argument is now + deprecated and will be deleted in future versions. + + :param doc_id: document id + :raises: :py:exc:`~pycouchdb.exceptions.NotFound` if a document + not exists + + :returns: document (dict) + """ + + if params: + warnings.warn("params parameter is now deprecated in favor to" + "**kwargs usage.", DeprecationWarning) + + if params is None: + params = {} + + params.update(kwargs) + + (resp, result) = self.resource(*_id_to_path(doc_id)).get(params=params) + return result + + def save(self, doc, batch=False): + """ + Save or update a document. + + .. versionchanged:: 1.2 + Now returns a new document instead of modify the original. + + :param doc: document + :param batch: allow batch=ok inserts (default False) + :raises: :py:exc:`~pycouchdb.exceptions.Conflict` if save with wrong + revision. + :returns: doc + """ + + _doc = copy.copy(doc) + if "_id" not in _doc: + _doc['_id'] = uuid.uuid4().hex + + if batch: + params = {'batch': 'ok'} + else: + params = {} + + data = utils.force_bytes(json.dumps(_doc)) + (resp, result) = self.resource(_doc['_id']).put( + data=data, params=params) + + if resp.status_code == 409: + raise exp.Conflict(result['reason']) + + if "rev" in result and result["rev"] is not None: + _doc["_rev"] = result["rev"] + + return _doc + + def save_bulk(self, docs, try_setting_ids=True, transaction=True): + """ + Save a bulk of documents. + + .. versionchanged:: 1.2 + Now returns a new document list instead of modify the original. + + :param docs: list of docs + :param try_setting_ids: if ``True``, we loop through docs and generate/set + an id in each doc if none exists + :param transaction: if ``True``, couchdb do a insert in transaction + model. + :returns: docs + """ + + _docs = copy.deepcopy(docs) + + # Insert _id field if it not exists and try_setting_ids is true + if try_setting_ids: + for doc in _docs: + if "_id" not in doc: + doc["_id"] = uuid.uuid4().hex + + data = utils.force_bytes(json.dumps({"docs": _docs})) + params = {"all_or_nothing": "true" if transaction else "false"} + + (resp, results) = self.resource.post("_bulk_docs", data=data, + params=params) + + for result, doc in zip(results, _docs): + if "rev" in result: + doc['_rev'] = result['rev'] + + return _docs + + def all(self, wrapper=None, flat=None, as_list=False, **kwargs): + """ + Execute a builtin view for get all documents. + + :param wrapper: wrap result into a specific class. + :param as_list: return a list of results instead of a + default lazy generator. + :param flat: get a specific field from a object instead + of a complete object. + + .. versionadded: 1.4 + Add as_list parameter. + Add flat parameter. + + :returns: generator object + """ + + params = {"include_docs": "true"} + params.update(kwargs) + + data = None + + if "keys" in params: + data = {"keys": params.pop("keys")} + data = utils.force_bytes(json.dumps(data)) + + params = utils.encode_view_options(params) + if data: + (resp, result) = self.resource.post( + "_all_docs", params=params, data=data) + else: + (resp, result) = self.resource.get("_all_docs", params=params) + + if wrapper is None: + def wrapper(doc): return doc + + if flat is not None: + def wrapper(doc): return doc[flat] + + def _iterate(): + for row in result["rows"]: + yield wrapper(row) + + if as_list: + return list(_iterate()) + return _iterate() + + def cleanup(self): + """ + Execute a cleanup operation. + """ + (r, result) = self.resource('_view_cleanup').post() + return result + + def commit(self): + """ + Send commit message to server. + """ + (resp, result) = self.resource.post('_ensure_full_commit') + return result + + def compact(self): + """ + Send compact message to server. Compacting write-heavy databases + should be avoided, otherwise the process may not catch up with + the writes. Read load has no effect. + """ + (r, result) = self.resource("_compact").post() + return result + + def compact_view(self, ddoc): + """ + Execute compact over design view. + + :raises: :py:exc:`~pycouchdb.exceptions.NotFound` + if a view does not exists. + """ + (r, result) = self.resource("_compact", ddoc).post() + return result + + def revisions(self, doc_id, status='available', params=None, **kwargs): + """ + Get all revisions of one document. + + :param doc_id: document id + :param status: filter of revision status, set empty to list all + :raises: :py:exc:`~pycouchdb.exceptions.NotFound` + if a view does not exists. + + :returns: generator object + """ + if params: + warnings.warn("params parameter is now deprecated in favor to" + "**kwargs usage.", DeprecationWarning) + + if params is None: + params = {} + + params.update(kwargs) + + if not params.get('revs_info'): + params['revs_info'] = 'true' + + resource = self.resource(doc_id) + (resp, result) = resource.get(params=params) + if resp.status_code == 404: + raise exp.NotFound("Document id `{0}` not found".format(doc_id)) + + for rev in result['_revs_info']: + if status and rev['status'] == status: + yield self.get(doc_id, rev=rev['rev']) + elif not status: + yield self.get(doc_id, rev=rev['rev']) + + def delete_attachment(self, doc, filename): + """ + Delete attachment by filename from document. + + .. versionchanged:: 1.2 + Now returns a new document instead of modify the original. + + :param doc: document dict + :param filename: name of attachment. + :raises: :py:exc:`~pycouchdb.exceptions.Conflict` + if save with wrong revision. + :returns: doc + """ + + _doc = copy.deepcopy(doc) + resource = self.resource(_doc['_id']) + + (resp, result) = resource.delete( + filename, params={'rev': _doc['_rev']}) + if resp.status_code == 404: + raise exp.NotFound("filename {0} not found".format(filename)) + + if resp.status_code > 205: + raise exp.Conflict(result['reason']) + + _doc['_rev'] = result['rev'] + try: + del _doc['_attachments'][filename] + + if not _doc['_attachments']: + del _doc['_attachments'] + except KeyError: + pass + + return _doc + + def get_attachment(self, doc, filename, stream=False, **kwargs): + """ + Get attachment by filename from document. + + :param doc: document dict + :param filename: attachment file name. + :param stream: setup streaming output (default: False) + + .. versionchanged: 1.5 + Add stream parameter for obtain very large attachments + without load all file to the memory. + + :returns: binary data or + """ + + params = {"rev": doc["_rev"]} + params.update(kwargs) + + r, result = self.resource(doc['_id']).get(filename, stream=stream, + params=params) + if stream: + return _StreamResponse(r) + + return r.content + + def put_attachment(self, doc, content, filename=None, content_type=None): + """ + Put a attachment to a document. + + .. versionchanged:: 1.2 + Now returns a new document instead of modify the original. + + :param doc: document dict. + :param content: the content to upload, either a file-like object or + bytes + :param filename: the name of the attachment file; if omitted, this + function tries to get the filename from the file-like + object passed as the `content` argument value + :raises: :py:exc:`~pycouchdb.exceptions.Conflict` + if save with wrong revision. + :raises: ValueError + :returns: doc + """ + + if filename is None: + if hasattr(content, 'name'): + filename = os.path.basename(content.name) + else: + raise ValueError('no filename specified for attachment') + + if content_type is None: + content_type = ';'.join( + filter(None, mimetypes.guess_type(filename))) + + headers = {"Content-Type": content_type} + resource = self.resource(doc['_id']) + + (resp, result) = resource.put( + filename, data=content, params={'rev': doc['_rev']}, headers=headers) + + if resp.status_code < 206: + return self.get(doc["_id"]) + + raise exp.Conflict(result['reason']) + + def one(self, name, flat=None, wrapper=None, **kwargs): + """ + Execute a design document view query and returns a first + result. + + :param name: name of the view (eg: docidname/viewname). + :param wrapper: wrap result into a specific class. + :param flat: get a specific field from a object instead + of a complete object. + + .. versionadded: 1.4 + + :returns: object or None + """ + + params = {"limit": 1} + params.update(kwargs) + + path = utils._path_from_name(name, '_view') + data = None + + if "keys" in params: + data = {"keys": params.pop('keys')} + + if data: + data = utils.force_bytes(json.dumps(data)) + + params = utils.encode_view_options(params) + result = list(self._query(self.resource(*path), wrapper=wrapper, + flat=flat, params=params, data=data)) + + return result[0] if len(result) > 0 else None + + def _query(self, resource, data=None, params=None, headers=None, + flat=None, wrapper=None): + + if data is None: + (resp, result) = resource.get(params=params, headers=headers) + else: + (resp, result) = resource.post( + data=data, params=params, headers=headers) + + if wrapper is None: + def wrapper(row): return row + + if flat is not None: + def wrapper(row): return row[flat] + + for row in result["rows"]: + yield wrapper(row) + + def query(self, name, wrapper=None, flat=None, as_list=False, **kwargs): + """ + Execute a design document view query. + + :param name: name of the view (eg: docidname/viewname). + :param wrapper: wrap result into a specific class. + :param as_list: return a list of results instead of a + default lazy generator. + :param flat: get a specific field from a object instead + of a complete object. + + .. versionadded: 1.4 + Add as_list parameter. + Add flat parameter. + + :returns: generator object + """ + params = copy.copy(kwargs) + path = utils._path_from_name(name, '_view') + data = None + + if "keys" in params: + data = {"keys": params.pop('keys')} + + if data: + data = utils.force_bytes(json.dumps(data)) + + params = utils.encode_view_options(params) + result = self._query(self.resource(*path), wrapper=wrapper, + flat=flat, params=params, data=data) + + if as_list: + return list(result) + return result + + def changes_feed(self, feed_reader, **kwargs): + """ + Subscribe to changes feed of couchdb database. + + Note: this method is blocking. + + + :param feed_reader: callable or :py:class:`~BaseFeedReader` + instance + + .. versionadded: 1.5 + """ + + object = self + _listen_feed(object, "_changes", feed_reader, **kwargs) + + def changes_list(self, **kwargs): + """ + Obtain a list of changes from couchdb. + + .. versionadded: 1.5 + """ + + (resp, result) = self.resource("_changes").get(params=kwargs) + return result['last_seq'], result['results'] + + def find(self, selector, wrapper=None, **kwargs): + """ + Execute Mango querys using _find. + + :param selector: data to search + :param wrapper: wrap result into a specific class. + + """ + path = '_find' + data = utils.force_bytes(json.dumps(selector)) + + (resp, result) = self.resource.post(path, data=data, params=kwargs) + + if wrapper is None: + def wrapper(doc): return doc + + for doc in result["docs"]: + yield wrapper(doc) + + def index(self, index_doc, **kwargs): + path = '_index' + data = utils.force_bytes(json.dumps(index_doc)) + + (resp, result) = self.resource.post(path, data=data, params=kwargs) + + return result diff --git a/src/db/couch/exceptions.py b/src/db/couch/exceptions.py new file mode 100644 index 0000000..d7e037b --- /dev/null +++ b/src/db/couch/exceptions.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +# Based on py-couchdb (https://github.com/histrio/py-couchdb) + + +class Error(Exception): + pass + + +class UnexpectedError(Error): + pass + + +class FeedReaderExited(Error): + pass + + +class ApiError(Error): + pass + + +class GenericError(ApiError): + pass + + +class Conflict(ApiError): + pass + + +class NotFound(ApiError): + pass + + +class BadRequest(ApiError): + pass + + +class AuthenticationFailed(ApiError): + pass diff --git a/src/db/couch/feedreader.py b/src/db/couch/feedreader.py new file mode 100644 index 0000000..e293932 --- /dev/null +++ b/src/db/couch/feedreader.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Based on py-couchdb (https://github.com/histrio/py-couchdb) + + +class BaseFeedReader(object): + """ + Base interface class for changes feed reader. + """ + + def __call__(self, db): + self.db = db + return self + + def on_message(self, message): + """ + Callback method that is called when change + message is received from couchdb. + + :param message: change object + :returns: None + """ + + raise NotImplementedError() + + def on_close(self): + """ + Callback method that is received when connection + is closed with a server. By default, does nothing. + """ + pass + + def on_heartbeat(self): + """ + Callback method invoked when a hearbeat (empty line) is received + from the _changes stream. Override this to purge the reader's internal + buffers (if any) if it waited too long without receiving anything. + """ + pass + + +class SimpleFeedReader(BaseFeedReader): + """ + Simple feed reader that encapsule any callable in + a valid feed reader interface. + """ + + def __call__(self, db, callback): + self.callback = callback + return super(SimpleFeedReader, self).__call__(db) + + def on_message(self, message): + self.callback(message, db=self.db) diff --git a/src/db/couch/resource.py b/src/db/couch/resource.py new file mode 100644 index 0000000..8ff883b --- /dev/null +++ b/src/db/couch/resource.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +# Based on py-couchdb (https://github.com/histrio/py-couchdb) + + +from __future__ import unicode_literals + +import json + +import requests +from db.couch import exceptions, utils + + +class Resource(object): + def __init__(self, base_url, full_commit=True, session=None, + credentials=None, authmethod="session", verify=False): + + self.base_url = base_url +# self.verify = verify + + if not session: + self.session = requests.session() + + self.session.headers.update({"accept": "application/json", + "content-type": "application/json"}) + self._authenticate(credentials, authmethod) + + if not full_commit: + self.session.headers.update({'X-Couch-Full-Commit': 'false'}) + else: + self.session = session + self.session.verify = verify + + def _authenticate(self, credentials, method): + if not credentials: + return + + if method == "session": + data = {"name": credentials[0], "password": credentials[1]} + data = utils.force_bytes(json.dumps(data)) + + post_url = utils.urljoin(self.base_url, "_session") + r = self.session.post(post_url, data=data) + if r.status_code != 200: + raise exceptions.AuthenticationFailed() + + elif method == "basic": + self.session.auth = credentials + + else: + raise RuntimeError("Invalid authentication method") + + def __call__(self, *path): + base_url = utils.urljoin(self.base_url, *path) + return self.__class__(base_url, session=self.session) + + def _check_result(self, response, result): + try: + error = result.get('error', None) + reason = result.get('reason', None) + except AttributeError: + error = None + reason = '' + + # This is here because couchdb can return http 201 + # but containing a list of conflict errors + if error == 'conflict' or error == "file_exists": + raise exceptions.Conflict(reason or "Conflict") + + if response.status_code > 205: + if response.status_code == 404 or error == 'not_found': + raise exceptions.NotFound(reason or 'Not found') + elif error == 'bad_request': + raise exceptions.BadRequest(reason or "Bad request") + raise exceptions.GenericError(result) + + def request(self, method, path, params=None, data=None, + headers=None, stream=False, **kwargs): + + if headers is None: + headers = {} + + headers.setdefault('Accept', 'application/json') + + if path: + if not isinstance(path, (list, tuple)): + path = [path] + url = utils.urljoin(self.base_url, *path) + else: + url = self.base_url + + response = self.session.request(method, url, stream=stream, + data=data, params=params, + headers=headers, **kwargs) + # Ignore result validation if + # request is with stream mode + + if stream and response.status_code < 400: + result = None + self._check_result(response, result) + else: + result = utils.as_json(response) + + if result is None: + return response, result + + if isinstance(result, list): + for res in result: + self._check_result(response, res) + else: + self._check_result(response, result) + + return response, result + + def get(self, path=None, **kwargs): + return self.request("GET", path, **kwargs) + + def put(self, path=None, **kwargs): + return self.request("PUT", path, **kwargs) + + def post(self, path=None, **kwargs): + return self.request("POST", path, **kwargs) + + def delete(self, path=None, **kwargs): + return self.request("DELETE", path, **kwargs) + + def head(self, path=None, **kwargs): + return self.request("HEAD", path, **kwargs) diff --git a/src/db/couch/utils.py b/src/db/couch/utils.py new file mode 100644 index 0000000..1cd21d8 --- /dev/null +++ b/src/db/couch/utils.py @@ -0,0 +1,150 @@ +# -*- coding: utf-8 -*- +# Based on py-couchdb (https://github.com/histrio/py-couchdb) + + +import json +import sys + + +if sys.version_info[0] == 3: + from urllib.parse import unquote as _unquote + from urllib.parse import urlunsplit, urlsplit + + string_type = str + bytes_type = bytes + + from functools import reduce + +else: + from urllib import unquote as _unquote + from urlparse import urlunsplit, urlsplit + + string_type = unicode # noqa: F821 + bytes_type = str + +URLSPLITTER = '/' + + +json_encoder = json.JSONEncoder() + + +def extract_credentials(url): + """ + Extract authentication (user name and password) credentials from the + given URL. + + >>> extract_credentials('http://localhost:5984/_config/') + ('http://localhost:5984/_config/', None) + >>> extract_credentials('http://joe:secret@localhost:5984/_config/') + ('http://localhost:5984/_config/', ('joe', 'secret')) + >>> extract_credentials('http://joe%40example.com:secret@' + ... 'localhost:5984/_config/') + ('http://localhost:5984/_config/', ('joe@example.com', 'secret')) + """ + parts = urlsplit(url) + netloc = parts[1] + if '@' in netloc: + creds, netloc = netloc.split('@') + credentials = tuple(_unquote(i) for i in creds.split(':')) + parts = list(parts) + parts[1] = netloc + else: + credentials = None + return urlunsplit(parts), credentials + + +def _join(head, tail): + parts = [head.rstrip(URLSPLITTER), tail.lstrip(URLSPLITTER)] + return URLSPLITTER.join(parts) + + +def urljoin(base, *path): + """ + Assemble a uri based on a base, any number of path segments, and query + string parameters. + + >>> urljoin('http://example.org', '_all_dbs') + 'http://example.org/_all_dbs' + + A trailing slash on the uri base is handled gracefully: + + >>> urljoin('http://example.org/', '_all_dbs') + 'http://example.org/_all_dbs' + + And multiple positional arguments become path parts: + + >>> urljoin('http://example.org/', 'foo', 'bar') + 'http://example.org/foo/bar' + + >>> urljoin('http://example.org/', 'foo/bar') + 'http://example.org/foo/bar' + + >>> urljoin('http://example.org/', 'foo', '/bar/') + 'http://example.org/foo/bar/' + + >>> urljoin('http://example.com', 'org.couchdb.user:username') + 'http://example.com/org.couchdb.user:username' + """ + return reduce(_join, path, base) + + +def as_json(response): + if "application/json" in response.headers['content-type']: + response_src = response.content.decode('utf-8') + if response.content != b'': + return json.loads(response_src) + else: + return response_src + return None + + +def _path_from_name(name, type): + """ + Expand a 'design/foo' style name to its full path as a list of + segments. + + >>> _path_from_name("_design/test", '_view') + ['_design', 'test'] + >>> _path_from_name("design/test", '_view') + ['_design', 'design', '_view', 'test'] + """ + if name.startswith('_'): + return name.split('/') + design, name = name.split('/', 1) + return ['_design', design, type, name] + + +def encode_view_options(options): + """ + Encode any items in the options dict that are sent as a JSON string to a + view/list function. + + >>> opts = {'key': 'foo', "notkey":"bar"} + >>> res = encode_view_options(opts) + >>> res["key"], res["notkey"] + ('"foo"', 'bar') + + >>> opts = {'startkey': 'foo', "endkey":"bar"} + >>> res = encode_view_options(opts) + >>> res['startkey'], res['endkey'] + ('"foo"', '"bar"') + """ + retval = {} + + for name, value in options.items(): + if name in ('key', 'startkey', 'endkey'): + value = json_encoder.encode(value) + retval[name] = value + return retval + + +def force_bytes(data, encoding="utf-8"): + if isinstance(data, string_type): + data = data.encode(encoding) + return data + + +def force_text(data, encoding="utf-8"): + if isinstance(data, bytes_type): + data = data.decode(encoding) + return data diff --git a/src/db/dictionary.py b/src/db/dictionary.py new file mode 100755 index 0000000..f0f5fe9 --- /dev/null +++ b/src/db/dictionary.py @@ -0,0 +1,146 @@ +# A database storing dictionaries, keyed on a timestamp. value = A +# dict which will be stored as a JSON object encoded in UTF-8. Note +# that dict keys of type integer or float will become strings while +# values will keep their type. + +# Note that there's a (slim) chance that you'd stomp on the previous +# value if you're too quick with generating the timestamps, ie +# invoking time.time() several times quickly enough. + +import os +import sys +import time + +from db import couch +from db.schema import as_index_list, validate_collector_data + + +class DictDB(): + def __init__(self): + """ + Check if the database exists, otherwise we will create it together + with the indexes specified in index.py. + """ + + try: + self.database = os.environ['COUCHDB_NAME'] + self.hostname = os.environ['COUCHDB_HOSTNAME'] + self.username = os.environ['COUCHDB_USER'] + self.password = os.environ['COUCHDB_PASSWORD'] + except KeyError: + print('The environment variables COUCHDB_NAME, COUCHDB_HOSTNAME,' + + ' COUCHDB_USER and COUCHDB_PASSWORD must be set.') + sys.exit(-1) + + if 'COUCHDB_PORT' in os.environ: + couchdb_port = os.environ['COUCHDB_PORT'] + else: + couchdb_port = 5984 + + self.server = couch.client.Server( + f"http://{self.username}:{self.password}@{self.hostname}:{couchdb_port}/") + + try: + self.couchdb = self.server.database(self.database) + print("Database already exists") + except couch.exceptions.NotFound: + print("Creating database and indexes.") + self.couchdb = self.server.create(self.database) + + for i in as_index_list(): + self.couchdb.index(i) + + self._ts = time.time() + + def unique_key(self): + """ + Create a unique key based on the current time. We will use this as + the ID for any new documents we store in CouchDB. + """ + + ts = time.time() + while round(ts * 1000) == self._ts: + ts = time.time() + self._ts = round(ts * 1000) + + return self._ts + + def add(self, data, batch_write=False): + """ + Store a document in CouchDB. + """ + + if type(data) is list: + for item in data: + error = validate_collector_data(item) + if error != "": + return error + item['_id'] = str(self.unique_key()) + ret = self.couchdb.save_bulk(data) + else: + error = validate_collector_data(data) + if error != "": + return error + data['_id'] = str(self.unique_key()) + ret = self.couchdb.save(data) + + return ret + + def get(self, key): + """ + Get a document based on its ID, return an empty dict if not found. + """ + + try: + doc = self.couchdb.get(key) + except couch.exceptions.NotFound: + doc = {} + + return doc + + def slice(self, key_from=None, key_to=None): + pass + + def search(self, limit=25, skip=0, **kwargs): + """ + Execute a Mango query, ideally we should have an index matching + the query otherwise things will be slow. + """ + + data = list() + selector = dict() + + try: + limit = int(limit) + skip = int(skip) + except ValueError: + limit = 25 + skip = 0 + + if kwargs: + selector = { + "limit": limit, + "skip": skip, + "selector": {} + } + + for key in kwargs: + if kwargs[key] and kwargs[key].isnumeric(): + kwargs[key] = int(kwargs[key]) + selector['selector'][key] = {'$eq': kwargs[key]} + + for doc in self.couchdb.find(selector, wrapper=None, limit=5): + data.append(doc) + + return data + + def delete(self, key): + """ + Delete a document based on its ID. + """ + try: + self.couchdb.delete(key) + except couch.exceptions.NotFound: + return None + + return key diff --git a/src/db/index.py b/src/db/index.py new file mode 100644 index 0000000..688ceeb --- /dev/null +++ b/src/db/index.py @@ -0,0 +1,61 @@ +from pydantic import BaseSettings + + +class CouchIindex(BaseSettings): + domain: dict = { + "index": { + "fields": [ + "domain", + ] + }, + "name": "domain-json-index", + "type": "json" + } + ip: dict = { + "index": { + "fields": [ + "domain", + "ip" + ] + }, + "name": "ip-json-index", + "type": "json" + } + port: dict = { + "index": { + "fields": [ + "domain", + "port" + ] + }, + "name": "port-json-index", + "type": "json" + } + asn: dict = { + "index": { + "fields": [ + "domain", + "asn" + ] + }, + "name": "asn-json-index", + "type": "json" + } + asn_country_code: dict = { + "index": { + "fields": [ + "domain", + "asn_country_code" + ] + }, + "name": "asn-country-code-json-index", + "type": "json" + } + + +def as_list(): + index_list = list() + for item in CouchIindex().dict(): + index_list.append(CouchIindex().dict()[item]) + + return index_list diff --git a/src/db/schema.py b/src/db/schema.py new file mode 100644 index 0000000..9bdf130 --- /dev/null +++ b/src/db/schema.py @@ -0,0 +1,135 @@ +import json +import sys +import traceback + +import jsonschema + +# fmt:off +# NOTE: Commented out properties are left intentionally, so it is easier to see +# what properties are optional. +schema = { + "$schema": "http://json-schema.org/schema#", + "type": "object", + "properties": { + "document_version": {"type": "integer"}, + "ip": {"type": "string"}, + "port": {"type": "integer"}, + "whois_description": {"type": "string"}, + "asn": {"type": "string"}, + "asn_country_code": {"type": "string"}, + "ptr": {"type": "string"}, + "abuse_mail": {"type": "string"}, + "domain": {"type": "string"}, + "timestamp_in_utc": {"type": "string"}, + "display_name": {"type": "string"}, + "description": {"type": "string"}, + "custom_data": { + "type": "object", + "patternProperties": { + ".*": { + "type": "object", + "properties": { + "display_name": {"type": "string"}, + "data": {"type": ["string", "boolean", "integer"]}, + "description": {"type": "string"}, + }, + "required": [ + "display_name", + "data", + # "description" + ] + }, + }, + }, + "result": { + "type": "object", + "patternProperties": { + ".*": { + "type": "object", + "properties": { + "display_name": {"type": "string"}, + "vulnerable": {"type": "boolean"}, + "investigation_needed": {"type": "boolean"}, + "reliability": {"type": "integer"}, + "description": {"type": "string"}, + }, + "oneOf": [ + { + "required": [ + "display_name", + "vulnerable", + # "reliability", # TODO: reliability is required if vulnerable = true + # "description", + ] + }, + { + "required": [ + "display_name", + "investigation_needed", + # "reliability", # TODO: reliability is required if investigation_needed = true + # "description", + ] + }, + ] + }, + }, + }, + }, + "required": [ + "document_version", + "ip", + "port", + "whois_description", + "asn", + "asn_country_code", + "ptr", + "abuse_mail", + "domain", + "timestamp_in_utc", + "display_name", + # "description", + # "custom_data", + "result", + ], +} +# fmt:on + + +def get_index_keys(): + keys = list() + for key in schema["properties"]: + keys.append(key) + return keys + + +def as_index_list(): + index_list = list() + for key in schema["properties"]: + name = f"{key}-json-index" + index = { + "index": { + "fields": [ + key, + ] + }, + "name": name, + "type": "json" + } + index_list.append(index) + + return index_list + + +def validate_collector_data(json_blob): + try: + jsonschema.validate(json_blob, schema) + except jsonschema.exceptions.ValidationError as e: + return f"Validation failed with error: {e.message}" + return "" + + +if __name__ == "__main__": + with open(sys.argv[1]) as fd: + json_data = json.loads(fd.read()) + + print(validate_collector_data(json_data)) diff --git a/src/db/sql.py b/src/db/sql.py new file mode 100644 index 0000000..c47a69c --- /dev/null +++ b/src/db/sql.py @@ -0,0 +1,170 @@ +import datetime +import os +import sys +from contextlib import contextmanager + +from sqlalchemy import (Boolean, Column, Date, Integer, String, Text, + create_engine, text) +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker + +Base = declarative_base() +metadata = Base.metadata + + +class Log(Base): + __tablename__ = "log" + + id = Column(Integer, primary_key=True) + timestamp = Column(Date, nullable=False, + default=datetime.datetime.utcnow) + username = Column(Text, nullable=False) + logtext = Column(Text, nullable=False) + + def as_dict(self): + """Return JSON serializable dict.""" + d = {} + for col in self.__table__.columns: + value = getattr(self, col.name) + if issubclass(value.__class__, Base): + continue + elif issubclass(value.__class__, datetime.datetime): + value = str(value) + d[col.name] = value + return d + + @classmethod + def add(cls, username, logtext): + with sqla_session() as session: + logentry = Log() + logentry.username = username + logentry.logtext = logtext + session.add(logentry) + + +class Scanner(Base): + __tablename__ = 'scanner' + + id = Column(Integer, primary_key=True, server_default=text( + "nextval('scanner_id_seq'::regclass)")) + runner = Column(Text, server_default=text("'*'::text")) + name = Column(String(128), nullable=False) + active = Column(Boolean, nullable=False) + interval = Column(Integer, nullable=False, + server_default=text("300")) + starttime = Column(Date) + endtime = Column(Date) + maxruns = Column(Integer, server_default=text("1")) + hostname = Column(String(128), nullable=False) + port = Column(Integer, nullable=False) + + def as_dict(self): + d = {} + for col in self.__table__.columns: + value = getattr(self, col.name) + if issubclass(value.__class__, Base): + continue + elif issubclass(value.__class__, datetime.datetime): + value = str(value) + d[col.name] = value + return d + + @classmethod + def add(cls, name, hostname, port, active=False, interval=0, + starttime=None, + endtime=None, + maxruns=1): + errors = list() + if starttime and endtime: + if starttime > endtime: + errors.append("Endtime must be after the starttime.") + if interval < 0: + errors.append("Interval must be > 0") + if maxruns < 0: + errors.append("Max runs must be > 0") + with sqla_session() as session: + scanentry = Scanner() + scanentry.name = name + scanentry.active = active + scanentry.interval = interval + if starttime: + scanentry.starttime = starttime + if endtime: + scanentry.endtime = endtime + scanentry.maxruns = maxruns + scanentry.hostname = hostname + scanentry.port = port + session.add(scanentry) + return errors + + @classmethod + def get(cls, name): + results = list() + with sqla_session() as session: + scanners = session.query(Scanner).all() + if not scanners: + return None + for scanner in scanners: + if scanner.runner == "*": + results.append(scanner.as_dict()) + elif scanner.runner == name: + results.append(scanner.as_dict()) + return results + + @classmethod + def edit(cls, name, active): + with sqla_session() as session: + scanners = session.query(Scanner).filter( + Scanner.name == name).all() + if not scanners: + return None + for scanner in scanners: + scanner.active = active + return True + + +def get_sqlalchemy_conn_str(**kwargs) -> str: + try: + if "SQL_HOSTNAME" in os.environ: + hostname = os.environ["SQL_HOSTNAME"] + else: + hostname = "localhost" + print("SQL_HOSTNAME not set, falling back to localhost.") + if "SQL_PORT" in os.environ: + port = os.environ["SQL_PORT"] + else: + print("SQL_PORT not set, falling back to 5432.") + port = 5432 + username = os.environ["SQL_USERNAME"] + password = os.environ["SQL_PASSWORD"] + database = os.environ["SQL_DATABASE"] + except KeyError: + print("SQL_DATABASE, SQL_USERNAME, SQL_PASSWORD must be set.") + sys.exit(-2) + + return ( + f"postgresql://{username}:{password}@{hostname}:{port}/{database}" + ) + + +def get_session(conn_str=""): + if conn_str == "": + conn_str = get_sqlalchemy_conn_str() + + engine = create_engine(conn_str, pool_size=50, max_overflow=0) + Session = sessionmaker(bind=engine) + + return Session() + + +@contextmanager +def sqla_session(conn_str="", **kwargs): + session = get_session(conn_str) + try: + yield session + session.commit() + except Exception: + session.rollback() + raise + finally: + session.close() |