doctr-api/src/main.py

112 lines
2.9 KiB
Python

import enum
import os
from base64 import b64decode
from urllib.parse import unquote
import numpy as np
from doctr.io import DocumentFile
from doctr.models import ocr_predictor
from fastapi import FastAPI, HTTPException
from jinja2 import Environment, FileSystemLoader, select_autoescape
from pydantic import BaseModel
from starlette.requests import Request
from starlette.responses import HTMLResponse, JSONResponse
app = FastAPI()
env = Environment(loader=FileSystemLoader("templates"), autoescape=select_autoescape())
class Image(BaseModel):
file_name: str
file_contents: str
class DetectionStrength(str, enum.Enum):
high = "high"
medium = "medium"
low = "low"
@app.get("/")
def get_root(request: Request):
if request.url.path != "/":
raise HTTPException(status_code=404)
return render("main.html")
@app.post("/detect/{strength}")
def post_detect(file: Image, strength: DetectionStrength):
file_name = unquote(file.file_name)
print("\nWorking on: " + file_name)
print("\nStrength: " + strength + "\n")
if not os.path.exists("data/"):
os.makedirs("data")
with open("data/" + file_name, "wb") as f:
f.write(b64decode(file.file_contents))
[det_model, rec_model] = get_models_for_detection_strength(strength)
try:
predictor = ocr_predictor(
det_model,
rec_model,
pretrained=True,
straighten_pages=True,
preserve_aspect_ratio=True,
).cuda()
doc = DocumentFile.from_images("data/" + file_name)
pred_res = predictor(doc)
json_res = pred_res.export()
converted = convert_dict_items_to_list(json_res)
os.unlink("data/" + file_name)
return JSONResponse(content=converted)
except Exception as e:
raise HTTPException(status_code=422, detail=str(e))
def convert_to_list(value):
if isinstance(value, dict):
return {k: convert_to_list(v) for k, v in value.items()}
elif isinstance(value, list):
return [
convert_to_list(item)
if isinstance(item, (dict, np.ndarray))
else item.tolist()
if isinstance(item, np.ndarray)
else item
for item in value
]
elif isinstance(value, np.ndarray):
return value.tolist()
else:
return value
def convert_dict_items_to_list(d: dict):
converted = {}
for k, v in d.items():
converted[k] = convert_to_list(v)
return converted
def get_models_for_detection_strength(strength: DetectionStrength):
if strength == DetectionStrength.high:
return ["linknet_resnet50", "vitstr_base"]
elif strength == DetectionStrength.medium:
return ["db_resnet50", "vitstr_base"]
elif strength == DetectionStrength.low:
return ["db_resnet50", "vitstr_small"]
def render(template, context={}):
return HTMLResponse(env.get_template(template).render(context))