refactor: major overhaul
This commit is contained in:
parent
e1b66a87ac
commit
6459f3e53e
86
src/main.py
86
src/main.py
@ -1,71 +1,56 @@
|
|||||||
|
import enum
|
||||||
import os
|
import os
|
||||||
import pathlib
|
|
||||||
import numpy as np
|
|
||||||
from base64 import b64decode
|
from base64 import b64decode
|
||||||
|
from urllib.parse import unquote
|
||||||
|
|
||||||
from fastapi.responses import JSONResponse
|
import numpy as np
|
||||||
from pydantic import BaseModel
|
|
||||||
from doctr.io import DocumentFile
|
from doctr.io import DocumentFile
|
||||||
from doctr.models import ocr_predictor
|
from doctr.models import ocr_predictor
|
||||||
from fastapi import FastAPI, HTTPException, Response
|
from fastapi import FastAPI, HTTPException
|
||||||
from jinja2 import Environment, FileSystemLoader, select_autoescape
|
from jinja2 import Environment, FileSystemLoader, select_autoescape
|
||||||
from urllib.parse import unquote
|
from pydantic import BaseModel
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import HTMLResponse, JSONResponse
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
env = Environment(loader=FileSystemLoader("templates"), autoescape=select_autoescape())
|
env = Environment(loader=FileSystemLoader("templates"), autoescape=select_autoescape())
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
|
||||||
def root():
|
|
||||||
base_dir = "/home/m/Projects/sandbox/data/app/upload/classifier/deal/389/"
|
|
||||||
files = os.listdir(base_dir)
|
|
||||||
|
|
||||||
predictor = ocr_predictor(
|
|
||||||
"db_resnet50",
|
|
||||||
pretrained=True,
|
|
||||||
assume_straight_pages=False,
|
|
||||||
preserve_aspect_ratio=True,
|
|
||||||
).cuda()
|
|
||||||
|
|
||||||
result = {}
|
|
||||||
i = 0
|
|
||||||
for f in files:
|
|
||||||
if i >= 1:
|
|
||||||
break
|
|
||||||
|
|
||||||
if pathlib.Path(f).suffix not in [".jpg", ".png"]:
|
|
||||||
continue
|
|
||||||
|
|
||||||
print("Working on: " + f)
|
|
||||||
doc = DocumentFile.from_images(base_dir + f)
|
|
||||||
|
|
||||||
pred_res = predictor(doc)
|
|
||||||
result[f] = pred_res.render()
|
|
||||||
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
return Response(renderTemplate("main.html", {"results": result}))
|
|
||||||
|
|
||||||
|
|
||||||
class Image(BaseModel):
|
class Image(BaseModel):
|
||||||
file_name: str
|
file_name: str
|
||||||
file_contents: str
|
file_contents: str
|
||||||
|
|
||||||
|
|
||||||
@app.post("/detect/")
|
class DetectionStrength(str, enum.Enum):
|
||||||
def file(file: Image):
|
high = "high"
|
||||||
|
medium = "medium"
|
||||||
|
low = "low"
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
def get_root(request: Request):
|
||||||
|
if request.url.path != "/":
|
||||||
|
raise HTTPException(status_code=404)
|
||||||
|
|
||||||
|
return renderTemplate("main.html")
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/detect/{strength}")
|
||||||
|
def post_detect(file: Image, strength: DetectionStrength):
|
||||||
file_name = unquote(file.file_name)
|
file_name = unquote(file.file_name)
|
||||||
|
|
||||||
|
print("\nWorking on: " + file_name)
|
||||||
|
print("\nStrength: " + strength + "\n")
|
||||||
|
|
||||||
with open("data/" + file_name, "wb") as f:
|
with open("data/" + file_name, "wb") as f:
|
||||||
f.write(b64decode(file.file_contents))
|
f.write(b64decode(file.file_contents))
|
||||||
|
|
||||||
print("Working on: " + file_name)
|
[det_model, rec_model] = get_models_for_detection_strength(strength)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
predictor = ocr_predictor(
|
predictor = ocr_predictor(
|
||||||
"db_resnet50",
|
det_model,
|
||||||
"vitstr_base",
|
rec_model,
|
||||||
pretrained=True,
|
pretrained=True,
|
||||||
straighten_pages=True,
|
straighten_pages=True,
|
||||||
preserve_aspect_ratio=True,
|
preserve_aspect_ratio=True,
|
||||||
@ -110,6 +95,15 @@ def convert_dict_items_to_list(d: dict):
|
|||||||
return converted
|
return converted
|
||||||
|
|
||||||
|
|
||||||
def renderTemplate(template, context=None):
|
def get_models_for_detection_strength(strength: DetectionStrength):
|
||||||
|
if strength == DetectionStrength.high:
|
||||||
|
return ["linknet_resnet50", "vitstr_base"]
|
||||||
|
elif strength == DetectionStrength.medium:
|
||||||
|
return ["db_resnet50", "vitstr_base"]
|
||||||
|
elif strength == DetectionStrength.low:
|
||||||
|
return ["db_resnet50", "vitstr_tiny"]
|
||||||
|
|
||||||
|
|
||||||
|
def renderTemplate(template, context={}):
|
||||||
template = env.get_template(template)
|
template = env.get_template(template)
|
||||||
return template.render(context)
|
return HTMLResponse(template.render(context))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user