summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKristofer Hallin <kristofer@sunet.se>2022-01-05 11:46:15 +0100
committerKristofer Hallin <kristofer@sunet.se>2022-01-05 11:46:15 +0100
commit571997129ba5275cc5e148a8ac1c0f64d895a9ef (patch)
tree607fd13bbbf5ac38f416da8172e89a5d3331d1d8
parent09677d03635da2b799cf117b2127c3b197a8babf (diff)
Added database and API endpoints for scanners.
-rwxr-xr-xsrc/db/db.py58
-rw-r--r--src/db/scanner.py87
-rw-r--r--src/routers/scanner.py85
3 files changed, 224 insertions, 6 deletions
diff --git a/src/db/db.py b/src/db/db.py
index 511748c..cbb87ce 100755
--- a/src/db/db.py
+++ b/src/db/db.py
@@ -10,6 +10,10 @@
import os
import sys
import time
+from contextlib import contextmanager
+
+from sqlalchemy import create_engine
+from sqlalchemy.orm import sessionmaker
from db import couch
from db.index import CouchIindex
@@ -137,3 +141,57 @@ class DictDB():
return None
return key
+
+
+def get_conn_str():
+ try:
+ dialect = os.environ['SQL_DIALECT']
+ database = os.environ['SQL_DATABASE']
+ except KeyError:
+ print('The environment variables SQL_DIALECT and SQL_DATABASE must ' +
+ 'be set.')
+ sys.exit(-1)
+
+ if dialect != 'sqlite':
+ try:
+ hostname = os.environ['SQL_HOSTNAME']
+ username = os.environ['SQL_USERNAME']
+ password = os.environ['SQL_PASSWORD']
+ except KeyError:
+ print('The environment variables SQL_DIALECT, SQL_NAME, ' +
+ 'SQL_HOSTNAME, SQL_USERNAME and SQL_PASSWORD must ' +
+ 'be set.')
+ sys.exit(-1)
+
+ if dialect == 'sqlite':
+ conn_str = f"{dialect}:///{database}.db"
+ else:
+ conn_str = f"{dialect}://{username}:{password}@{hostname}" + \
+ "/{database}"
+
+ return conn_str
+
+
+class SqlDB():
+ def get_session(conn_str):
+ if 'sqlite' in conn_str:
+ engine = create_engine(conn_str)
+ else:
+ engine = create_engine(conn_str, pool_size=50, max_overflow=0)
+ Session = sessionmaker(bind=engine)
+
+ return Session()
+
+ @classmethod
+ @contextmanager
+ def sql_session(cls, **kwargs):
+ session = cls.get_session(get_conn_str())
+
+ try:
+ yield session
+ session.commit()
+ except Exception:
+ session.rollback()
+ raise
+ finally:
+ session.close()
diff --git a/src/db/scanner.py b/src/db/scanner.py
index 714551f..e9ac8c3 100644
--- a/src/db/scanner.py
+++ b/src/db/scanner.py
@@ -3,9 +3,10 @@ from datetime import datetime
from sqlalchemy import (Boolean, Column, DateTime, Integer, Unicode,
UniqueConstraint, create_engine)
+from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.declarative import declarative_base
-from db import SqlDB, get_conn_str
+from db.db import SqlDB, get_conn_str
Base = declarative_base()
engine = create_engine(get_conn_str())
@@ -16,6 +17,7 @@ class Scanner(Base):
__table_args__ = (
None,
UniqueConstraint('id'),
+ UniqueConstraint('uuid'),
)
id = Column(Integer, autoincrement=True, primary_key=True)
@@ -26,7 +28,7 @@ class Scanner(Base):
onupdate=datetime.utcnow, nullable=False)
comment = Column(Unicode(255), nullable=True)
scanners = Column(Unicode(2048), nullable=False)
- target = Column(Unicode(255), nullable=True)
+ targets = Column(Unicode(255), nullable=True)
def as_dict(self):
"""Return JSON serializable dict."""
@@ -42,5 +44,86 @@ class Scanner(Base):
d[col.name] = value
return d
+ @classmethod
+ def comment(cls, uuid, comment):
+ with SqlDB.sql_session() as session:
+ scanner: Scanner = session.query(Scanner).filter(Scanner.uuid == uuid).one_or_none()
+
+ if scanner:
+ scanner.comment = comment
+ else:
+ return None
+ return None
+
+ @classmethod
+ def enable(cls, uuid):
+ with SqlDB.sql_session() as session:
+ scanner: Scanner = session.query(Scanner).filter(Scanner.uuid == uuid).one_or_none()
+
+ if scanner:
+ scanner.enabled = True
+ else:
+ return None
+ return None
+
+ @classmethod
+ def disable(cls, uuid):
+ with SqlDB.sql_session() as session:
+ scanner: Scanner = session.query(Scanner).filter(Scanner.uuid == uuid).one_or_none()
+
+ if scanner:
+ scanner.enabled = False
+ else:
+ return None
+ return None
+
+ @classmethod
+ def get(cls, scanner_id=None, uuid=None):
+ if scanner_id is None and uuid is None:
+ raise ValueError('Either scanner_id or uuid must be present.')
+
+ with SqlDB.sql_session() as session:
+ if scanner_id:
+ scanner: Scanner = session.query(Scanner).filter(Scanner.id == scanner_id).one_or_none()
+ elif uuid:
+ scanner: Scanner = session.query(Scanner).filter(Scanner.uuid == uuid).one_or_none()
+ else:
+ return None
+
+ if scanner is None:
+ return None
+
+ return scanner.as_dict()
+
+ return None
+
+ @classmethod
+ def add(cls, uuid):
+ try:
+ with SqlDB.sql_session() as session:
+ scanner = Scanner()
+ scanner.uuid = uuid
+ scanner.enabled = False
+ scanner.scanners = "None"
+
+ session.add(scanner)
+ session.flush()
+
+ return scanner.id
+ except IntegrityError:
+ return None
+
+
+ @classmethod
+ def is_enabled(cls, uuid):
+ with SqlDB.sql_session() as session:
+ scanner: Scanner = session.query(Scanner).filter(Scanner.uuid == uuid).one_or_none()
+ if scanner is None:
+ return None
+
+ enabled = scanner.enabled
+
+ return enabled
+
Base.metadata.create_all(engine)
diff --git a/src/routers/scanner.py b/src/routers/scanner.py
index 956153b..9bb0f98 100644
--- a/src/routers/scanner.py
+++ b/src/routers/scanner.py
@@ -1,3 +1,6 @@
+from uuid import UUID
+
+from db.scanner import Scanner
from fastapi import APIRouter, Depends, Request
from fastapi.responses import JSONResponse
from fastapi_jwt_auth import AuthJWT
@@ -5,12 +8,86 @@ from fastapi_jwt_auth import AuthJWT
router = APIRouter()
-@router.get('/callhome')
-async def callhome(data: Request, Authorize: AuthJWT = Depends()):
+@router.post('/scanner/{uuid}')
+async def scanner(uuid, data: Request, Authorize: AuthJWT = Depends()):
Authorize.jwt_required()
json_data = await data.json()
- if 'uuid' not in json_data:
+ if not Scanner.get(uuid=uuid):
return JSONResponse(content={"status": "error",
- "message": "UUID missing"})
+ "message": "Scanner don't exist."},
+ status_code=400)
+
+ if 'targets' in json_data:
+ if isinstance(json_data['targets'], str):
+ Scanner.comment(uuid, json_data['targets'])
+ else:
+ return JSONResponse(content={"status": "error",
+ "message": "Targets should be a string."},
+ status_code=400)
+ if 'scanner' in json_data:
+ if isinstance(json_data['comment'], str):
+ Scanner.comment(uuid, json_data['scanner'])
+ else:
+ return JSONResponse(content={"status": "error",
+ "message": "Scanner should be a string."},
+ status_code=400)
+ if 'comment' in json_data:
+ if isinstance(json_data['comment'], str):
+ Scanner.comment(uuid, json_data['comment'])
+ else:
+ return JSONResponse(content={"status": "error",
+ "message": "Comment should be a string."},
+ status_code=400)
+ if 'enabled' in json_data:
+ if isinstance(json_data['enabled'], bool):
+ if json_data['enabled'] is True:
+ Scanner.enable(uuid)
+ elif json_data['enabled'] is False:
+ Scanner.disable(uuid)
+ else:
+ return JSONResponse(content={"status": "error",
+ "message": "Enabled should be boolean."},
+ status_code=400)
+
+
+@router.get('/callhome/{uuid}')
+async def callhome(uuid, data: Request, Authorize: AuthJWT = Depends()):
+ Authorize.jwt_required()
+
+ raw_jwt = Authorize.get_raw_jwt()
+
+ if 'user' not in raw_jwt or raw_jwt['user'] != "scanner":
+ return JSONResponse(content={"status": "error",
+ "message": "Invalid token type."},
+ status_code=400)
+
+ try:
+ UUID(uuid).version
+ except ValueError:
+ return JSONResponse(content={"status": "error",
+ "message": "Invalid UUID."},
+ status_code=400)
+
+ scanner_data = Scanner.get(uuid=uuid)
+
+ if scanner_data:
+ if not Scanner.is_enabled(uuid):
+ return JSONResponse(content={"status": "error",
+ "message": "Scanner disabled."},
+ status_code=400)
+ else:
+ return JSONResponse(content={"status": "success",
+ "data": scanner_data},
+ status_code=200)
+
+ else:
+ if Scanner.add(uuid):
+ return JSONResponse(content={"status": "success",
+ "message": "Scanner added."},
+ status_code=200)
+ else:
+ return JSONResponse(content={"status": "error",
+ "message": "Failed to add scanner."},
+ status_code=400)