refactor: major overhaul

This commit is contained in:
Mark Bailey 2024-10-26 19:20:39 -04:00
parent e1b66a87ac
commit 6459f3e53e

View File

@ -1,71 +1,56 @@
import enum
import os import os
import pathlib
import numpy as np
from base64 import b64decode from base64 import b64decode
from urllib.parse import unquote
from fastapi.responses import JSONResponse import numpy as np
from pydantic import BaseModel
from doctr.io import DocumentFile from doctr.io import DocumentFile
from doctr.models import ocr_predictor from doctr.models import ocr_predictor
from fastapi import FastAPI, HTTPException, Response from fastapi import FastAPI, HTTPException
from jinja2 import Environment, FileSystemLoader, select_autoescape from jinja2 import Environment, FileSystemLoader, select_autoescape
from urllib.parse import unquote from pydantic import BaseModel
from starlette.requests import Request
from starlette.responses import HTMLResponse, JSONResponse
app = FastAPI() app = FastAPI()
env = Environment(loader=FileSystemLoader("templates"), autoescape=select_autoescape()) env = Environment(loader=FileSystemLoader("templates"), autoescape=select_autoescape())
@app.get("/")
def root():
base_dir = "/home/m/Projects/sandbox/data/app/upload/classifier/deal/389/"
files = os.listdir(base_dir)
predictor = ocr_predictor(
"db_resnet50",
pretrained=True,
assume_straight_pages=False,
preserve_aspect_ratio=True,
).cuda()
result = {}
i = 0
for f in files:
if i >= 1:
break
if pathlib.Path(f).suffix not in [".jpg", ".png"]:
continue
print("Working on: " + f)
doc = DocumentFile.from_images(base_dir + f)
pred_res = predictor(doc)
result[f] = pred_res.render()
i += 1
return Response(renderTemplate("main.html", {"results": result}))
class Image(BaseModel): class Image(BaseModel):
file_name: str file_name: str
file_contents: str file_contents: str
@app.post("/detect/") class DetectionStrength(str, enum.Enum):
def file(file: Image): high = "high"
medium = "medium"
low = "low"
@app.get("/")
def get_root(request: Request):
if request.url.path != "/":
raise HTTPException(status_code=404)
return renderTemplate("main.html")
@app.post("/detect/{strength}")
def post_detect(file: Image, strength: DetectionStrength):
file_name = unquote(file.file_name) file_name = unquote(file.file_name)
print("\nWorking on: " + file_name)
print("\nStrength: " + strength + "\n")
with open("data/" + file_name, "wb") as f: with open("data/" + file_name, "wb") as f:
f.write(b64decode(file.file_contents)) f.write(b64decode(file.file_contents))
print("Working on: " + file_name) [det_model, rec_model] = get_models_for_detection_strength(strength)
try: try:
predictor = ocr_predictor( predictor = ocr_predictor(
"db_resnet50", det_model,
"vitstr_base", rec_model,
pretrained=True, pretrained=True,
straighten_pages=True, straighten_pages=True,
preserve_aspect_ratio=True, preserve_aspect_ratio=True,
@ -110,6 +95,15 @@ def convert_dict_items_to_list(d: dict):
return converted return converted
def renderTemplate(template, context=None): def get_models_for_detection_strength(strength: DetectionStrength):
if strength == DetectionStrength.high:
return ["linknet_resnet50", "vitstr_base"]
elif strength == DetectionStrength.medium:
return ["db_resnet50", "vitstr_base"]
elif strength == DetectionStrength.low:
return ["db_resnet50", "vitstr_tiny"]
def renderTemplate(template, context={}):
template = env.get_template(template) template = env.get_template(template)
return template.render(context) return HTMLResponse(template.render(context))