diff options
-rw-r--r-- | requirements.txt | 7 | ||||
-rw-r--r-- | src/db/couch/__init__.py (renamed from src/couch/__init__.py) | 2 | ||||
-rw-r--r-- | src/db/couch/client.py (renamed from src/couch/client.py) | 14 | ||||
-rw-r--r-- | src/db/couch/exceptions.py (renamed from src/couch/exceptions.py) | 0 | ||||
-rw-r--r-- | src/db/couch/feedreader.py (renamed from src/couch/feedreader.py) | 0 | ||||
-rw-r--r-- | src/db/couch/resource.py (renamed from src/couch/resource.py) | 5 | ||||
-rw-r--r-- | src/db/couch/utils.py (renamed from src/couch/utils.py) | 0 | ||||
-rwxr-xr-x | src/db/db.py (renamed from src/db.py) | 71 | ||||
-rw-r--r-- | src/db/index.py (renamed from src/index.py) | 0 | ||||
-rw-r--r-- | src/db/scanner.py | 133 | ||||
-rw-r--r-- | src/log.py | 16 | ||||
-rwxr-xr-x | src/main.py | 158 | ||||
-rw-r--r-- | src/routers/__init__.py | 9 | ||||
-rw-r--r-- | src/routers/collector.py | 152 | ||||
-rw-r--r-- | src/routers/scanner.py | 93 | ||||
-rw-r--r-- | tools/jwt_producer.py | 24 |
16 files changed, 502 insertions, 182 deletions
diff --git a/requirements.txt b/requirements.txt index ce2f921..7922e54 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ anyio==3.3.4 +APScheduler==3.8.1 asgiref==3.4.1 attrs==21.2.0 certifi==2021.10.8 @@ -20,10 +21,16 @@ pydantic==1.8.2 PyJWT==1.7.1 pyparsing==3.0.6 pytest==6.2.5 +pytz==2021.3 +pytz-deprecation-shim==0.1.0.post0 requests==2.26.0 +six==1.16.0 sniffio==1.2.0 +SQLAlchemy==1.4.29 starlette==0.16.0 toml==0.10.2 typing-extensions==3.10.0.2 +tzdata==2021.5 +tzlocal==4.1 urllib3==1.26.7 uvicorn==0.15.0 diff --git a/src/couch/__init__.py b/src/db/couch/__init__.py index a7537bc..b099235 100644 --- a/src/couch/__init__.py +++ b/src/db/couch/__init__.py @@ -8,4 +8,4 @@ __email__ = "rinat.sabitov@gmail.com" __status__ = "Development" -from couch.client import Server # noqa: F401 +from db.couch.client import Server # noqa: F401 diff --git a/src/couch/client.py b/src/db/couch/client.py index 188e0de..73d85a1 100644 --- a/src/couch/client.py +++ b/src/db/couch/client.py @@ -1,18 +1,16 @@ # -*- coding: utf-8 -*- # Based on py-couchdb (https://github.com/histrio/py-couchdb) -import os -import json -import uuid import copy +import json import mimetypes +import os +import uuid import warnings -from couch import utils -from couch import feedreader -from couch import exceptions as exp -from couch.resource import Resource - +from db.couch import exceptions as exp +from db.couch import feedreader, utils +from db.couch.resource import Resource DEFAULT_BASE_URL = os.environ.get('COUCHDB_URL', 'http://localhost:5984/') diff --git a/src/couch/exceptions.py b/src/db/couch/exceptions.py index d7e037b..d7e037b 100644 --- a/src/couch/exceptions.py +++ b/src/db/couch/exceptions.py diff --git a/src/couch/feedreader.py b/src/db/couch/feedreader.py index e293932..e293932 100644 --- a/src/couch/feedreader.py +++ b/src/db/couch/feedreader.py diff --git a/src/couch/resource.py b/src/db/couch/resource.py index da1e0dd..8ff883b 100644 --- a/src/couch/resource.py +++ b/src/db/couch/resource.py @@ -5,10 +5,9 @@ from __future__ import unicode_literals import json -import requests -from couch import utils -from couch import exceptions +import requests +from db.couch import exceptions, utils class Resource(object): diff --git a/src/couch/utils.py b/src/db/couch/utils.py index 1cd21d8..1cd21d8 100644 --- a/src/couch/utils.py +++ b/src/db/couch/utils.py @@ -1,18 +1,13 @@ -# A database storing dictionaries, keyed on a timestamp. value = A -# dict which will be stored as a JSON object encoded in UTF-8. Note -# that dict keys of type integer or float will become strings while -# values will keep their type. - -# Note that there's a (slim) chance that you'd stomp on the previous -# value if you're too quick with generating the timestamps, ie -# invoking time.time() several times quickly enough. - import os import sys import time -import couch +from contextlib import contextmanager + +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker -from index import CouchIindex +from db import couch +from db.index import CouchIindex class DictDB(): @@ -137,3 +132,57 @@ class DictDB(): return None return key + + +def get_conn_str(): + try: + dialect = os.environ['SQL_DIALECT'] + database = os.environ['SQL_DATABASE'] + except KeyError: + print('The environment variables SQL_DIALECT and SQL_DATABASE must ' + + 'be set.') + sys.exit(-1) + + if dialect != 'sqlite': + try: + hostname = os.environ['SQL_HOSTNAME'] + username = os.environ['SQL_USERNAME'] + password = os.environ['SQL_PASSWORD'] + except KeyError: + print('The environment variables SQL_DIALECT, SQL_NAME, ' + + 'SQL_HOSTNAME, SQL_USERNAME and SQL_PASSWORD must ' + + 'be set.') + sys.exit(-1) + + if dialect == 'sqlite': + conn_str = f"{dialect}:///{database}.db" + else: + conn_str = f"{dialect}://{username}:{password}@{hostname}" + \ + "/{database}" + + return conn_str + + +class SqlDB(): + def get_session(conn_str): + if 'sqlite' in conn_str: + engine = create_engine(conn_str) + else: + engine = create_engine(conn_str, pool_size=50, max_overflow=0) + Session = sessionmaker(bind=engine) + + return Session() + + @classmethod + @contextmanager + def sql_session(cls, **kwargs): + session = cls.get_session(get_conn_str()) + + try: + yield session + session.commit() + except Exception: + session.rollback() + raise + finally: + session.close() diff --git a/src/index.py b/src/db/index.py index 3541ec7..3541ec7 100644 --- a/src/index.py +++ b/src/db/index.py diff --git a/src/db/scanner.py b/src/db/scanner.py new file mode 100644 index 0000000..625fd8e --- /dev/null +++ b/src/db/scanner.py @@ -0,0 +1,133 @@ +import enum +from datetime import datetime + +from sqlalchemy import (JSON, Boolean, Column, DateTime, Integer, Unicode, + UniqueConstraint, create_engine) +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.declarative import declarative_base + +from db.db import SqlDB, get_conn_str + +Base = declarative_base() +engine = create_engine(get_conn_str()) + + +class Scanner(Base): + __tablename__ = 'scanners' + __table_args__ = ( + None, + UniqueConstraint('id'), + UniqueConstraint('uuid'), + ) + + id = Column(Integer, autoincrement=True, primary_key=True) + uuid = Column(Unicode(37), nullable=False) + enabled = Column(Boolean, default=False, nullable=False) + first_seen = Column(DateTime, default=datetime.utcnow, nullable=False) + last_seen = Column(DateTime, default=datetime.utcnow, + onupdate=datetime.utcnow, nullable=False) + comment = Column(Unicode(255), default="", nullable=True) + scanners = Column(JSON, nullable=True) + + def as_dict(self): + """Return JSON serializable dict.""" + d = {} + for col in self.__table__.columns: + value = getattr(self, col.name) + if issubclass(value.__class__, enum.Enum): + value = value.value + elif issubclass(value.__class__, Base): + continue + elif issubclass(value.__class__, datetime): + value = str(value) + d[col.name] = value + return d + + @classmethod + def comment(cls, uuid, comment): + with SqlDB.sql_session() as session: + scanner: Scanner = session.query(Scanner).filter( + Scanner.uuid == uuid).one_or_none() + + if scanner: + scanner.comment = comment + else: + return None + return None + + @classmethod + def enable(cls, uuid): + with SqlDB.sql_session() as session: + scanner: Scanner = session.query(Scanner).filter( + Scanner.uuid == uuid).one_or_none() + + if scanner: + scanner.enabled = True + else: + return None + return None + + @classmethod + def disable(cls, uuid): + with SqlDB.sql_session() as session: + scanner: Scanner = session.query(Scanner).filter( + Scanner.uuid == uuid).one_or_none() + + if scanner: + scanner.enabled = False + else: + return None + return None + + @classmethod + def get(cls, scanner_id=None, uuid=None): + if scanner_id is None and uuid is None: + raise ValueError('Either scanner_id or uuid must be present.') + + with SqlDB.sql_session() as session: + if scanner_id: + scanner: Scanner = session.query(Scanner).filter( + Scanner.id == scanner_id).one_or_none() + elif uuid: + scanner: Scanner = session.query(Scanner).filter( + Scanner.uuid == uuid).one_or_none() + else: + return None + + if scanner is None: + return None + + return scanner.as_dict() + + return None + + @classmethod + def add(cls, uuid): + try: + with SqlDB.sql_session() as session: + scanner = Scanner() + scanner.uuid = uuid + scanner.enabled = False + scanner.scanners = {} + + session.add(scanner) + session.flush() + + return scanner.id + except IntegrityError: + return None + + @classmethod + def is_enabled(cls, uuid): + with SqlDB.sql_session() as session: + scanner: Scanner = session.query(Scanner).filter( + Scanner.uuid == uuid).one_or_none() + if scanner is None: + return None + + enabled = scanner.enabled + + return enabled + + +Base.metadata.create_all(engine) diff --git a/src/log.py b/src/log.py new file mode 100644 index 0000000..de0a6ea --- /dev/null +++ b/src/log.py @@ -0,0 +1,16 @@ +import logging + + +def get_logger(): + logger = logging.getLogger('soc-collector') + + if not logger.handlers: + formatter = logging.Formatter('%(levelname)s: %(message)s') + + handler = logging.StreamHandler() + handler.setFormatter(formatter) + logger.addHandler(handler) + + logger.setLevel(logging.DEBUG) + + return logger diff --git a/src/main.py b/src/main.py index c3e5ad9..a65971d 100755 --- a/src/main.py +++ b/src/main.py @@ -1,19 +1,20 @@ import os import sys -import uvicorn -from fastapi import FastAPI, Depends, Request +import uvicorn +from fastapi import Depends, FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from fastapi_jwt_auth import AuthJWT from fastapi_jwt_auth.exceptions import AuthJWTException -from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel -from index import CouchIindex -import time -from db import DictDB + +import routers app = FastAPI() +app.include_router(routers.router, prefix='/sc/v0') + app.add_middleware( CORSMiddleware, allow_origins=["http://localhost:8001"], @@ -23,8 +24,6 @@ app.add_middleware( expose_headers=["X-Total-Count"], ) -# TODO: X-Total-Count - @app.middleware("http") async def mock_x_total_count_header(request: Request, call_next): @@ -32,19 +31,6 @@ async def mock_x_total_count_header(request: Request, call_next): response.headers["X-Total-Count"] = "100" return response -for i in range(10): - try: - db = DictDB() - except Exception: - print( - f'Database not responding, will try again soon. Attempt {i + 1} of 10.') - else: - break - time.sleep(10) -else: - print('Database did not respond after 10 attempts, quitting.') - sys.exit(-1) - def get_pubkey(): try: @@ -62,27 +48,6 @@ def get_pubkey(): return pubkey -def get_data(key=None, limit=25, skip=0, ip=None, - port=None, asn=None, domain=None): - if key: - return db.get(key) - - selectors = dict() - indexes = CouchIindex().dict() - selectors['domain'] = domain - - if ip and 'ip' in indexes: - selectors['ip'] = ip - if port and 'port' in indexes: - selectors['port'] = port - if asn and 'asn' in indexes: - selectors['asn'] = asn - - data = db.search(**selectors, limit=limit, skip=skip) - - return data - - class JWTConfig(BaseModel): authjwt_algorithm: str = "ES256" authjwt_public_key: str = get_pubkey() @@ -106,115 +71,6 @@ def app_exception_handler(request: Request, exc: RuntimeError): status_code=400) -@app.get('/sc/v0/get') -async def get(key=None, limit=25, skip=0, ip=None, port=None, - asn=None, Authorize: AuthJWT = Depends()): - - Authorize.jwt_required() - - data = [] - raw_jwt = Authorize.get_raw_jwt() - - if "read" not in raw_jwt: - return JSONResponse( - content={ - "status": "error", - "message": "Could not find read claim in JWT token", - }, - status_code=400, - ) - else: - domains = raw_jwt["read"] - - for domain in domains: - data.extend(get_data(key, limit, skip, ip, port, asn, domain)) - - return JSONResponse(content={"status": "success", "docs": data}) - - -@app.get('/sc/v0/get/{key}') -async def get_key(key=None, Authorize: AuthJWT = Depends()): - - Authorize.jwt_required() - - raw_jwt = Authorize.get_raw_jwt() - - if "read" not in raw_jwt: - return JSONResponse( - content={ - "status": "error", - "message": "Could not find read claim in JWT token", - }, - status_code=400, - ) - else: - allowed_domains = raw_jwt["read"] - - data = get_data(key) - - if data["domain"] not in allowed_domains: - return JSONResponse( - content={ - "status": "error", - "message": "User not authorized to view this object", - }, - status_code=400, - ) - - return JSONResponse(content={"status": "success", "docs": data}) - - -@app.post('/sc/v0/add') -async def add(data: Request, Authorize: AuthJWT = Depends()): - - # Maybe we should protect this enpoint too and let the scanner use - # a JWT token as well. - # Authorize.jwt_required() - - json_data = await data.json() - - key = db.add(json_data) - - return JSONResponse(content={"status": "success", "docs": key}) - - -@app.delete('/sc/v0/delete/{key}') -async def delete(key, Authorize: AuthJWT = Depends()): - - Authorize.jwt_required() - - raw_jwt = Authorize.get_raw_jwt() - - if "write" not in raw_jwt: - return JSONResponse( - content={ - "status": "error", - "message": "Could not find write claim in JWT token", - }, - status_code=400, - ) - else: - allowed_domains = raw_jwt["write"] - - data = get_data(key) - - if data["domain"] not in allowed_domains: - return JSONResponse( - content={ - "status": "error", - "message": "User not authorized to delete this object", - }, - status_code=400, - ) - - if db.delete(key) is None: - return JSONResponse(content={"status": "error", - "message": "Document not found"}, - status_code=400) - - return JSONResponse(content={"status": "success", "docs": data}) - - def main(standalone=False): if not standalone: return app diff --git a/src/routers/__init__.py b/src/routers/__init__.py new file mode 100644 index 0000000..300ed3a --- /dev/null +++ b/src/routers/__init__.py @@ -0,0 +1,9 @@ +from fastapi import APIRouter + +from .collector import router as collector_router +from .scanner import router as scanner_router + +router = APIRouter() + +router.include_router(collector_router, tags=['collector']) +router.include_router(scanner_router, tags=['scanner']) diff --git a/src/routers/collector.py b/src/routers/collector.py new file mode 100644 index 0000000..7d91609 --- /dev/null +++ b/src/routers/collector.py @@ -0,0 +1,152 @@ +import sys +import time + +import requests +from db.db import DictDB +from db.index import CouchIindex +from fastapi import APIRouter, Depends, Request +from fastapi.responses import JSONResponse +from fastapi_jwt_auth import AuthJWT + +router = APIRouter() + +for i in range(10): + try: + db = DictDB() + except requests.exceptions.ConnectionError: + print('Database not responding, will try again soon.' + + f'Attempt {i + 1} of 10.') + else: + break + time.sleep(10) +else: + print('Database did not respond after 10 attempts, quitting.') + sys.exit(-1) + + +def get_data(key=None, limit=25, skip=0, ip=None, + port=None, asn=None, domain=None): + if key: + return db.get(key) + + selectors = dict() + indexes = CouchIindex().dict() + selectors['domain'] = domain + + if ip and 'ip' in indexes: + selectors['ip'] = ip + if port and 'port' in indexes: + selectors['port'] = port + if asn and 'asn' in indexes: + selectors['asn'] = asn + + data = db.search(**selectors, limit=limit, skip=skip) + + return data + + +@router.get('/get') +async def get(key=None, limit=25, skip=0, ip=None, port=None, + asn=None, Authorize: AuthJWT = Depends()): + + Authorize.jwt_required() + + data = [] + raw_jwt = Authorize.get_raw_jwt() + + if "read" not in raw_jwt: + return JSONResponse( + content={ + "status": "error", + "message": "Could not find read claim in JWT token", + }, + status_code=400, + ) + else: + domains = raw_jwt["read"] + + for domain in domains: + data.extend(get_data(key, limit, skip, ip, port, asn, domain)) + + return JSONResponse(content={"status": "success", "docs": data}) + + +@router.get('/get/{key}') +async def get_key(key=None, Authorize: AuthJWT = Depends()): + + Authorize.jwt_required() + + raw_jwt = Authorize.get_raw_jwt() + + if "read" not in raw_jwt: + return JSONResponse( + content={ + "status": "error", + "message": "Could not find read claim in JWT token", + }, + status_code=400, + ) + else: + allowed_domains = raw_jwt["read"] + + data = get_data(key) + + if data["domain"] not in allowed_domains: + return JSONResponse( + content={ + "status": "error", + "message": "User not authorized to view this object", + }, + status_code=400, + ) + + return JSONResponse(content={"status": "success", "docs": data}) + + +@router.post('/add') +async def add(data: Request, Authorize: AuthJWT = Depends()): + + Authorize.jwt_required() + + json_data = await data.json() + + key = db.add(json_data) + + return JSONResponse(content={"status": "success", "docs": key}) + + +@router.delete('/delete/{key}') +async def delete(key, Authorize: AuthJWT = Depends()): + + Authorize.jwt_required() + + raw_jwt = Authorize.get_raw_jwt() + + if "write" not in raw_jwt: + return JSONResponse( + content={ + "status": "error", + "message": "Could not find write claim in JWT token", + }, + status_code=400, + ) + else: + allowed_domains = raw_jwt["write"] + + data = get_data(key) + + if data["domain"] not in allowed_domains: + return JSONResponse( + content={ + "status": "error", + "message": "User not authorized to delete this object", + }, + status_code=400, + ) + + if db.delete(key) is None: + return JSONResponse(content={"status": "error", + "message": "Document not found"}, + status_code=400) + + return JSONResponse(content={"status": "success", "docs": data}) diff --git a/src/routers/scanner.py b/src/routers/scanner.py new file mode 100644 index 0000000..645cd74 --- /dev/null +++ b/src/routers/scanner.py @@ -0,0 +1,93 @@ +from uuid import UUID + +from db.scanner import Scanner +from fastapi import APIRouter, Depends, Request +from fastapi.responses import JSONResponse +from fastapi_jwt_auth import AuthJWT + +router = APIRouter() + + +@router.post('/scanner/{uuid}') +async def scanner(uuid, data: Request, Authorize: AuthJWT = Depends()): + Authorize.jwt_required() + + json_data = await data.json() + + if not Scanner.get(uuid=uuid): + return JSONResponse(content={"status": "error", + "message": "Scanner don't exist."}, + status_code=400) + + if 'targets' in json_data: + if isinstance(json_data['targets'], str): + Scanner.comment(uuid, json_data['targets']) + else: + return JSONResponse(content={"status": "error", + "message": "Targets should be a string."}, + status_code=400) + if 'scanner' in json_data: + if isinstance(json_data['comment'], str): + Scanner.comment(uuid, json_data['scanner']) + else: + return JSONResponse(content={"status": "error", + "message": "Scanner should be a string."}, + status_code=400) + if 'comment' in json_data: + if isinstance(json_data['comment'], str): + Scanner.comment(uuid, json_data['comment']) + else: + return JSONResponse(content={"status": "error", + "message": "Comment should be a string."}, + status_code=400) + if 'enabled' in json_data: + if isinstance(json_data['enabled'], bool): + if json_data['enabled'] is True: + Scanner.enable(uuid) + elif json_data['enabled'] is False: + Scanner.disable(uuid) + else: + return JSONResponse(content={"status": "error", + "message": "Enabled should be boolean."}, + status_code=400) + + +@router.get('/callhome/{uuid}') +async def callhome(uuid, data: Request, Authorize: AuthJWT = Depends()): + Authorize.jwt_required() + + raw_jwt = Authorize.get_raw_jwt() + + if 'user' not in raw_jwt or raw_jwt['user'] != "scanner": + return JSONResponse(content={"status": "error", + "message": "Invalid token type."}, + status_code=400) + + try: + UUID(uuid).version + except ValueError: + return JSONResponse(content={"status": "error", + "message": "Invalid UUID."}, + status_code=400) + + scanner_data = Scanner.get(uuid=uuid) + + if scanner_data: + if not Scanner.is_enabled(uuid): + return JSONResponse(content={"status": "error", + "message": "Scanner disabled."}, + status_code=400) + else: + return JSONResponse(content={"status": "success", + "data": scanner_data}, + status_code=200) + + else: + if Scanner.add(uuid): + return JSONResponse(content={"status": "error", + "message": "Scanner added but disabled."}, + status_code=400) + else: + return JSONResponse(content={"status": "error", + "message": "Failed to add scanner."}, + status_code=400) diff --git a/tools/jwt_producer.py b/tools/jwt_producer.py index 3f8094d..ea033a6 100644 --- a/tools/jwt_producer.py +++ b/tools/jwt_producer.py @@ -1,20 +1,23 @@ +import getopt import sys + import jwt -import getopt def usage(): progname = sys.argv[0] - print(f'{progname} -p <path to public key> -s <path to private key>' + - '-d <domain>, for example sunet.se>') + print(f'{progname} -p <path to private key> ' + + '-d <domain, for example sunet.se> ', + '-t <type, can be access or scanner>') sys.exit(0) -def create_token(private_key, domain): +def create_token(private_key, token_type, domain): payload = { 'type': 'access', - 'domains': [domain] # We'll just do one domain now + 'domains': [domain], # We'll just do one domain now + 'user': token_type } with open(private_key, "r") as fd: @@ -25,11 +28,11 @@ def create_token(private_key, domain): if __name__ == '__main__': try: - opts, args = getopt.getopt(sys.argv[1:], 'p:d:') + opts, args = getopt.getopt(sys.argv[1:], 'p:d:t:') except getopt.GetoptError: usage() - if len(sys.argv) != 5: + if len(sys.argv) != 7: usage() for opt, arg in opts: @@ -37,9 +40,14 @@ if __name__ == '__main__': private_key = arg elif opt == '-d': domain = arg + elif opt == '-t': + token_type = arg + + if token_type != "access" and token_type != "scanner": + usage() else: usage() - token = create_token(private_key, domain).decode('utf-8') + token = create_token(private_key, token_type, domain).decode('utf-8') print(f'{token}') |