summaryrefslogtreecommitdiff
path: root/src/db/scanner.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/db/scanner.py')
-rw-r--r--src/db/scanner.py133
1 files changed, 133 insertions, 0 deletions
diff --git a/src/db/scanner.py b/src/db/scanner.py
new file mode 100644
index 0000000..625fd8e
--- /dev/null
+++ b/src/db/scanner.py
@@ -0,0 +1,133 @@
+import enum
+from datetime import datetime
+
+from sqlalchemy import (JSON, Boolean, Column, DateTime, Integer, Unicode,
+ UniqueConstraint, create_engine)
+from sqlalchemy.exc import IntegrityError
+from sqlalchemy.ext.declarative import declarative_base
+
+from db.db import SqlDB, get_conn_str
+
+Base = declarative_base()
+engine = create_engine(get_conn_str())
+
+
+class Scanner(Base):
+ __tablename__ = 'scanners'
+ __table_args__ = (
+ None,
+ UniqueConstraint('id'),
+ UniqueConstraint('uuid'),
+ )
+
+ id = Column(Integer, autoincrement=True, primary_key=True)
+ uuid = Column(Unicode(37), nullable=False)
+ enabled = Column(Boolean, default=False, nullable=False)
+ first_seen = Column(DateTime, default=datetime.utcnow, nullable=False)
+ last_seen = Column(DateTime, default=datetime.utcnow,
+ onupdate=datetime.utcnow, nullable=False)
+ comment = Column(Unicode(255), default="", nullable=True)
+ scanners = Column(JSON, nullable=True)
+
+ def as_dict(self):
+ """Return JSON serializable dict."""
+ d = {}
+ for col in self.__table__.columns:
+ value = getattr(self, col.name)
+ if issubclass(value.__class__, enum.Enum):
+ value = value.value
+ elif issubclass(value.__class__, Base):
+ continue
+ elif issubclass(value.__class__, datetime):
+ value = str(value)
+ d[col.name] = value
+ return d
+
+ @classmethod
+ def comment(cls, uuid, comment):
+ with SqlDB.sql_session() as session:
+ scanner: Scanner = session.query(Scanner).filter(
+ Scanner.uuid == uuid).one_or_none()
+
+ if scanner:
+ scanner.comment = comment
+ else:
+ return None
+ return None
+
+ @classmethod
+ def enable(cls, uuid):
+ with SqlDB.sql_session() as session:
+ scanner: Scanner = session.query(Scanner).filter(
+ Scanner.uuid == uuid).one_or_none()
+
+ if scanner:
+ scanner.enabled = True
+ else:
+ return None
+ return None
+
+ @classmethod
+ def disable(cls, uuid):
+ with SqlDB.sql_session() as session:
+ scanner: Scanner = session.query(Scanner).filter(
+ Scanner.uuid == uuid).one_or_none()
+
+ if scanner:
+ scanner.enabled = False
+ else:
+ return None
+ return None
+
+ @classmethod
+ def get(cls, scanner_id=None, uuid=None):
+ if scanner_id is None and uuid is None:
+ raise ValueError('Either scanner_id or uuid must be present.')
+
+ with SqlDB.sql_session() as session:
+ if scanner_id:
+ scanner: Scanner = session.query(Scanner).filter(
+ Scanner.id == scanner_id).one_or_none()
+ elif uuid:
+ scanner: Scanner = session.query(Scanner).filter(
+ Scanner.uuid == uuid).one_or_none()
+ else:
+ return None
+
+ if scanner is None:
+ return None
+
+ return scanner.as_dict()
+
+ return None
+
+ @classmethod
+ def add(cls, uuid):
+ try:
+ with SqlDB.sql_session() as session:
+ scanner = Scanner()
+ scanner.uuid = uuid
+ scanner.enabled = False
+ scanner.scanners = {}
+
+ session.add(scanner)
+ session.flush()
+
+ return scanner.id
+ except IntegrityError:
+ return None
+
+ @classmethod
+ def is_enabled(cls, uuid):
+ with SqlDB.sql_session() as session:
+ scanner: Scanner = session.query(Scanner).filter(
+ Scanner.uuid == uuid).one_or_none()
+ if scanner is None:
+ return None
+
+ enabled = scanner.enabled
+
+ return enabled
+
+
+Base.metadata.create_all(engine)