From 6459f3e53ec69c55bb8a8e16053a4dcfab2d95f4 Mon Sep 17 00:00:00 2001 From: Mark Bailey Date: Sat, 26 Oct 2024 19:20:39 -0400 Subject: [PATCH] refactor: major overhaul --- src/main.py | 86 +++++++++++++++++++++++++---------------------------- 1 file changed, 40 insertions(+), 46 deletions(-) diff --git a/src/main.py b/src/main.py index 49954e7..c7091d5 100644 --- a/src/main.py +++ b/src/main.py @@ -1,71 +1,56 @@ +import enum import os -import pathlib -import numpy as np from base64 import b64decode +from urllib.parse import unquote -from fastapi.responses import JSONResponse -from pydantic import BaseModel +import numpy as np from doctr.io import DocumentFile from doctr.models import ocr_predictor -from fastapi import FastAPI, HTTPException, Response +from fastapi import FastAPI, HTTPException from jinja2 import Environment, FileSystemLoader, select_autoescape -from urllib.parse import unquote +from pydantic import BaseModel +from starlette.requests import Request +from starlette.responses import HTMLResponse, JSONResponse app = FastAPI() env = Environment(loader=FileSystemLoader("templates"), autoescape=select_autoescape()) -@app.get("/") -def root(): - base_dir = "/home/m/Projects/sandbox/data/app/upload/classifier/deal/389/" - files = os.listdir(base_dir) - - predictor = ocr_predictor( - "db_resnet50", - pretrained=True, - assume_straight_pages=False, - preserve_aspect_ratio=True, - ).cuda() - - result = {} - i = 0 - for f in files: - if i >= 1: - break - - if pathlib.Path(f).suffix not in [".jpg", ".png"]: - continue - - print("Working on: " + f) - doc = DocumentFile.from_images(base_dir + f) - - pred_res = predictor(doc) - result[f] = pred_res.render() - - i += 1 - - return Response(renderTemplate("main.html", {"results": result})) - - class Image(BaseModel): file_name: str file_contents: str -@app.post("/detect/") -def file(file: Image): +class DetectionStrength(str, enum.Enum): + high = "high" + medium = "medium" + low = "low" + + +@app.get("/") +def get_root(request: Request): + if request.url.path != "/": + raise HTTPException(status_code=404) + + return renderTemplate("main.html") + + +@app.post("/detect/{strength}") +def post_detect(file: Image, strength: DetectionStrength): file_name = unquote(file.file_name) + print("\nWorking on: " + file_name) + print("\nStrength: " + strength + "\n") + with open("data/" + file_name, "wb") as f: f.write(b64decode(file.file_contents)) - print("Working on: " + file_name) - + [det_model, rec_model] = get_models_for_detection_strength(strength) try: predictor = ocr_predictor( - "db_resnet50", - "vitstr_base", + det_model, + rec_model, pretrained=True, straighten_pages=True, preserve_aspect_ratio=True, @@ -110,6 +95,15 @@ def convert_dict_items_to_list(d: dict): return converted -def renderTemplate(template, context=None): +def get_models_for_detection_strength(strength: DetectionStrength): + if strength == DetectionStrength.high: + return ["linknet_resnet50", "vitstr_base"] + elif strength == DetectionStrength.medium: + return ["db_resnet50", "vitstr_base"] + elif strength == DetectionStrength.low: + return ["db_resnet50", "vitstr_tiny"] + + +def renderTemplate(template, context={}): template = env.get_template(template) - return template.render(context) + return HTMLResponse(template.render(context))