diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/db/couch/__init__.py (renamed from src/couch/__init__.py) | 2 | ||||
-rw-r--r-- | src/db/couch/client.py (renamed from src/couch/client.py) | 14 | ||||
-rw-r--r-- | src/db/couch/exceptions.py (renamed from src/couch/exceptions.py) | 0 | ||||
-rw-r--r-- | src/db/couch/feedreader.py (renamed from src/couch/feedreader.py) | 0 | ||||
-rw-r--r-- | src/db/couch/resource.py (renamed from src/couch/resource.py) | 5 | ||||
-rw-r--r-- | src/db/couch/utils.py (renamed from src/couch/utils.py) | 0 | ||||
-rwxr-xr-x | src/db/dictionary.py (renamed from src/db.py) | 4 | ||||
-rw-r--r-- | src/db/index.py | 61 | ||||
-rw-r--r-- | src/db/schema.py (renamed from src/schema.py) | 0 | ||||
-rw-r--r-- | src/db/sql.py | 170 | ||||
-rwxr-xr-x | src/main.py | 48 |
11 files changed, 278 insertions, 26 deletions
diff --git a/src/couch/__init__.py b/src/db/couch/__init__.py index a7537bc..b099235 100644 --- a/src/couch/__init__.py +++ b/src/db/couch/__init__.py @@ -8,4 +8,4 @@ __email__ = "rinat.sabitov@gmail.com" __status__ = "Development" -from couch.client import Server # noqa: F401 +from db.couch.client import Server # noqa: F401 diff --git a/src/couch/client.py b/src/db/couch/client.py index 188e0de..73d85a1 100644 --- a/src/couch/client.py +++ b/src/db/couch/client.py @@ -1,18 +1,16 @@ # -*- coding: utf-8 -*- # Based on py-couchdb (https://github.com/histrio/py-couchdb) -import os -import json -import uuid import copy +import json import mimetypes +import os +import uuid import warnings -from couch import utils -from couch import feedreader -from couch import exceptions as exp -from couch.resource import Resource - +from db.couch import exceptions as exp +from db.couch import feedreader, utils +from db.couch.resource import Resource DEFAULT_BASE_URL = os.environ.get('COUCHDB_URL', 'http://localhost:5984/') diff --git a/src/couch/exceptions.py b/src/db/couch/exceptions.py index d7e037b..d7e037b 100644 --- a/src/couch/exceptions.py +++ b/src/db/couch/exceptions.py diff --git a/src/couch/feedreader.py b/src/db/couch/feedreader.py index e293932..e293932 100644 --- a/src/couch/feedreader.py +++ b/src/db/couch/feedreader.py diff --git a/src/couch/resource.py b/src/db/couch/resource.py index da1e0dd..8ff883b 100644 --- a/src/couch/resource.py +++ b/src/db/couch/resource.py @@ -5,10 +5,9 @@ from __future__ import unicode_literals import json -import requests -from couch import utils -from couch import exceptions +import requests +from db.couch import exceptions, utils class Resource(object): diff --git a/src/couch/utils.py b/src/db/couch/utils.py index 1cd21d8..1cd21d8 100644 --- a/src/couch/utils.py +++ b/src/db/couch/utils.py diff --git a/src/db.py b/src/db/dictionary.py index 6f25ec3..f0f5fe9 100755 --- a/src/db.py +++ b/src/db/dictionary.py @@ -11,8 +11,8 @@ import os import sys import time -import couch -from schema import as_index_list, validate_collector_data +from db import couch +from db.schema import as_index_list, validate_collector_data class DictDB(): diff --git a/src/db/index.py b/src/db/index.py new file mode 100644 index 0000000..688ceeb --- /dev/null +++ b/src/db/index.py @@ -0,0 +1,61 @@ +from pydantic import BaseSettings + + +class CouchIindex(BaseSettings): + domain: dict = { + "index": { + "fields": [ + "domain", + ] + }, + "name": "domain-json-index", + "type": "json" + } + ip: dict = { + "index": { + "fields": [ + "domain", + "ip" + ] + }, + "name": "ip-json-index", + "type": "json" + } + port: dict = { + "index": { + "fields": [ + "domain", + "port" + ] + }, + "name": "port-json-index", + "type": "json" + } + asn: dict = { + "index": { + "fields": [ + "domain", + "asn" + ] + }, + "name": "asn-json-index", + "type": "json" + } + asn_country_code: dict = { + "index": { + "fields": [ + "domain", + "asn_country_code" + ] + }, + "name": "asn-country-code-json-index", + "type": "json" + } + + +def as_list(): + index_list = list() + for item in CouchIindex().dict(): + index_list.append(CouchIindex().dict()[item]) + + return index_list diff --git a/src/schema.py b/src/db/schema.py index 9bdf130..9bdf130 100644 --- a/src/schema.py +++ b/src/db/schema.py diff --git a/src/db/sql.py b/src/db/sql.py new file mode 100644 index 0000000..c47a69c --- /dev/null +++ b/src/db/sql.py @@ -0,0 +1,170 @@ +import datetime +import os +import sys +from contextlib import contextmanager + +from sqlalchemy import (Boolean, Column, Date, Integer, String, Text, + create_engine, text) +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker + +Base = declarative_base() +metadata = Base.metadata + + +class Log(Base): + __tablename__ = "log" + + id = Column(Integer, primary_key=True) + timestamp = Column(Date, nullable=False, + default=datetime.datetime.utcnow) + username = Column(Text, nullable=False) + logtext = Column(Text, nullable=False) + + def as_dict(self): + """Return JSON serializable dict.""" + d = {} + for col in self.__table__.columns: + value = getattr(self, col.name) + if issubclass(value.__class__, Base): + continue + elif issubclass(value.__class__, datetime.datetime): + value = str(value) + d[col.name] = value + return d + + @classmethod + def add(cls, username, logtext): + with sqla_session() as session: + logentry = Log() + logentry.username = username + logentry.logtext = logtext + session.add(logentry) + + +class Scanner(Base): + __tablename__ = 'scanner' + + id = Column(Integer, primary_key=True, server_default=text( + "nextval('scanner_id_seq'::regclass)")) + runner = Column(Text, server_default=text("'*'::text")) + name = Column(String(128), nullable=False) + active = Column(Boolean, nullable=False) + interval = Column(Integer, nullable=False, + server_default=text("300")) + starttime = Column(Date) + endtime = Column(Date) + maxruns = Column(Integer, server_default=text("1")) + hostname = Column(String(128), nullable=False) + port = Column(Integer, nullable=False) + + def as_dict(self): + d = {} + for col in self.__table__.columns: + value = getattr(self, col.name) + if issubclass(value.__class__, Base): + continue + elif issubclass(value.__class__, datetime.datetime): + value = str(value) + d[col.name] = value + return d + + @classmethod + def add(cls, name, hostname, port, active=False, interval=0, + starttime=None, + endtime=None, + maxruns=1): + errors = list() + if starttime and endtime: + if starttime > endtime: + errors.append("Endtime must be after the starttime.") + if interval < 0: + errors.append("Interval must be > 0") + if maxruns < 0: + errors.append("Max runs must be > 0") + with sqla_session() as session: + scanentry = Scanner() + scanentry.name = name + scanentry.active = active + scanentry.interval = interval + if starttime: + scanentry.starttime = starttime + if endtime: + scanentry.endtime = endtime + scanentry.maxruns = maxruns + scanentry.hostname = hostname + scanentry.port = port + session.add(scanentry) + return errors + + @classmethod + def get(cls, name): + results = list() + with sqla_session() as session: + scanners = session.query(Scanner).all() + if not scanners: + return None + for scanner in scanners: + if scanner.runner == "*": + results.append(scanner.as_dict()) + elif scanner.runner == name: + results.append(scanner.as_dict()) + return results + + @classmethod + def edit(cls, name, active): + with sqla_session() as session: + scanners = session.query(Scanner).filter( + Scanner.name == name).all() + if not scanners: + return None + for scanner in scanners: + scanner.active = active + return True + + +def get_sqlalchemy_conn_str(**kwargs) -> str: + try: + if "SQL_HOSTNAME" in os.environ: + hostname = os.environ["SQL_HOSTNAME"] + else: + hostname = "localhost" + print("SQL_HOSTNAME not set, falling back to localhost.") + if "SQL_PORT" in os.environ: + port = os.environ["SQL_PORT"] + else: + print("SQL_PORT not set, falling back to 5432.") + port = 5432 + username = os.environ["SQL_USERNAME"] + password = os.environ["SQL_PASSWORD"] + database = os.environ["SQL_DATABASE"] + except KeyError: + print("SQL_DATABASE, SQL_USERNAME, SQL_PASSWORD must be set.") + sys.exit(-2) + + return ( + f"postgresql://{username}:{password}@{hostname}:{port}/{database}" + ) + + +def get_session(conn_str=""): + if conn_str == "": + conn_str = get_sqlalchemy_conn_str() + + engine = create_engine(conn_str, pool_size=50, max_overflow=0) + Session = sessionmaker(bind=engine) + + return Session() + + +@contextmanager +def sqla_session(conn_str="", **kwargs): + session = get_session(conn_str) + try: + yield session + session.commit() + except Exception: + session.rollback() + raise + finally: + session.close() diff --git a/src/main.py b/src/main.py index 9de8eb8..a62d77c 100755 --- a/src/main.py +++ b/src/main.py @@ -11,14 +11,14 @@ from fastapi_jwt_auth import AuthJWT from fastapi_jwt_auth.exceptions import AuthJWTException from pydantic import BaseModel -from db import DictDB -from schema import get_index_keys, validate_collector_data +from db.dictionary import DictDB +from db.schema import get_index_keys app = FastAPI() app.add_middleware( CORSMiddleware, - allow_origins=["http://localhost:8001"], + allow_origins=["http://localhost:8000"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], @@ -37,9 +37,8 @@ async def mock_x_total_count_header(request: Request, call_next): for i in range(10): try: db = DictDB() - except Exception: - print( - f'Database not responding, will try again soon. Attempt {i + 1} of 10.') + except Exception as e: + print(f"Database not responding, will try again soon: {e}") else: break time.sleep(1) @@ -90,25 +89,25 @@ class JWTConfig(BaseModel): authjwt_public_key: str = get_pubkey() -@AuthJWT.load_config +@ AuthJWT.load_config def jwt_config(): return JWTConfig() -@app.exception_handler(AuthJWTException) +@ app.exception_handler(AuthJWTException) def authjwt_exception_handler(request: Request, exc: AuthJWTException): return JSONResponse(content={"status": "error", "message": exc.message}, status_code=400) -@app.exception_handler(RuntimeError) +@ app.exception_handler(RuntimeError) def app_exception_handler(request: Request, exc: RuntimeError): return JSONResponse(content={"status": "error", "message": str(exc.with_traceback(None))}, status_code=400) -@app.get('/sc/v0/get') +@ app.get('/sc/v0/get') async def get(key=None, limit=25, skip=0, ip=None, port=None, asn=None, Authorize: AuthJWT = Depends()): @@ -134,7 +133,7 @@ async def get(key=None, limit=25, skip=0, ip=None, port=None, return JSONResponse(content={"status": "success", "docs": data}) -@app.get('/sc/v0/get/{key}') +@ app.get('/sc/v0/get/{key}') async def get_key(key=None, Authorize: AuthJWT = Depends()): Authorize.jwt_required() @@ -195,7 +194,7 @@ async def add(data: Request, Authorize: AuthJWT = Depends()): return JSONResponse(content={"status": "success", "docs": key}) -@app.delete('/sc/v0/delete/{key}') +@ app.delete('/sc/v0/delete/{key}') async def delete(key, Authorize: AuthJWT = Depends()): Authorize.jwt_required() @@ -232,6 +231,31 @@ async def delete(key, Authorize: AuthJWT = Depends()): return JSONResponse(content={"status": "success", "docs": data}) +@ app.get("/sc/v0/scanner/{name}") +async def scanner_get(name, data: Request, Authorize: AuthJWT = Depends()): + Authorize.jwt_required() + + scanners = Scanner.get(name) + + return JSONResponse(content={"status": "success", "data": scanners}) + + +@ app.put("/sc/v0/scanner/{name}") +async def scanner_put(name, data: Request, Authorize: AuthJWT = Depends()): + errors = None + Authorize.jwt_required() + + json_data = await data.json() + + if "active" in json_data and isinstance(json_data["active"], bool): + errors = Scanner.active(name, json_data["active"]) + + if errors: + return JSONResponse(content={"status": "error", "message": "\n".join(errors)}, status_code=400) + + return JSONResponse(content={"status": "success", "data": Scanner.get(name)}, status_code=200) + + def main(standalone=False): if not standalone: return app |