summaryrefslogtreecommitdiff
path: root/src/db/sql.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/db/sql.py')
-rw-r--r--src/db/sql.py155
1 files changed, 155 insertions, 0 deletions
diff --git a/src/db/sql.py b/src/db/sql.py
new file mode 100644
index 0000000..f2a9ee2
--- /dev/null
+++ b/src/db/sql.py
@@ -0,0 +1,155 @@
+import datetime
+import os
+import sys
+from contextlib import contextmanager
+
+from sqlalchemy import (Boolean, Column, Date, Integer, String, Text,
+ create_engine, text)
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import sessionmaker
+
+Base = declarative_base()
+metadata = Base.metadata
+
+
+class Log(Base):
+ __tablename__ = "log"
+
+ id = Column(Integer, primary_key=True)
+ timestamp = Column(Date, nullable=False, default=datetime.datetime.utcnow)
+ username = Column(Text, nullable=False)
+ logtext = Column(Text, nullable=False)
+
+ def as_dict(self):
+ """Return JSON serializable dict."""
+ d = {}
+ for col in self.__table__.columns:
+ value = getattr(self, col.name)
+ if issubclass(value.__class__, Base):
+ continue
+ elif issubclass(value.__class__, datetime.datetime):
+ value = str(value)
+ d[col.name] = value
+ return d
+
+ @classmethod
+ def add(cls, username, logtext):
+ with sqla_session() as session:
+ logentry = Log()
+ logentry.username = username
+ logentry.logtext = logtext
+ session.add(logentry)
+
+
+class Scanner(Base):
+ __tablename__ = "scanner"
+
+ id = Column(Integer, primary_key=True)
+ runner = Column(Text, nullable=False, default="*")
+ name = Column(String(128), nullable=False, unique=True)
+ active = Column(Boolean, nullable=False)
+ interval = Column(Integer, nullable=False, server_default=text("300"))
+ starttime = Column(Date)
+ hostname = Column(String(128), nullable=False, unique=True)
+ port = Column(Integer, nullable=False)
+ maxruns = Column(Integer, nullable=False, default=1)
+
+ def as_dict(self):
+ d = {}
+ for col in self.__table__.columns:
+ value = getattr(self, col.name)
+ if issubclass(value.__class__, Base):
+ continue
+ elif issubclass(value.__class__, datetime.datetime):
+ value = str(value)
+ d[col.name] = value
+ return d
+
+ @classmethod
+ def add(cls, name, hostname, port, active=False, interval=0,
+ starttime=None,
+ endtime=None,
+ maxruns=1):
+ errors = list()
+ if starttime and endtime:
+ if starttime > endtime:
+ errors.append("Endtime must be after the starttime.")
+ if interval < 0:
+ errors.append("Interval must be > 0")
+ if maxruns < 0:
+ errors.append("Max runs must be > 0")
+ with sqla_session() as session:
+ scanentry = Scanner()
+ scanentry.name = name
+ scanentry.active = active
+ scanentry.interval = interval
+ if starttime:
+ scanentry.starttime = starttime
+ if endtime:
+ scanentry.endtime = endtime
+ scanentry.maxruns = maxruns
+ scanentry.hostname = hostname
+ scanentry.port = port
+ session.add(scanentry)
+ return errors
+
+ @classmethod
+ def get(cls, name):
+ results = list()
+ with sqla_session() as session:
+ scanners = session.query(Scanner).all()
+ if not scanners:
+ return None
+ for scanner in scanners:
+ if scanner.runner == "*":
+ results.append(scanner.as_dict())
+ elif scanner.runner == name:
+ results.append(scanner.as_dict())
+ return results
+
+
+def get_sqlalchemy_conn_str(**kwargs) -> str:
+ try:
+ if "SQL_HOSTNAME" in os.environ:
+ hostname = os.environ["SQL_HOSTNAME"]
+ else:
+ hostname = "localhost"
+ print("SQL_HOSTNAME not set, falling back to localhost.")
+ if "SQL_PORT" in os.environ:
+ port = os.environ["SQL_PORT"]
+ else:
+ print("SQL_PORT not set, falling back to 5432.")
+ port = 5432
+ username = os.environ["SQL_USERNAME"]
+ password = os.environ["SQL_PASSWORD"]
+ database = os.environ["SQL_DATABASE"]
+ except KeyError:
+ print("SQL_DATABASE, SQL_USERNAME, SQL_PASSWORD must be set.")
+ sys.exit(-2)
+
+ return (
+ f"postgresql://{username}:{password}@{hostname}:{port}/{database}"
+ )
+
+
+def get_session(conn_str=""):
+ if conn_str == "":
+ conn_str = get_sqlalchemy_conn_str()
+
+ engine = create_engine(conn_str, pool_size=50, max_overflow=0)
+ Session = sessionmaker(bind=engine)
+
+ return Session()
+
+
+@contextmanager
+def sqla_session(conn_str="", **kwargs):
+ session = get_session(conn_str)
+ try:
+ yield session
+ session.commit()
+ except Exception:
+ session.rollback()
+ raise
+ finally:
+ session.close()