import enum import os from base64 import b64decode from urllib.parse import unquote import numpy as np from doctr.io import DocumentFile from doctr.models import ocr_predictor from fastapi import FastAPI, HTTPException from jinja2 import Environment, FileSystemLoader, select_autoescape from pydantic import BaseModel from starlette.requests import Request from starlette.responses import HTMLResponse, JSONResponse app = FastAPI() env = Environment(loader=FileSystemLoader("templates"), autoescape=select_autoescape()) class Image(BaseModel): file_name: str file_contents: str class DetectionStrength(str, enum.Enum): high = "high" medium = "medium" low = "low" @app.get("/") def get_root(request: Request): if request.url.path != "/": raise HTTPException(status_code=404) return render("main.html") @app.post("/detect/{strength}") def post_detect(file: Image, strength: DetectionStrength): file_name = unquote(file.file_name) print("\nWorking on: " + file_name) print("\nStrength: " + strength + "\n") if not os.path.exists("data/"): os.makedirs("data") with open("data/" + file_name, "wb") as f: f.write(b64decode(file.file_contents)) [det_model, rec_model] = get_models_for_detection_strength(strength) try: predictor = ocr_predictor( det_model, rec_model, pretrained=True, straighten_pages=True, preserve_aspect_ratio=True, ).cuda() doc = DocumentFile.from_images("data/" + file_name) pred_res = predictor(doc) json_res = pred_res.export() converted = convert_dict_items_to_list(json_res) os.unlink("data/" + file_name) return JSONResponse(content=converted) except Exception as e: raise HTTPException(status_code=422, detail=str(e)) def convert_to_list(value): if isinstance(value, dict): return {k: convert_to_list(v) for k, v in value.items()} elif isinstance(value, list): return [ convert_to_list(item) if isinstance(item, (dict, np.ndarray)) else item.tolist() if isinstance(item, np.ndarray) else item for item in value ] elif isinstance(value, np.ndarray): return value.tolist() else: return value def convert_dict_items_to_list(d: dict): converted = {} for k, v in d.items(): converted[k] = convert_to_list(v) return converted def get_models_for_detection_strength(strength: DetectionStrength): if strength == DetectionStrength.high: return ["linknet_resnet50", "vitstr_base"] elif strength == DetectionStrength.medium: return ["db_resnet50", "vitstr_base"] elif strength == DetectionStrength.low: return ["db_resnet50", "vitstr_small"] def render(template, context={}): return HTMLResponse(env.get_template(template).render(context))