diff options
author | Kristofer Hallin <kristofer@sunet.se> | 2022-04-11 16:46:54 +0200 |
---|---|---|
committer | Kristofer Hallin <kristofer@sunet.se> | 2022-04-11 16:46:54 +0200 |
commit | 16f5009ac0d630c5f25c9d6cb4e8fb026ae628f9 (patch) | |
tree | 1e9e681a86e28189c833049b3478bef462b20ee5 | |
parent | 3335f65e5e4b3132a72b46b99e50d3c55c0c58b5 (diff) |
More database changes and endpoints.
-rw-r--r-- | src/db/sql.py | 37 | ||||
-rwxr-xr-x | src/main.py | 45 |
2 files changed, 58 insertions, 24 deletions
diff --git a/src/db/sql.py b/src/db/sql.py index f2a9ee2..fc20e36 100644 --- a/src/db/sql.py +++ b/src/db/sql.py @@ -1,9 +1,9 @@ import datetime -import os import sys from contextlib import contextmanager +import os -from sqlalchemy import (Boolean, Column, Date, Integer, String, Text, +from sqlalchemy import (Boolean, Column, Date, Integer, Serial, String, Text, create_engine, text) from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker @@ -15,8 +15,9 @@ metadata = Base.metadata class Log(Base): __tablename__ = "log" - id = Column(Integer, primary_key=True) - timestamp = Column(Date, nullable=False, default=datetime.datetime.utcnow) + id = Column(Serial, primary_key=True) + timestamp = Column(Date, nullable=False, + default=datetime.datetime.utcnow) username = Column(Text, nullable=False) logtext = Column(Text, nullable=False) @@ -42,17 +43,20 @@ class Log(Base): class Scanner(Base): - __tablename__ = "scanner" + __tablename__ = 'scanner' - id = Column(Integer, primary_key=True) - runner = Column(Text, nullable=False, default="*") - name = Column(String(128), nullable=False, unique=True) + id = Column(Integer, primary_key=True, server_default=text( + "nextval('scanner_id_seq'::regclass)")) + runner = Column(Text, server_default=text("'*'::text")) + name = Column(String(128), nullable=False) active = Column(Boolean, nullable=False) - interval = Column(Integer, nullable=False, server_default=text("300")) + interval = Column(Integer, nullable=False, + server_default=text("300")) starttime = Column(Date) - hostname = Column(String(128), nullable=False, unique=True) + endtime = Column(Date) + maxruns = Column(Integer, server_default=text("1")) + hostname = Column(String(128), nullable=False) port = Column(Integer, nullable=False) - maxruns = Column(Integer, nullable=False, default=1) def as_dict(self): d = {} @@ -107,6 +111,17 @@ class Scanner(Base): results.append(scanner.as_dict()) return results + @classmethod + def edit(cls, name, active): + with sqla_session() as session: + scanners = session.query(Scanner).filter( + Scanner.name == name).all() + if not scanners: + return None + for scanner in scanners: + scanner.active = active + return True + def get_sqlalchemy_conn_str(**kwargs) -> str: try: diff --git a/src/main.py b/src/main.py index c93029f..e6bb8e2 100755 --- a/src/main.py +++ b/src/main.py @@ -3,7 +3,10 @@ import sys import time import uvicorn -from fastapi import Depends, FastAPI, Request + +from fastapi import Depends +from fastapi import FastAPI +from fastapi import Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from fastapi_jwt_auth import AuthJWT @@ -12,13 +15,14 @@ from pydantic import BaseModel from db.dictionary import DictDB from db.index import CouchIindex -from db.sql import Log, Scanner +from db.sql import Log +from db.sql import Scanner app = FastAPI() app.add_middleware( CORSMiddleware, - allow_origins=["http://localhost:8001"], + allow_origins=["http://localhost:8000"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], @@ -38,8 +42,7 @@ for i in range(10): try: db = DictDB() except Exception: - print( - f'Database not responding, will try again soon. Attempt {i + 1} of 10.') + print(f"Database not responding, will try again soon. Attempt {i + 1} of 10.") else: break time.sleep(1) @@ -90,25 +93,25 @@ class JWTConfig(BaseModel): authjwt_public_key: str = get_pubkey() -@AuthJWT.load_config +@ AuthJWT.load_config def jwt_config(): return JWTConfig() -@app.exception_handler(AuthJWTException) +@ app.exception_handler(AuthJWTException) def authjwt_exception_handler(request: Request, exc: AuthJWTException): return JSONResponse(content={"status": "error", "message": exc.message}, status_code=400) -@app.exception_handler(RuntimeError) +@ app.exception_handler(RuntimeError) def app_exception_handler(request: Request, exc: RuntimeError): return JSONResponse(content={"status": "error", "message": str(exc.with_traceback(None))}, status_code=400) -@app.get('/sc/v0/get') +@ app.get('/sc/v0/get') async def get(key=None, limit=25, skip=0, ip=None, port=None, asn=None, Authorize: AuthJWT = Depends()): @@ -134,7 +137,7 @@ async def get(key=None, limit=25, skip=0, ip=None, port=None, return JSONResponse(content={"status": "success", "docs": data}) -@app.get('/sc/v0/get/{key}') +@ app.get('/sc/v0/get/{key}') async def get_key(key=None, Authorize: AuthJWT = Depends()): Authorize.jwt_required() @@ -166,7 +169,7 @@ async def get_key(key=None, Authorize: AuthJWT = Depends()): return JSONResponse(content={"status": "success", "docs": data}) -@app.post('/sc/v0/add') +@ app.post('/sc/v0/add') async def add(data: Request, Authorize: AuthJWT = Depends()): # Maybe we should protect this enpoint too and let the scanner use @@ -180,7 +183,7 @@ async def add(data: Request, Authorize: AuthJWT = Depends()): return JSONResponse(content={"status": "success", "docs": key}) -@app.delete('/sc/v0/delete/{key}') +@ app.delete('/sc/v0/delete/{key}') async def delete(key, Authorize: AuthJWT = Depends()): Authorize.jwt_required() @@ -217,7 +220,7 @@ async def delete(key, Authorize: AuthJWT = Depends()): return JSONResponse(content={"status": "success", "docs": data}) -@app.get("/sc/v0/scanner/{name}") +@ app.get("/sc/v0/scanner/{name}") async def scanner_get(name, data: Request, Authorize: AuthJWT = Depends()): Authorize.jwt_required() @@ -226,6 +229,22 @@ async def scanner_get(name, data: Request, Authorize: AuthJWT = Depends()): return JSONResponse(content={"status": "success", "data": scanners}) +@ app.put("/sc/v0/scanner/{name}") +async def scanner_put(name, data: Request, Authorize: AuthJWT = Depends()): + errors = None + Authorize.jwt_required() + + json_data = await data.json() + + if "active" in json_data and isinstance(json_data["active"], bool): + errors = Scanner.active(name, json_data["active"]) + + if errors: + return JSONResponse(content={"status": "error", "message": "\n".join(errors)}, status_code=400) + + return JSONResponse(content={"status": "success", "data": Scanner.get(name)}, status_code=200) + + def main(standalone=False): if not standalone: return app |