doctr-api/src/main.py

94 lines
2.5 KiB
Python

import os
import pathlib
from fastapi.responses import JSONResponse
import numpy as np
from doctr.io import DocumentFile
from doctr.models import ocr_predictor
from fastapi import FastAPI, HTTPException, Response
from jinja2 import Environment, FileSystemLoader, select_autoescape
from urllib.parse import unquote
app = FastAPI()
env = Environment(loader=FileSystemLoader("templates"), autoescape=select_autoescape())
@app.get("/")
def root():
base_dir = "/home/m/Projects/sandbox/data/app/upload/classifier/deal/389/"
files = os.listdir(base_dir)
predictor = ocr_predictor(
"db_resnet50",
pretrained=True,
assume_straight_pages=False,
preserve_aspect_ratio=True,
)
result = {}
i = 0
for f in files:
if i >= 1:
break
if pathlib.Path(f).suffix not in [".jpg", ".png"]:
continue
print("Working on: " + f)
doc = DocumentFile.from_images(base_dir + f)
pred_res = predictor(doc)
result[f] = pred_res.render()
i += 1
return Response(renderTemplate("main.html", {"results": result}))
@app.get("/detect/")
def file(file_name: str):
file_name = os.path.expanduser("~/Projects/sandbox/")+unquote(file_name).lstrip("/")
print("Working on: " + file_name)
try:
predictor = ocr_predictor(
"db_resnet50",
"vitstr_base",
pretrained=True,
assume_straight_pages=False,
preserve_aspect_ratio=True,
)
doc = DocumentFile.from_images(file_name)
pred_res = predictor(doc)
json_res = pred_res.export()
converted = convert_dict_items_to_list(json_res)
return JSONResponse(content=converted)
except Exception as e:
raise HTTPException(status_code=422, detail=str(e))
def convert_to_list(value):
if isinstance(value, dict):
return {k:convert_to_list(v) for k,v in value.items()}
elif isinstance(value, list):
return [convert_to_list(item) if isinstance(item, (dict, np.ndarray)) else item.tolist() if isinstance(item, np.ndarray) else item for item in value]
elif isinstance(value, np.ndarray):
return value.tolist()
else:
return value
def convert_dict_items_to_list(d: dict):
converted = {}
for k, v in d.items():
converted[k] = convert_to_list(v)
return converted
def renderTemplate(template, context=None):
template = env.get_template(template)
return template.render(context)