112 lines
2.9 KiB
Python
112 lines
2.9 KiB
Python
import enum
|
|
import os
|
|
from base64 import b64decode
|
|
from urllib.parse import unquote
|
|
|
|
import numpy as np
|
|
from doctr.io import DocumentFile
|
|
from doctr.models import ocr_predictor
|
|
from fastapi import FastAPI, HTTPException
|
|
from jinja2 import Environment, FileSystemLoader, select_autoescape
|
|
from pydantic import BaseModel
|
|
from starlette.requests import Request
|
|
from starlette.responses import HTMLResponse, JSONResponse
|
|
|
|
app = FastAPI()
|
|
|
|
env = Environment(loader=FileSystemLoader("templates"), autoescape=select_autoescape())
|
|
|
|
|
|
class Image(BaseModel):
|
|
file_name: str
|
|
file_contents: str
|
|
|
|
|
|
class DetectionStrength(str, enum.Enum):
|
|
high = "high"
|
|
medium = "medium"
|
|
low = "low"
|
|
|
|
|
|
@app.get("/")
|
|
def get_root(request: Request):
|
|
if request.url.path != "/":
|
|
raise HTTPException(status_code=404)
|
|
|
|
return render("main.html")
|
|
|
|
|
|
@app.post("/detect/{strength}")
|
|
def post_detect(file: Image, strength: DetectionStrength):
|
|
file_name = unquote(file.file_name)
|
|
|
|
print("\nWorking on: " + file_name)
|
|
print("\nStrength: " + strength + "\n")
|
|
|
|
if not os.path.exists("data/"):
|
|
os.makedirs("data")
|
|
|
|
with open("data/" + file_name, "wb") as f:
|
|
f.write(b64decode(file.file_contents))
|
|
|
|
[det_model, rec_model] = get_models_for_detection_strength(strength)
|
|
try:
|
|
predictor = ocr_predictor(
|
|
det_model,
|
|
rec_model,
|
|
pretrained=True,
|
|
straighten_pages=True,
|
|
preserve_aspect_ratio=True,
|
|
).cuda()
|
|
doc = DocumentFile.from_images("data/" + file_name)
|
|
pred_res = predictor(doc)
|
|
json_res = pred_res.export()
|
|
converted = convert_dict_items_to_list(json_res)
|
|
|
|
os.unlink("data/" + file_name)
|
|
|
|
return JSONResponse(content=converted)
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=422, detail=str(e))
|
|
|
|
|
|
def convert_to_list(value):
|
|
if isinstance(value, dict):
|
|
return {k: convert_to_list(v) for k, v in value.items()}
|
|
elif isinstance(value, list):
|
|
return [
|
|
convert_to_list(item)
|
|
if isinstance(item, (dict, np.ndarray))
|
|
else item.tolist()
|
|
if isinstance(item, np.ndarray)
|
|
else item
|
|
for item in value
|
|
]
|
|
elif isinstance(value, np.ndarray):
|
|
return value.tolist()
|
|
else:
|
|
return value
|
|
|
|
|
|
def convert_dict_items_to_list(d: dict):
|
|
converted = {}
|
|
|
|
for k, v in d.items():
|
|
converted[k] = convert_to_list(v)
|
|
|
|
return converted
|
|
|
|
|
|
def get_models_for_detection_strength(strength: DetectionStrength):
|
|
if strength == DetectionStrength.high:
|
|
return ["linknet_resnet50", "vitstr_base"]
|
|
elif strength == DetectionStrength.medium:
|
|
return ["db_resnet50", "vitstr_base"]
|
|
elif strength == DetectionStrength.low:
|
|
return ["db_resnet50", "vitstr_small"]
|
|
|
|
|
|
def render(template, context={}):
|
|
return HTMLResponse(env.get_template(template).render(context))
|