refactor: major overhaul

This commit is contained in:
Mark Bailey 2024-10-26 19:20:39 -04:00
parent e1b66a87ac
commit 6459f3e53e

View File

@ -1,71 +1,56 @@
import enum
import os
import pathlib
import numpy as np
from base64 import b64decode
from urllib.parse import unquote
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import numpy as np
from doctr.io import DocumentFile
from doctr.models import ocr_predictor
from fastapi import FastAPI, HTTPException, Response
from fastapi import FastAPI, HTTPException
from jinja2 import Environment, FileSystemLoader, select_autoescape
from urllib.parse import unquote
from pydantic import BaseModel
from starlette.requests import Request
from starlette.responses import HTMLResponse, JSONResponse
app = FastAPI()
env = Environment(loader=FileSystemLoader("templates"), autoescape=select_autoescape())
@app.get("/")
def root():
base_dir = "/home/m/Projects/sandbox/data/app/upload/classifier/deal/389/"
files = os.listdir(base_dir)
predictor = ocr_predictor(
"db_resnet50",
pretrained=True,
assume_straight_pages=False,
preserve_aspect_ratio=True,
).cuda()
result = {}
i = 0
for f in files:
if i >= 1:
break
if pathlib.Path(f).suffix not in [".jpg", ".png"]:
continue
print("Working on: " + f)
doc = DocumentFile.from_images(base_dir + f)
pred_res = predictor(doc)
result[f] = pred_res.render()
i += 1
return Response(renderTemplate("main.html", {"results": result}))
class Image(BaseModel):
file_name: str
file_contents: str
@app.post("/detect/")
def file(file: Image):
class DetectionStrength(str, enum.Enum):
high = "high"
medium = "medium"
low = "low"
@app.get("/")
def get_root(request: Request):
if request.url.path != "/":
raise HTTPException(status_code=404)
return renderTemplate("main.html")
@app.post("/detect/{strength}")
def post_detect(file: Image, strength: DetectionStrength):
file_name = unquote(file.file_name)
print("\nWorking on: " + file_name)
print("\nStrength: " + strength + "\n")
with open("data/" + file_name, "wb") as f:
f.write(b64decode(file.file_contents))
print("Working on: " + file_name)
[det_model, rec_model] = get_models_for_detection_strength(strength)
try:
predictor = ocr_predictor(
"db_resnet50",
"vitstr_base",
det_model,
rec_model,
pretrained=True,
straighten_pages=True,
preserve_aspect_ratio=True,
@ -110,6 +95,15 @@ def convert_dict_items_to_list(d: dict):
return converted
def renderTemplate(template, context=None):
def get_models_for_detection_strength(strength: DetectionStrength):
if strength == DetectionStrength.high:
return ["linknet_resnet50", "vitstr_base"]
elif strength == DetectionStrength.medium:
return ["db_resnet50", "vitstr_base"]
elif strength == DetectionStrength.low:
return ["db_resnet50", "vitstr_tiny"]
def renderTemplate(template, context={}):
template = env.get_template(template)
return template.render(context)
return HTMLResponse(template.render(context))