diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/db/sql.py | 155 | ||||
-rwxr-xr-x | src/main.py | 10 |
2 files changed, 165 insertions, 0 deletions
diff --git a/src/db/sql.py b/src/db/sql.py new file mode 100644 index 0000000..f2a9ee2 --- /dev/null +++ b/src/db/sql.py @@ -0,0 +1,155 @@ +import datetime +import os +import sys +from contextlib import contextmanager + +from sqlalchemy import (Boolean, Column, Date, Integer, String, Text, + create_engine, text) +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker + +Base = declarative_base() +metadata = Base.metadata + + +class Log(Base): + __tablename__ = "log" + + id = Column(Integer, primary_key=True) + timestamp = Column(Date, nullable=False, default=datetime.datetime.utcnow) + username = Column(Text, nullable=False) + logtext = Column(Text, nullable=False) + + def as_dict(self): + """Return JSON serializable dict.""" + d = {} + for col in self.__table__.columns: + value = getattr(self, col.name) + if issubclass(value.__class__, Base): + continue + elif issubclass(value.__class__, datetime.datetime): + value = str(value) + d[col.name] = value + return d + + @classmethod + def add(cls, username, logtext): + with sqla_session() as session: + logentry = Log() + logentry.username = username + logentry.logtext = logtext + session.add(logentry) + + +class Scanner(Base): + __tablename__ = "scanner" + + id = Column(Integer, primary_key=True) + runner = Column(Text, nullable=False, default="*") + name = Column(String(128), nullable=False, unique=True) + active = Column(Boolean, nullable=False) + interval = Column(Integer, nullable=False, server_default=text("300")) + starttime = Column(Date) + hostname = Column(String(128), nullable=False, unique=True) + port = Column(Integer, nullable=False) + maxruns = Column(Integer, nullable=False, default=1) + + def as_dict(self): + d = {} + for col in self.__table__.columns: + value = getattr(self, col.name) + if issubclass(value.__class__, Base): + continue + elif issubclass(value.__class__, datetime.datetime): + value = str(value) + d[col.name] = value + return d + + @classmethod + def add(cls, name, hostname, port, active=False, interval=0, + starttime=None, + endtime=None, + maxruns=1): + errors = list() + if starttime and endtime: + if starttime > endtime: + errors.append("Endtime must be after the starttime.") + if interval < 0: + errors.append("Interval must be > 0") + if maxruns < 0: + errors.append("Max runs must be > 0") + with sqla_session() as session: + scanentry = Scanner() + scanentry.name = name + scanentry.active = active + scanentry.interval = interval + if starttime: + scanentry.starttime = starttime + if endtime: + scanentry.endtime = endtime + scanentry.maxruns = maxruns + scanentry.hostname = hostname + scanentry.port = port + session.add(scanentry) + return errors + + @classmethod + def get(cls, name): + results = list() + with sqla_session() as session: + scanners = session.query(Scanner).all() + if not scanners: + return None + for scanner in scanners: + if scanner.runner == "*": + results.append(scanner.as_dict()) + elif scanner.runner == name: + results.append(scanner.as_dict()) + return results + + +def get_sqlalchemy_conn_str(**kwargs) -> str: + try: + if "SQL_HOSTNAME" in os.environ: + hostname = os.environ["SQL_HOSTNAME"] + else: + hostname = "localhost" + print("SQL_HOSTNAME not set, falling back to localhost.") + if "SQL_PORT" in os.environ: + port = os.environ["SQL_PORT"] + else: + print("SQL_PORT not set, falling back to 5432.") + port = 5432 + username = os.environ["SQL_USERNAME"] + password = os.environ["SQL_PASSWORD"] + database = os.environ["SQL_DATABASE"] + except KeyError: + print("SQL_DATABASE, SQL_USERNAME, SQL_PASSWORD must be set.") + sys.exit(-2) + + return ( + f"postgresql://{username}:{password}@{hostname}:{port}/{database}" + ) + + +def get_session(conn_str=""): + if conn_str == "": + conn_str = get_sqlalchemy_conn_str() + + engine = create_engine(conn_str, pool_size=50, max_overflow=0) + Session = sessionmaker(bind=engine) + + return Session() + + +@contextmanager +def sqla_session(conn_str="", **kwargs): + session = get_session(conn_str) + try: + yield session + session.commit() + except Exception: + session.rollback() + raise + finally: + session.close() diff --git a/src/main.py b/src/main.py index d3d02fd..c93029f 100755 --- a/src/main.py +++ b/src/main.py @@ -12,6 +12,7 @@ from pydantic import BaseModel from db.dictionary import DictDB from db.index import CouchIindex +from db.sql import Log, Scanner app = FastAPI() @@ -216,6 +217,15 @@ async def delete(key, Authorize: AuthJWT = Depends()): return JSONResponse(content={"status": "success", "docs": data}) +@app.get("/sc/v0/scanner/{name}") +async def scanner_get(name, data: Request, Authorize: AuthJWT = Depends()): + Authorize.jwt_required() + + scanners = Scanner.get(name) + + return JSONResponse(content={"status": "success", "data": scanners}) + + def main(standalone=False): if not standalone: return app |