diff options
author | Kristofer Hallin <kristofer@sunet.se> | 2022-01-05 11:46:15 +0100 |
---|---|---|
committer | Kristofer Hallin <kristofer@sunet.se> | 2022-01-05 11:46:15 +0100 |
commit | 571997129ba5275cc5e148a8ac1c0f64d895a9ef (patch) | |
tree | 607fd13bbbf5ac38f416da8172e89a5d3331d1d8 | |
parent | 09677d03635da2b799cf117b2127c3b197a8babf (diff) |
Added database and API endpoints for scanners.
-rwxr-xr-x | src/db/db.py | 58 | ||||
-rw-r--r-- | src/db/scanner.py | 87 | ||||
-rw-r--r-- | src/routers/scanner.py | 85 |
3 files changed, 224 insertions, 6 deletions
diff --git a/src/db/db.py b/src/db/db.py index 511748c..cbb87ce 100755 --- a/src/db/db.py +++ b/src/db/db.py @@ -10,6 +10,10 @@ import os import sys import time +from contextlib import contextmanager + +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker from db import couch from db.index import CouchIindex @@ -137,3 +141,57 @@ class DictDB(): return None return key + + +def get_conn_str(): + try: + dialect = os.environ['SQL_DIALECT'] + database = os.environ['SQL_DATABASE'] + except KeyError: + print('The environment variables SQL_DIALECT and SQL_DATABASE must ' + + 'be set.') + sys.exit(-1) + + if dialect != 'sqlite': + try: + hostname = os.environ['SQL_HOSTNAME'] + username = os.environ['SQL_USERNAME'] + password = os.environ['SQL_PASSWORD'] + except KeyError: + print('The environment variables SQL_DIALECT, SQL_NAME, ' + + 'SQL_HOSTNAME, SQL_USERNAME and SQL_PASSWORD must ' + + 'be set.') + sys.exit(-1) + + if dialect == 'sqlite': + conn_str = f"{dialect}:///{database}.db" + else: + conn_str = f"{dialect}://{username}:{password}@{hostname}" + \ + "/{database}" + + return conn_str + + +class SqlDB(): + def get_session(conn_str): + if 'sqlite' in conn_str: + engine = create_engine(conn_str) + else: + engine = create_engine(conn_str, pool_size=50, max_overflow=0) + Session = sessionmaker(bind=engine) + + return Session() + + @classmethod + @contextmanager + def sql_session(cls, **kwargs): + session = cls.get_session(get_conn_str()) + + try: + yield session + session.commit() + except Exception: + session.rollback() + raise + finally: + session.close() diff --git a/src/db/scanner.py b/src/db/scanner.py index 714551f..e9ac8c3 100644 --- a/src/db/scanner.py +++ b/src/db/scanner.py @@ -3,9 +3,10 @@ from datetime import datetime from sqlalchemy import (Boolean, Column, DateTime, Integer, Unicode, UniqueConstraint, create_engine) +from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.declarative import declarative_base -from db import SqlDB, get_conn_str +from db.db import SqlDB, get_conn_str Base = declarative_base() engine = create_engine(get_conn_str()) @@ -16,6 +17,7 @@ class Scanner(Base): __table_args__ = ( None, UniqueConstraint('id'), + UniqueConstraint('uuid'), ) id = Column(Integer, autoincrement=True, primary_key=True) @@ -26,7 +28,7 @@ class Scanner(Base): onupdate=datetime.utcnow, nullable=False) comment = Column(Unicode(255), nullable=True) scanners = Column(Unicode(2048), nullable=False) - target = Column(Unicode(255), nullable=True) + targets = Column(Unicode(255), nullable=True) def as_dict(self): """Return JSON serializable dict.""" @@ -42,5 +44,86 @@ class Scanner(Base): d[col.name] = value return d + @classmethod + def comment(cls, uuid, comment): + with SqlDB.sql_session() as session: + scanner: Scanner = session.query(Scanner).filter(Scanner.uuid == uuid).one_or_none() + + if scanner: + scanner.comment = comment + else: + return None + return None + + @classmethod + def enable(cls, uuid): + with SqlDB.sql_session() as session: + scanner: Scanner = session.query(Scanner).filter(Scanner.uuid == uuid).one_or_none() + + if scanner: + scanner.enabled = True + else: + return None + return None + + @classmethod + def disable(cls, uuid): + with SqlDB.sql_session() as session: + scanner: Scanner = session.query(Scanner).filter(Scanner.uuid == uuid).one_or_none() + + if scanner: + scanner.enabled = False + else: + return None + return None + + @classmethod + def get(cls, scanner_id=None, uuid=None): + if scanner_id is None and uuid is None: + raise ValueError('Either scanner_id or uuid must be present.') + + with SqlDB.sql_session() as session: + if scanner_id: + scanner: Scanner = session.query(Scanner).filter(Scanner.id == scanner_id).one_or_none() + elif uuid: + scanner: Scanner = session.query(Scanner).filter(Scanner.uuid == uuid).one_or_none() + else: + return None + + if scanner is None: + return None + + return scanner.as_dict() + + return None + + @classmethod + def add(cls, uuid): + try: + with SqlDB.sql_session() as session: + scanner = Scanner() + scanner.uuid = uuid + scanner.enabled = False + scanner.scanners = "None" + + session.add(scanner) + session.flush() + + return scanner.id + except IntegrityError: + return None + + + @classmethod + def is_enabled(cls, uuid): + with SqlDB.sql_session() as session: + scanner: Scanner = session.query(Scanner).filter(Scanner.uuid == uuid).one_or_none() + if scanner is None: + return None + + enabled = scanner.enabled + + return enabled + Base.metadata.create_all(engine) diff --git a/src/routers/scanner.py b/src/routers/scanner.py index 956153b..9bb0f98 100644 --- a/src/routers/scanner.py +++ b/src/routers/scanner.py @@ -1,3 +1,6 @@ +from uuid import UUID + +from db.scanner import Scanner from fastapi import APIRouter, Depends, Request from fastapi.responses import JSONResponse from fastapi_jwt_auth import AuthJWT @@ -5,12 +8,86 @@ from fastapi_jwt_auth import AuthJWT router = APIRouter() -@router.get('/callhome') -async def callhome(data: Request, Authorize: AuthJWT = Depends()): +@router.post('/scanner/{uuid}') +async def scanner(uuid, data: Request, Authorize: AuthJWT = Depends()): Authorize.jwt_required() json_data = await data.json() - if 'uuid' not in json_data: + if not Scanner.get(uuid=uuid): return JSONResponse(content={"status": "error", - "message": "UUID missing"}) + "message": "Scanner don't exist."}, + status_code=400) + + if 'targets' in json_data: + if isinstance(json_data['targets'], str): + Scanner.comment(uuid, json_data['targets']) + else: + return JSONResponse(content={"status": "error", + "message": "Targets should be a string."}, + status_code=400) + if 'scanner' in json_data: + if isinstance(json_data['comment'], str): + Scanner.comment(uuid, json_data['scanner']) + else: + return JSONResponse(content={"status": "error", + "message": "Scanner should be a string."}, + status_code=400) + if 'comment' in json_data: + if isinstance(json_data['comment'], str): + Scanner.comment(uuid, json_data['comment']) + else: + return JSONResponse(content={"status": "error", + "message": "Comment should be a string."}, + status_code=400) + if 'enabled' in json_data: + if isinstance(json_data['enabled'], bool): + if json_data['enabled'] is True: + Scanner.enable(uuid) + elif json_data['enabled'] is False: + Scanner.disable(uuid) + else: + return JSONResponse(content={"status": "error", + "message": "Enabled should be boolean."}, + status_code=400) + + +@router.get('/callhome/{uuid}') +async def callhome(uuid, data: Request, Authorize: AuthJWT = Depends()): + Authorize.jwt_required() + + raw_jwt = Authorize.get_raw_jwt() + + if 'user' not in raw_jwt or raw_jwt['user'] != "scanner": + return JSONResponse(content={"status": "error", + "message": "Invalid token type."}, + status_code=400) + + try: + UUID(uuid).version + except ValueError: + return JSONResponse(content={"status": "error", + "message": "Invalid UUID."}, + status_code=400) + + scanner_data = Scanner.get(uuid=uuid) + + if scanner_data: + if not Scanner.is_enabled(uuid): + return JSONResponse(content={"status": "error", + "message": "Scanner disabled."}, + status_code=400) + else: + return JSONResponse(content={"status": "success", + "data": scanner_data}, + status_code=200) + + else: + if Scanner.add(uuid): + return JSONResponse(content={"status": "success", + "message": "Scanner added."}, + status_code=200) + else: + return JSONResponse(content={"status": "error", + "message": "Failed to add scanner."}, + status_code=400) |