refactor: major overhaul
This commit is contained in:
parent
e1b66a87ac
commit
6459f3e53e
86
src/main.py
86
src/main.py
@ -1,71 +1,56 @@
|
||||
import enum
|
||||
import os
|
||||
import pathlib
|
||||
import numpy as np
|
||||
from base64 import b64decode
|
||||
from urllib.parse import unquote
|
||||
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
import numpy as np
|
||||
from doctr.io import DocumentFile
|
||||
from doctr.models import ocr_predictor
|
||||
from fastapi import FastAPI, HTTPException, Response
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from jinja2 import Environment, FileSystemLoader, select_autoescape
|
||||
from urllib.parse import unquote
|
||||
from pydantic import BaseModel
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import HTMLResponse, JSONResponse
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
env = Environment(loader=FileSystemLoader("templates"), autoescape=select_autoescape())
|
||||
|
||||
|
||||
@app.get("/")
|
||||
def root():
|
||||
base_dir = "/home/m/Projects/sandbox/data/app/upload/classifier/deal/389/"
|
||||
files = os.listdir(base_dir)
|
||||
|
||||
predictor = ocr_predictor(
|
||||
"db_resnet50",
|
||||
pretrained=True,
|
||||
assume_straight_pages=False,
|
||||
preserve_aspect_ratio=True,
|
||||
).cuda()
|
||||
|
||||
result = {}
|
||||
i = 0
|
||||
for f in files:
|
||||
if i >= 1:
|
||||
break
|
||||
|
||||
if pathlib.Path(f).suffix not in [".jpg", ".png"]:
|
||||
continue
|
||||
|
||||
print("Working on: " + f)
|
||||
doc = DocumentFile.from_images(base_dir + f)
|
||||
|
||||
pred_res = predictor(doc)
|
||||
result[f] = pred_res.render()
|
||||
|
||||
i += 1
|
||||
|
||||
return Response(renderTemplate("main.html", {"results": result}))
|
||||
|
||||
|
||||
class Image(BaseModel):
|
||||
file_name: str
|
||||
file_contents: str
|
||||
|
||||
|
||||
@app.post("/detect/")
|
||||
def file(file: Image):
|
||||
class DetectionStrength(str, enum.Enum):
|
||||
high = "high"
|
||||
medium = "medium"
|
||||
low = "low"
|
||||
|
||||
|
||||
@app.get("/")
|
||||
def get_root(request: Request):
|
||||
if request.url.path != "/":
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
return renderTemplate("main.html")
|
||||
|
||||
|
||||
@app.post("/detect/{strength}")
|
||||
def post_detect(file: Image, strength: DetectionStrength):
|
||||
file_name = unquote(file.file_name)
|
||||
|
||||
print("\nWorking on: " + file_name)
|
||||
print("\nStrength: " + strength + "\n")
|
||||
|
||||
with open("data/" + file_name, "wb") as f:
|
||||
f.write(b64decode(file.file_contents))
|
||||
|
||||
print("Working on: " + file_name)
|
||||
|
||||
[det_model, rec_model] = get_models_for_detection_strength(strength)
|
||||
try:
|
||||
predictor = ocr_predictor(
|
||||
"db_resnet50",
|
||||
"vitstr_base",
|
||||
det_model,
|
||||
rec_model,
|
||||
pretrained=True,
|
||||
straighten_pages=True,
|
||||
preserve_aspect_ratio=True,
|
||||
@ -110,6 +95,15 @@ def convert_dict_items_to_list(d: dict):
|
||||
return converted
|
||||
|
||||
|
||||
def renderTemplate(template, context=None):
|
||||
def get_models_for_detection_strength(strength: DetectionStrength):
|
||||
if strength == DetectionStrength.high:
|
||||
return ["linknet_resnet50", "vitstr_base"]
|
||||
elif strength == DetectionStrength.medium:
|
||||
return ["db_resnet50", "vitstr_base"]
|
||||
elif strength == DetectionStrength.low:
|
||||
return ["db_resnet50", "vitstr_tiny"]
|
||||
|
||||
|
||||
def renderTemplate(template, context={}):
|
||||
template = env.get_template(template)
|
||||
return template.render(context)
|
||||
return HTMLResponse(template.render(context))
|
||||
|
Loading…
x
Reference in New Issue
Block a user