diff options
Diffstat (limited to 'src')
| -rwxr-xr-x | src/db.py | 60 | ||||
| -rw-r--r-- | src/job.py | 84 | ||||
| -rw-r--r-- | src/jobs.py | 58 | ||||
| -rwxr-xr-x | src/main.py | 29 |
4 files changed, 221 insertions, 10 deletions
@@ -10,8 +10,12 @@ import os import sys import time -import couch +from contextlib import contextmanager + +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +import couch from index import CouchIindex @@ -137,3 +141,57 @@ class DictDB(): return None return key + + +def get_conn_str(): + try: + dialect = os.environ['SQL_DIALECT'] + database = os.environ['SQL_DATABASE'] + except KeyError: + print('The environment variables SQL_DIALECT and SQL_DATABASE must ' + + 'be set.') + sys.exit(-1) + + if dialect != 'sqlite': + try: + hostname = os.environ['SQL_HOSTNAME'] + username = os.environ['SQL_USERNAME'] + password = os.environ['SQL_PASSWORD'] + except KeyError: + print('The environment variables SQL_DIALECT, SQL_NAME, ' + + 'SQL_HOSTNAME, SQL_USERNAME and SQL_PASSWORD must ' + + 'be set.') + sys.exit(-1) + + if dialect == 'sqlite': + conn_str = f"{dialect}:///{database}.db" + else: + conn_str = f"{dialect}://{username}:{password}@{hostname}" + \ + "/{database}" + + return conn_str + + +class SqlDB(): + def get_session(conn_str): + if 'sqlite' in conn_str: + engine = create_engine(conn_str) + else: + engine = create_engine(conn_str, pool_size=50, max_overflow=0) + Session = sessionmaker(bind=engine) + + return Session() + + @classmethod + @contextmanager + def sql_session(cls, **kwargs): + session = cls.get_session(get_conn_str()) + + try: + yield session + session.commit() + except Exception: + session.rollback() + raise + finally: + session.close() diff --git a/src/job.py b/src/job.py new file mode 100644 index 0000000..32311de --- /dev/null +++ b/src/job.py @@ -0,0 +1,84 @@ +import enum +from datetime import datetime + +from sqlalchemy import (Column, DateTime, Enum, Integer, Unicode, + UniqueConstraint, create_engine) +from sqlalchemy.ext.declarative import declarative_base + +from db import SqlDB, get_conn_str + +Base = declarative_base() +engine = create_engine(get_conn_str()) + + +class JobStatus(enum.Enum): + UNKNOWN = 0 + SCHEDULED = 1 + RUNNING = 2 + STOPPED = 3 + ABORTED = 4 + ERROR = 5 + CLEARED = 6, + DONE = 7 + + +class Job(Base): + __tablename__ = 'jobs' + __table_args__ = ( + None, + UniqueConstraint('id'), + ) + + id = Column(Integer, autoincrement=True, primary_key=True) + starttime = Column(DateTime, nullable=False) + stoptime = Column(DateTime, nullable=True) + updated = Column(DateTime, nullable=True) + status = Column(Enum(JobStatus), index=True, default=JobStatus.SCHEDULED, + nullable=True) + comment = Column(Unicode(255), nullable=True) + scanner = Column(Unicode(255), nullable=True) + target = Column(Unicode(255), nullable=True) + + @classmethod + def starter(cls, func, **kwargs): + job_id = kwargs['job_id'] + retval = None + + del kwargs['job_id'] + + with SqlDB.sql_session() as session: + job = session.query(Job).filter(Job.id == job_id).one_or_none() + job.status = JobStatus.RUNNING + + session.commit() + + try: + retval = func(*kwargs) + except Exception: + job.status = JobStatus.ERROR + job.stoptime = datetime.now() + session.commit() + print("Job raised an exception!") + else: + job.status = JobStatus.DONE + job.stoptime = datetime.now() + print("Job finished properly.") + + return retval + + def as_dict(self): + """Return JSON serializable dict.""" + d = {} + for col in self.__table__.columns: + value = getattr(self, col.name) + if issubclass(value.__class__, enum.Enum): + value = value.value + elif issubclass(value.__class__, Base): + continue + elif issubclass(value.__class__, datetime): + value = str(value) + d[col.name] = value + return d + + +Base.metadata.create_all(engine) diff --git a/src/jobs.py b/src/jobs.py new file mode 100644 index 0000000..67c4e1c --- /dev/null +++ b/src/jobs.py @@ -0,0 +1,58 @@ +from datetime import datetime + +from apscheduler.executors.pool import ThreadPoolExecutor +from apscheduler.jobstores.memory import MemoryJobStore +from apscheduler.schedulers.background import BackgroundScheduler +from pytz import utc + +from db import SqlDB +from job import Job, JobStatus + + +class JobScheduler(object): + def __init__(self, nr_threads=10): + self.__scheduler = BackgroundScheduler( + executors={"default": ThreadPoolExecutor(nr_threads)}, + jobstores={"default": MemoryJobStore()}, + job_defaults={}, + timezone=utc, + ) + + def get(self): + return self.__scheduler + + def start(self): + return self.__scheduler.start() + + def stop(self): + return self.__scheduler.shutdown() + + def add(self, func, comment='', **kwargs): + with SqlDB.sql_session() as session: + job = Job() + job.starttime = datetime.now() + job.comment = comment + job.status = JobStatus.SCHEDULED + + session.add(job) + session.flush() + + job_id = job.id + kwargs['job_id'] = job_id + kwargs['func'] = func + + self.__scheduler.add_job(Job.starter, kwargs=kwargs) + return job_id + + @classmethod + def get_jobs(cls): + jobs = list() + + with SqlDB.sql_session() as session: + query = session.query(Job).all() + + for instance in query: + job_dict = instance.as_dict() + jobs.append(job_dict) + + return jobs diff --git a/src/main.py b/src/main.py index f95a09c..4454f7d 100755 --- a/src/main.py +++ b/src/main.py @@ -1,17 +1,19 @@ import os import sys -import uvicorn +import time -from fastapi import FastAPI, Depends, Request +import requests +import uvicorn +from fastapi import Depends, FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from fastapi_jwt_auth import AuthJWT from fastapi_jwt_auth.exceptions import AuthJWTException -from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel -from index import CouchIindex -import time + from db import DictDB -import requests +from index import CouchIindex +from jobs import JobScheduler app = FastAPI() @@ -37,8 +39,8 @@ for i in range(10): try: db = DictDB() except requests.exceptions.ConnectionError: - print(f'Database not responding, will try again soon.' + - 'Attempt {i + 1} of 10.') + print('Database not responding, will try again soon.' + + f'Attempt {i + 1} of 10.') else: break time.sleep(10) @@ -118,7 +120,7 @@ async def get(key=None, limit=25, skip=0, ip=None, port=None, if 'domains' not in raw_jwt: return JSONResponse(content={"status": "error", - "message": "Could not find domains" + + "message": "Could not find domains " + "claim in JWT token"}, status_code=400) else: @@ -169,6 +171,15 @@ async def delete(key, Authorize: AuthJWT = Depends()): return JSONResponse(content={"status": "success", "docs": {}}) +@app.get('/sc/v0/jobs') +async def jobs_get(Authorize: AuthJWT = Depends()): + Authorize.jwt_required() + + data = JobScheduler.get_jobs() + + return JSONResponse(content={"status": "success", "jobs": data}) + + def main(standalone=False): if not standalone: return app |
