diff options
author | Kristofer Hallin <kristofer@sunet.se> | 2022-04-08 14:10:45 +0200 |
---|---|---|
committer | Kristofer Hallin <kristofer@sunet.se> | 2022-04-08 14:10:45 +0200 |
commit | 3335f65e5e4b3132a72b46b99e50d3c55c0c58b5 (patch) | |
tree | 23f6e3edb708cc08e0e7a5e1f6a5b912e9af3ff5 | |
parent | 208089fa95e63d6e29e7a1d86726bfec804de211 (diff) |
Random SQL stuff for handling scanners.
New API endpoint for scanners.
-rw-r--r-- | docker/postgres/schema.sql | 10 | ||||
-rw-r--r-- | src/db/sql.py | 155 | ||||
-rwxr-xr-x | src/main.py | 10 |
3 files changed, 171 insertions, 4 deletions
diff --git a/docker/postgres/schema.sql b/docker/postgres/schema.sql index 38ff863..c6e1323 100644 --- a/docker/postgres/schema.sql +++ b/docker/postgres/schema.sql @@ -25,7 +25,7 @@ SET default_table_access_method = heap; --
CREATE TABLE public.log (
- id integer NOT NULL,
+ id SERIAL PRIMARY KEY,
"timestamp" date NOT NULL,
username text NOT NULL,
logtext text
@@ -39,13 +39,16 @@ ALTER TABLE public.log OWNER TO test; --
CREATE TABLE public.scanner (
- id integer NOT NULL,
+ id SERIAL PRIMARY KEY,
+ runner text DEFAULT '*',
name character varying(128) NOT NULL,
active boolean NOT NULL,
"interval" integer DEFAULT 300 NOT NULL,
starttime date,
endtime date,
- maxruns integer DEFAULT 1
+ maxruns integer DEFAULT 1,
+ hostname character varying(128) NOT NULL,
+ port integer NOT NULL
);
@@ -94,4 +97,3 @@ ALTER TABLE ONLY public.scanner --
-- PostgreSQL database dump complete
--
-
diff --git a/src/db/sql.py b/src/db/sql.py new file mode 100644 index 0000000..f2a9ee2 --- /dev/null +++ b/src/db/sql.py @@ -0,0 +1,155 @@ +import datetime +import os +import sys +from contextlib import contextmanager + +from sqlalchemy import (Boolean, Column, Date, Integer, String, Text, + create_engine, text) +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker + +Base = declarative_base() +metadata = Base.metadata + + +class Log(Base): + __tablename__ = "log" + + id = Column(Integer, primary_key=True) + timestamp = Column(Date, nullable=False, default=datetime.datetime.utcnow) + username = Column(Text, nullable=False) + logtext = Column(Text, nullable=False) + + def as_dict(self): + """Return JSON serializable dict.""" + d = {} + for col in self.__table__.columns: + value = getattr(self, col.name) + if issubclass(value.__class__, Base): + continue + elif issubclass(value.__class__, datetime.datetime): + value = str(value) + d[col.name] = value + return d + + @classmethod + def add(cls, username, logtext): + with sqla_session() as session: + logentry = Log() + logentry.username = username + logentry.logtext = logtext + session.add(logentry) + + +class Scanner(Base): + __tablename__ = "scanner" + + id = Column(Integer, primary_key=True) + runner = Column(Text, nullable=False, default="*") + name = Column(String(128), nullable=False, unique=True) + active = Column(Boolean, nullable=False) + interval = Column(Integer, nullable=False, server_default=text("300")) + starttime = Column(Date) + hostname = Column(String(128), nullable=False, unique=True) + port = Column(Integer, nullable=False) + maxruns = Column(Integer, nullable=False, default=1) + + def as_dict(self): + d = {} + for col in self.__table__.columns: + value = getattr(self, col.name) + if issubclass(value.__class__, Base): + continue + elif issubclass(value.__class__, datetime.datetime): + value = str(value) + d[col.name] = value + return d + + @classmethod + def add(cls, name, hostname, port, active=False, interval=0, + starttime=None, + endtime=None, + maxruns=1): + errors = list() + if starttime and endtime: + if starttime > endtime: + errors.append("Endtime must be after the starttime.") + if interval < 0: + errors.append("Interval must be > 0") + if maxruns < 0: + errors.append("Max runs must be > 0") + with sqla_session() as session: + scanentry = Scanner() + scanentry.name = name + scanentry.active = active + scanentry.interval = interval + if starttime: + scanentry.starttime = starttime + if endtime: + scanentry.endtime = endtime + scanentry.maxruns = maxruns + scanentry.hostname = hostname + scanentry.port = port + session.add(scanentry) + return errors + + @classmethod + def get(cls, name): + results = list() + with sqla_session() as session: + scanners = session.query(Scanner).all() + if not scanners: + return None + for scanner in scanners: + if scanner.runner == "*": + results.append(scanner.as_dict()) + elif scanner.runner == name: + results.append(scanner.as_dict()) + return results + + +def get_sqlalchemy_conn_str(**kwargs) -> str: + try: + if "SQL_HOSTNAME" in os.environ: + hostname = os.environ["SQL_HOSTNAME"] + else: + hostname = "localhost" + print("SQL_HOSTNAME not set, falling back to localhost.") + if "SQL_PORT" in os.environ: + port = os.environ["SQL_PORT"] + else: + print("SQL_PORT not set, falling back to 5432.") + port = 5432 + username = os.environ["SQL_USERNAME"] + password = os.environ["SQL_PASSWORD"] + database = os.environ["SQL_DATABASE"] + except KeyError: + print("SQL_DATABASE, SQL_USERNAME, SQL_PASSWORD must be set.") + sys.exit(-2) + + return ( + f"postgresql://{username}:{password}@{hostname}:{port}/{database}" + ) + + +def get_session(conn_str=""): + if conn_str == "": + conn_str = get_sqlalchemy_conn_str() + + engine = create_engine(conn_str, pool_size=50, max_overflow=0) + Session = sessionmaker(bind=engine) + + return Session() + + +@contextmanager +def sqla_session(conn_str="", **kwargs): + session = get_session(conn_str) + try: + yield session + session.commit() + except Exception: + session.rollback() + raise + finally: + session.close() diff --git a/src/main.py b/src/main.py index d3d02fd..c93029f 100755 --- a/src/main.py +++ b/src/main.py @@ -12,6 +12,7 @@ from pydantic import BaseModel from db.dictionary import DictDB from db.index import CouchIindex +from db.sql import Log, Scanner app = FastAPI() @@ -216,6 +217,15 @@ async def delete(key, Authorize: AuthJWT = Depends()): return JSONResponse(content={"status": "success", "docs": data}) +@app.get("/sc/v0/scanner/{name}") +async def scanner_get(name, data: Request, Authorize: AuthJWT = Depends()): + Authorize.jwt_required() + + scanners = Scanner.get(name) + + return JSONResponse(content={"status": "success", "data": scanners}) + + def main(standalone=False): if not standalone: return app |