summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/db/couch/__init__.py (renamed from src/couch/__init__.py)2
-rw-r--r--src/db/couch/client.py (renamed from src/couch/client.py)14
-rw-r--r--src/db/couch/exceptions.py (renamed from src/couch/exceptions.py)0
-rw-r--r--src/db/couch/feedreader.py (renamed from src/couch/feedreader.py)0
-rw-r--r--src/db/couch/resource.py (renamed from src/couch/resource.py)5
-rw-r--r--src/db/couch/utils.py (renamed from src/couch/utils.py)0
-rwxr-xr-xsrc/db/dictionary.py (renamed from src/db.py)4
-rw-r--r--src/db/index.py61
-rw-r--r--src/db/schema.py (renamed from src/schema.py)0
-rw-r--r--src/db/sql.py170
-rwxr-xr-xsrc/main.py48
11 files changed, 278 insertions, 26 deletions
diff --git a/src/couch/__init__.py b/src/db/couch/__init__.py
index a7537bc..b099235 100644
--- a/src/couch/__init__.py
+++ b/src/db/couch/__init__.py
@@ -8,4 +8,4 @@ __email__ = "rinat.sabitov@gmail.com"
__status__ = "Development"
-from couch.client import Server # noqa: F401
+from db.couch.client import Server # noqa: F401
diff --git a/src/couch/client.py b/src/db/couch/client.py
index 188e0de..73d85a1 100644
--- a/src/couch/client.py
+++ b/src/db/couch/client.py
@@ -1,18 +1,16 @@
# -*- coding: utf-8 -*-
# Based on py-couchdb (https://github.com/histrio/py-couchdb)
-import os
-import json
-import uuid
import copy
+import json
import mimetypes
+import os
+import uuid
import warnings
-from couch import utils
-from couch import feedreader
-from couch import exceptions as exp
-from couch.resource import Resource
-
+from db.couch import exceptions as exp
+from db.couch import feedreader, utils
+from db.couch.resource import Resource
DEFAULT_BASE_URL = os.environ.get('COUCHDB_URL', 'http://localhost:5984/')
diff --git a/src/couch/exceptions.py b/src/db/couch/exceptions.py
index d7e037b..d7e037b 100644
--- a/src/couch/exceptions.py
+++ b/src/db/couch/exceptions.py
diff --git a/src/couch/feedreader.py b/src/db/couch/feedreader.py
index e293932..e293932 100644
--- a/src/couch/feedreader.py
+++ b/src/db/couch/feedreader.py
diff --git a/src/couch/resource.py b/src/db/couch/resource.py
index da1e0dd..8ff883b 100644
--- a/src/couch/resource.py
+++ b/src/db/couch/resource.py
@@ -5,10 +5,9 @@
from __future__ import unicode_literals
import json
-import requests
-from couch import utils
-from couch import exceptions
+import requests
+from db.couch import exceptions, utils
class Resource(object):
diff --git a/src/couch/utils.py b/src/db/couch/utils.py
index 1cd21d8..1cd21d8 100644
--- a/src/couch/utils.py
+++ b/src/db/couch/utils.py
diff --git a/src/db.py b/src/db/dictionary.py
index 6f25ec3..f0f5fe9 100755
--- a/src/db.py
+++ b/src/db/dictionary.py
@@ -11,8 +11,8 @@ import os
import sys
import time
-import couch
-from schema import as_index_list, validate_collector_data
+from db import couch
+from db.schema import as_index_list, validate_collector_data
class DictDB():
diff --git a/src/db/index.py b/src/db/index.py
new file mode 100644
index 0000000..688ceeb
--- /dev/null
+++ b/src/db/index.py
@@ -0,0 +1,61 @@
+from pydantic import BaseSettings
+
+
+class CouchIindex(BaseSettings):
+ domain: dict = {
+ "index": {
+ "fields": [
+ "domain",
+ ]
+ },
+ "name": "domain-json-index",
+ "type": "json"
+ }
+ ip: dict = {
+ "index": {
+ "fields": [
+ "domain",
+ "ip"
+ ]
+ },
+ "name": "ip-json-index",
+ "type": "json"
+ }
+ port: dict = {
+ "index": {
+ "fields": [
+ "domain",
+ "port"
+ ]
+ },
+ "name": "port-json-index",
+ "type": "json"
+ }
+ asn: dict = {
+ "index": {
+ "fields": [
+ "domain",
+ "asn"
+ ]
+ },
+ "name": "asn-json-index",
+ "type": "json"
+ }
+ asn_country_code: dict = {
+ "index": {
+ "fields": [
+ "domain",
+ "asn_country_code"
+ ]
+ },
+ "name": "asn-country-code-json-index",
+ "type": "json"
+ }
+
+
+def as_list():
+ index_list = list()
+ for item in CouchIindex().dict():
+ index_list.append(CouchIindex().dict()[item])
+
+ return index_list
diff --git a/src/schema.py b/src/db/schema.py
index 9bdf130..9bdf130 100644
--- a/src/schema.py
+++ b/src/db/schema.py
diff --git a/src/db/sql.py b/src/db/sql.py
new file mode 100644
index 0000000..c47a69c
--- /dev/null
+++ b/src/db/sql.py
@@ -0,0 +1,170 @@
+import datetime
+import os
+import sys
+from contextlib import contextmanager
+
+from sqlalchemy import (Boolean, Column, Date, Integer, String, Text,
+ create_engine, text)
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import sessionmaker
+
+Base = declarative_base()
+metadata = Base.metadata
+
+
+class Log(Base):
+ __tablename__ = "log"
+
+ id = Column(Integer, primary_key=True)
+ timestamp = Column(Date, nullable=False,
+ default=datetime.datetime.utcnow)
+ username = Column(Text, nullable=False)
+ logtext = Column(Text, nullable=False)
+
+ def as_dict(self):
+ """Return JSON serializable dict."""
+ d = {}
+ for col in self.__table__.columns:
+ value = getattr(self, col.name)
+ if issubclass(value.__class__, Base):
+ continue
+ elif issubclass(value.__class__, datetime.datetime):
+ value = str(value)
+ d[col.name] = value
+ return d
+
+ @classmethod
+ def add(cls, username, logtext):
+ with sqla_session() as session:
+ logentry = Log()
+ logentry.username = username
+ logentry.logtext = logtext
+ session.add(logentry)
+
+
+class Scanner(Base):
+ __tablename__ = 'scanner'
+
+ id = Column(Integer, primary_key=True, server_default=text(
+ "nextval('scanner_id_seq'::regclass)"))
+ runner = Column(Text, server_default=text("'*'::text"))
+ name = Column(String(128), nullable=False)
+ active = Column(Boolean, nullable=False)
+ interval = Column(Integer, nullable=False,
+ server_default=text("300"))
+ starttime = Column(Date)
+ endtime = Column(Date)
+ maxruns = Column(Integer, server_default=text("1"))
+ hostname = Column(String(128), nullable=False)
+ port = Column(Integer, nullable=False)
+
+ def as_dict(self):
+ d = {}
+ for col in self.__table__.columns:
+ value = getattr(self, col.name)
+ if issubclass(value.__class__, Base):
+ continue
+ elif issubclass(value.__class__, datetime.datetime):
+ value = str(value)
+ d[col.name] = value
+ return d
+
+ @classmethod
+ def add(cls, name, hostname, port, active=False, interval=0,
+ starttime=None,
+ endtime=None,
+ maxruns=1):
+ errors = list()
+ if starttime and endtime:
+ if starttime > endtime:
+ errors.append("Endtime must be after the starttime.")
+ if interval < 0:
+ errors.append("Interval must be > 0")
+ if maxruns < 0:
+ errors.append("Max runs must be > 0")
+ with sqla_session() as session:
+ scanentry = Scanner()
+ scanentry.name = name
+ scanentry.active = active
+ scanentry.interval = interval
+ if starttime:
+ scanentry.starttime = starttime
+ if endtime:
+ scanentry.endtime = endtime
+ scanentry.maxruns = maxruns
+ scanentry.hostname = hostname
+ scanentry.port = port
+ session.add(scanentry)
+ return errors
+
+ @classmethod
+ def get(cls, name):
+ results = list()
+ with sqla_session() as session:
+ scanners = session.query(Scanner).all()
+ if not scanners:
+ return None
+ for scanner in scanners:
+ if scanner.runner == "*":
+ results.append(scanner.as_dict())
+ elif scanner.runner == name:
+ results.append(scanner.as_dict())
+ return results
+
+ @classmethod
+ def edit(cls, name, active):
+ with sqla_session() as session:
+ scanners = session.query(Scanner).filter(
+ Scanner.name == name).all()
+ if not scanners:
+ return None
+ for scanner in scanners:
+ scanner.active = active
+ return True
+
+
+def get_sqlalchemy_conn_str(**kwargs) -> str:
+ try:
+ if "SQL_HOSTNAME" in os.environ:
+ hostname = os.environ["SQL_HOSTNAME"]
+ else:
+ hostname = "localhost"
+ print("SQL_HOSTNAME not set, falling back to localhost.")
+ if "SQL_PORT" in os.environ:
+ port = os.environ["SQL_PORT"]
+ else:
+ print("SQL_PORT not set, falling back to 5432.")
+ port = 5432
+ username = os.environ["SQL_USERNAME"]
+ password = os.environ["SQL_PASSWORD"]
+ database = os.environ["SQL_DATABASE"]
+ except KeyError:
+ print("SQL_DATABASE, SQL_USERNAME, SQL_PASSWORD must be set.")
+ sys.exit(-2)
+
+ return (
+ f"postgresql://{username}:{password}@{hostname}:{port}/{database}"
+ )
+
+
+def get_session(conn_str=""):
+ if conn_str == "":
+ conn_str = get_sqlalchemy_conn_str()
+
+ engine = create_engine(conn_str, pool_size=50, max_overflow=0)
+ Session = sessionmaker(bind=engine)
+
+ return Session()
+
+
+@contextmanager
+def sqla_session(conn_str="", **kwargs):
+ session = get_session(conn_str)
+ try:
+ yield session
+ session.commit()
+ except Exception:
+ session.rollback()
+ raise
+ finally:
+ session.close()
diff --git a/src/main.py b/src/main.py
index 9de8eb8..a62d77c 100755
--- a/src/main.py
+++ b/src/main.py
@@ -11,14 +11,14 @@ from fastapi_jwt_auth import AuthJWT
from fastapi_jwt_auth.exceptions import AuthJWTException
from pydantic import BaseModel
-from db import DictDB
-from schema import get_index_keys, validate_collector_data
+from db.dictionary import DictDB
+from db.schema import get_index_keys
app = FastAPI()
app.add_middleware(
CORSMiddleware,
- allow_origins=["http://localhost:8001"],
+ allow_origins=["http://localhost:8000"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
@@ -37,9 +37,8 @@ async def mock_x_total_count_header(request: Request, call_next):
for i in range(10):
try:
db = DictDB()
- except Exception:
- print(
- f'Database not responding, will try again soon. Attempt {i + 1} of 10.')
+ except Exception as e:
+ print(f"Database not responding, will try again soon: {e}")
else:
break
time.sleep(1)
@@ -90,25 +89,25 @@ class JWTConfig(BaseModel):
authjwt_public_key: str = get_pubkey()
-@AuthJWT.load_config
+@ AuthJWT.load_config
def jwt_config():
return JWTConfig()
-@app.exception_handler(AuthJWTException)
+@ app.exception_handler(AuthJWTException)
def authjwt_exception_handler(request: Request, exc: AuthJWTException):
return JSONResponse(content={"status": "error", "message":
exc.message}, status_code=400)
-@app.exception_handler(RuntimeError)
+@ app.exception_handler(RuntimeError)
def app_exception_handler(request: Request, exc: RuntimeError):
return JSONResponse(content={"status": "error", "message":
str(exc.with_traceback(None))},
status_code=400)
-@app.get('/sc/v0/get')
+@ app.get('/sc/v0/get')
async def get(key=None, limit=25, skip=0, ip=None, port=None,
asn=None, Authorize: AuthJWT = Depends()):
@@ -134,7 +133,7 @@ async def get(key=None, limit=25, skip=0, ip=None, port=None,
return JSONResponse(content={"status": "success", "docs": data})
-@app.get('/sc/v0/get/{key}')
+@ app.get('/sc/v0/get/{key}')
async def get_key(key=None, Authorize: AuthJWT = Depends()):
Authorize.jwt_required()
@@ -195,7 +194,7 @@ async def add(data: Request, Authorize: AuthJWT = Depends()):
return JSONResponse(content={"status": "success", "docs": key})
-@app.delete('/sc/v0/delete/{key}')
+@ app.delete('/sc/v0/delete/{key}')
async def delete(key, Authorize: AuthJWT = Depends()):
Authorize.jwt_required()
@@ -232,6 +231,31 @@ async def delete(key, Authorize: AuthJWT = Depends()):
return JSONResponse(content={"status": "success", "docs": data})
+@ app.get("/sc/v0/scanner/{name}")
+async def scanner_get(name, data: Request, Authorize: AuthJWT = Depends()):
+ Authorize.jwt_required()
+
+ scanners = Scanner.get(name)
+
+ return JSONResponse(content={"status": "success", "data": scanners})
+
+
+@ app.put("/sc/v0/scanner/{name}")
+async def scanner_put(name, data: Request, Authorize: AuthJWT = Depends()):
+ errors = None
+ Authorize.jwt_required()
+
+ json_data = await data.json()
+
+ if "active" in json_data and isinstance(json_data["active"], bool):
+ errors = Scanner.active(name, json_data["active"])
+
+ if errors:
+ return JSONResponse(content={"status": "error", "message": "\n".join(errors)}, status_code=400)
+
+ return JSONResponse(content={"status": "success", "data": Scanner.get(name)}, status_code=200)
+
+
def main(standalone=False):
if not standalone:
return app