summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKristofer Hallin <kristofer@sunet.se>2022-04-11 16:46:54 +0200
committerKristofer Hallin <kristofer@sunet.se>2022-04-11 16:46:54 +0200
commit16f5009ac0d630c5f25c9d6cb4e8fb026ae628f9 (patch)
tree1e9e681a86e28189c833049b3478bef462b20ee5
parent3335f65e5e4b3132a72b46b99e50d3c55c0c58b5 (diff)
More database changes and endpoints.
-rw-r--r--src/db/sql.py37
-rwxr-xr-xsrc/main.py45
2 files changed, 58 insertions, 24 deletions
diff --git a/src/db/sql.py b/src/db/sql.py
index f2a9ee2..fc20e36 100644
--- a/src/db/sql.py
+++ b/src/db/sql.py
@@ -1,9 +1,9 @@
import datetime
-import os
import sys
from contextlib import contextmanager
+import os
-from sqlalchemy import (Boolean, Column, Date, Integer, String, Text,
+from sqlalchemy import (Boolean, Column, Date, Integer, Serial, String, Text,
create_engine, text)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
@@ -15,8 +15,9 @@ metadata = Base.metadata
class Log(Base):
__tablename__ = "log"
- id = Column(Integer, primary_key=True)
- timestamp = Column(Date, nullable=False, default=datetime.datetime.utcnow)
+ id = Column(Serial, primary_key=True)
+ timestamp = Column(Date, nullable=False,
+ default=datetime.datetime.utcnow)
username = Column(Text, nullable=False)
logtext = Column(Text, nullable=False)
@@ -42,17 +43,20 @@ class Log(Base):
class Scanner(Base):
- __tablename__ = "scanner"
+ __tablename__ = 'scanner'
- id = Column(Integer, primary_key=True)
- runner = Column(Text, nullable=False, default="*")
- name = Column(String(128), nullable=False, unique=True)
+ id = Column(Integer, primary_key=True, server_default=text(
+ "nextval('scanner_id_seq'::regclass)"))
+ runner = Column(Text, server_default=text("'*'::text"))
+ name = Column(String(128), nullable=False)
active = Column(Boolean, nullable=False)
- interval = Column(Integer, nullable=False, server_default=text("300"))
+ interval = Column(Integer, nullable=False,
+ server_default=text("300"))
starttime = Column(Date)
- hostname = Column(String(128), nullable=False, unique=True)
+ endtime = Column(Date)
+ maxruns = Column(Integer, server_default=text("1"))
+ hostname = Column(String(128), nullable=False)
port = Column(Integer, nullable=False)
- maxruns = Column(Integer, nullable=False, default=1)
def as_dict(self):
d = {}
@@ -107,6 +111,17 @@ class Scanner(Base):
results.append(scanner.as_dict())
return results
+ @classmethod
+ def edit(cls, name, active):
+ with sqla_session() as session:
+ scanners = session.query(Scanner).filter(
+ Scanner.name == name).all()
+ if not scanners:
+ return None
+ for scanner in scanners:
+ scanner.active = active
+ return True
+
def get_sqlalchemy_conn_str(**kwargs) -> str:
try:
diff --git a/src/main.py b/src/main.py
index c93029f..e6bb8e2 100755
--- a/src/main.py
+++ b/src/main.py
@@ -3,7 +3,10 @@ import sys
import time
import uvicorn
-from fastapi import Depends, FastAPI, Request
+
+from fastapi import Depends
+from fastapi import FastAPI
+from fastapi import Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi_jwt_auth import AuthJWT
@@ -12,13 +15,14 @@ from pydantic import BaseModel
from db.dictionary import DictDB
from db.index import CouchIindex
-from db.sql import Log, Scanner
+from db.sql import Log
+from db.sql import Scanner
app = FastAPI()
app.add_middleware(
CORSMiddleware,
- allow_origins=["http://localhost:8001"],
+ allow_origins=["http://localhost:8000"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
@@ -38,8 +42,7 @@ for i in range(10):
try:
db = DictDB()
except Exception:
- print(
- f'Database not responding, will try again soon. Attempt {i + 1} of 10.')
+ print(f"Database not responding, will try again soon. Attempt {i + 1} of 10.")
else:
break
time.sleep(1)
@@ -90,25 +93,25 @@ class JWTConfig(BaseModel):
authjwt_public_key: str = get_pubkey()
-@AuthJWT.load_config
+@ AuthJWT.load_config
def jwt_config():
return JWTConfig()
-@app.exception_handler(AuthJWTException)
+@ app.exception_handler(AuthJWTException)
def authjwt_exception_handler(request: Request, exc: AuthJWTException):
return JSONResponse(content={"status": "error", "message":
exc.message}, status_code=400)
-@app.exception_handler(RuntimeError)
+@ app.exception_handler(RuntimeError)
def app_exception_handler(request: Request, exc: RuntimeError):
return JSONResponse(content={"status": "error", "message":
str(exc.with_traceback(None))},
status_code=400)
-@app.get('/sc/v0/get')
+@ app.get('/sc/v0/get')
async def get(key=None, limit=25, skip=0, ip=None, port=None,
asn=None, Authorize: AuthJWT = Depends()):
@@ -134,7 +137,7 @@ async def get(key=None, limit=25, skip=0, ip=None, port=None,
return JSONResponse(content={"status": "success", "docs": data})
-@app.get('/sc/v0/get/{key}')
+@ app.get('/sc/v0/get/{key}')
async def get_key(key=None, Authorize: AuthJWT = Depends()):
Authorize.jwt_required()
@@ -166,7 +169,7 @@ async def get_key(key=None, Authorize: AuthJWT = Depends()):
return JSONResponse(content={"status": "success", "docs": data})
-@app.post('/sc/v0/add')
+@ app.post('/sc/v0/add')
async def add(data: Request, Authorize: AuthJWT = Depends()):
# Maybe we should protect this enpoint too and let the scanner use
@@ -180,7 +183,7 @@ async def add(data: Request, Authorize: AuthJWT = Depends()):
return JSONResponse(content={"status": "success", "docs": key})
-@app.delete('/sc/v0/delete/{key}')
+@ app.delete('/sc/v0/delete/{key}')
async def delete(key, Authorize: AuthJWT = Depends()):
Authorize.jwt_required()
@@ -217,7 +220,7 @@ async def delete(key, Authorize: AuthJWT = Depends()):
return JSONResponse(content={"status": "success", "docs": data})
-@app.get("/sc/v0/scanner/{name}")
+@ app.get("/sc/v0/scanner/{name}")
async def scanner_get(name, data: Request, Authorize: AuthJWT = Depends()):
Authorize.jwt_required()
@@ -226,6 +229,22 @@ async def scanner_get(name, data: Request, Authorize: AuthJWT = Depends()):
return JSONResponse(content={"status": "success", "data": scanners})
+@ app.put("/sc/v0/scanner/{name}")
+async def scanner_put(name, data: Request, Authorize: AuthJWT = Depends()):
+ errors = None
+ Authorize.jwt_required()
+
+ json_data = await data.json()
+
+ if "active" in json_data and isinstance(json_data["active"], bool):
+ errors = Scanner.active(name, json_data["active"])
+
+ if errors:
+ return JSONResponse(content={"status": "error", "message": "\n".join(errors)}, status_code=400)
+
+ return JSONResponse(content={"status": "success", "data": Scanner.get(name)}, status_code=200)
+
+
def main(standalone=False):
if not standalone:
return app