summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rwxr-xr-xsrc/db.py60
-rw-r--r--src/job.py84
-rw-r--r--src/jobs.py58
-rwxr-xr-xsrc/main.py29
4 files changed, 221 insertions, 10 deletions
diff --git a/src/db.py b/src/db.py
index 012dfac..57f4003 100755
--- a/src/db.py
+++ b/src/db.py
@@ -10,8 +10,12 @@
import os
import sys
import time
-import couch
+from contextlib import contextmanager
+
+from sqlalchemy import create_engine
+from sqlalchemy.orm import sessionmaker
+import couch
from index import CouchIindex
@@ -137,3 +141,57 @@ class DictDB():
return None
return key
+
+
+def get_conn_str():
+ try:
+ dialect = os.environ['SQL_DIALECT']
+ database = os.environ['SQL_DATABASE']
+ except KeyError:
+ print('The environment variables SQL_DIALECT and SQL_DATABASE must ' +
+ 'be set.')
+ sys.exit(-1)
+
+ if dialect != 'sqlite':
+ try:
+ hostname = os.environ['SQL_HOSTNAME']
+ username = os.environ['SQL_USERNAME']
+ password = os.environ['SQL_PASSWORD']
+ except KeyError:
+ print('The environment variables SQL_DIALECT, SQL_NAME, ' +
+ 'SQL_HOSTNAME, SQL_USERNAME and SQL_PASSWORD must ' +
+ 'be set.')
+ sys.exit(-1)
+
+ if dialect == 'sqlite':
+ conn_str = f"{dialect}:///{database}.db"
+ else:
+ conn_str = f"{dialect}://{username}:{password}@{hostname}" + \
+ "/{database}"
+
+ return conn_str
+
+
+class SqlDB():
+ def get_session(conn_str):
+ if 'sqlite' in conn_str:
+ engine = create_engine(conn_str)
+ else:
+ engine = create_engine(conn_str, pool_size=50, max_overflow=0)
+ Session = sessionmaker(bind=engine)
+
+ return Session()
+
+ @classmethod
+ @contextmanager
+ def sql_session(cls, **kwargs):
+ session = cls.get_session(get_conn_str())
+
+ try:
+ yield session
+ session.commit()
+ except Exception:
+ session.rollback()
+ raise
+ finally:
+ session.close()
diff --git a/src/job.py b/src/job.py
new file mode 100644
index 0000000..32311de
--- /dev/null
+++ b/src/job.py
@@ -0,0 +1,84 @@
+import enum
+from datetime import datetime
+
+from sqlalchemy import (Column, DateTime, Enum, Integer, Unicode,
+ UniqueConstraint, create_engine)
+from sqlalchemy.ext.declarative import declarative_base
+
+from db import SqlDB, get_conn_str
+
+Base = declarative_base()
+engine = create_engine(get_conn_str())
+
+
+class JobStatus(enum.Enum):
+ UNKNOWN = 0
+ SCHEDULED = 1
+ RUNNING = 2
+ STOPPED = 3
+ ABORTED = 4
+ ERROR = 5
+ CLEARED = 6,
+ DONE = 7
+
+
+class Job(Base):
+ __tablename__ = 'jobs'
+ __table_args__ = (
+ None,
+ UniqueConstraint('id'),
+ )
+
+ id = Column(Integer, autoincrement=True, primary_key=True)
+ starttime = Column(DateTime, nullable=False)
+ stoptime = Column(DateTime, nullable=True)
+ updated = Column(DateTime, nullable=True)
+ status = Column(Enum(JobStatus), index=True, default=JobStatus.SCHEDULED,
+ nullable=True)
+ comment = Column(Unicode(255), nullable=True)
+ scanner = Column(Unicode(255), nullable=True)
+ target = Column(Unicode(255), nullable=True)
+
+ @classmethod
+ def starter(cls, func, **kwargs):
+ job_id = kwargs['job_id']
+ retval = None
+
+ del kwargs['job_id']
+
+ with SqlDB.sql_session() as session:
+ job = session.query(Job).filter(Job.id == job_id).one_or_none()
+ job.status = JobStatus.RUNNING
+
+ session.commit()
+
+ try:
+ retval = func(*kwargs)
+ except Exception:
+ job.status = JobStatus.ERROR
+ job.stoptime = datetime.now()
+ session.commit()
+ print("Job raised an exception!")
+ else:
+ job.status = JobStatus.DONE
+ job.stoptime = datetime.now()
+ print("Job finished properly.")
+
+ return retval
+
+ def as_dict(self):
+ """Return JSON serializable dict."""
+ d = {}
+ for col in self.__table__.columns:
+ value = getattr(self, col.name)
+ if issubclass(value.__class__, enum.Enum):
+ value = value.value
+ elif issubclass(value.__class__, Base):
+ continue
+ elif issubclass(value.__class__, datetime):
+ value = str(value)
+ d[col.name] = value
+ return d
+
+
+Base.metadata.create_all(engine)
diff --git a/src/jobs.py b/src/jobs.py
new file mode 100644
index 0000000..67c4e1c
--- /dev/null
+++ b/src/jobs.py
@@ -0,0 +1,58 @@
+from datetime import datetime
+
+from apscheduler.executors.pool import ThreadPoolExecutor
+from apscheduler.jobstores.memory import MemoryJobStore
+from apscheduler.schedulers.background import BackgroundScheduler
+from pytz import utc
+
+from db import SqlDB
+from job import Job, JobStatus
+
+
+class JobScheduler(object):
+ def __init__(self, nr_threads=10):
+ self.__scheduler = BackgroundScheduler(
+ executors={"default": ThreadPoolExecutor(nr_threads)},
+ jobstores={"default": MemoryJobStore()},
+ job_defaults={},
+ timezone=utc,
+ )
+
+ def get(self):
+ return self.__scheduler
+
+ def start(self):
+ return self.__scheduler.start()
+
+ def stop(self):
+ return self.__scheduler.shutdown()
+
+ def add(self, func, comment='', **kwargs):
+ with SqlDB.sql_session() as session:
+ job = Job()
+ job.starttime = datetime.now()
+ job.comment = comment
+ job.status = JobStatus.SCHEDULED
+
+ session.add(job)
+ session.flush()
+
+ job_id = job.id
+ kwargs['job_id'] = job_id
+ kwargs['func'] = func
+
+ self.__scheduler.add_job(Job.starter, kwargs=kwargs)
+ return job_id
+
+ @classmethod
+ def get_jobs(cls):
+ jobs = list()
+
+ with SqlDB.sql_session() as session:
+ query = session.query(Job).all()
+
+ for instance in query:
+ job_dict = instance.as_dict()
+ jobs.append(job_dict)
+
+ return jobs
diff --git a/src/main.py b/src/main.py
index f95a09c..4454f7d 100755
--- a/src/main.py
+++ b/src/main.py
@@ -1,17 +1,19 @@
import os
import sys
-import uvicorn
+import time
-from fastapi import FastAPI, Depends, Request
+import requests
+import uvicorn
+from fastapi import Depends, FastAPI, Request
+from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi_jwt_auth import AuthJWT
from fastapi_jwt_auth.exceptions import AuthJWTException
-from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
-from index import CouchIindex
-import time
+
from db import DictDB
-import requests
+from index import CouchIindex
+from jobs import JobScheduler
app = FastAPI()
@@ -37,8 +39,8 @@ for i in range(10):
try:
db = DictDB()
except requests.exceptions.ConnectionError:
- print(f'Database not responding, will try again soon.' +
- 'Attempt {i + 1} of 10.')
+ print('Database not responding, will try again soon.' +
+ f'Attempt {i + 1} of 10.')
else:
break
time.sleep(10)
@@ -118,7 +120,7 @@ async def get(key=None, limit=25, skip=0, ip=None, port=None,
if 'domains' not in raw_jwt:
return JSONResponse(content={"status": "error",
- "message": "Could not find domains" +
+ "message": "Could not find domains " +
"claim in JWT token"},
status_code=400)
else:
@@ -169,6 +171,15 @@ async def delete(key, Authorize: AuthJWT = Depends()):
return JSONResponse(content={"status": "success", "docs": {}})
+@app.get('/sc/v0/jobs')
+async def jobs_get(Authorize: AuthJWT = Depends()):
+ Authorize.jwt_required()
+
+ data = JobScheduler.get_jobs()
+
+ return JSONResponse(content={"status": "success", "jobs": data})
+
+
def main(standalone=False):
if not standalone:
return app