94 lines
2.5 KiB
Python
94 lines
2.5 KiB
Python
import os
|
|
import pathlib
|
|
from fastapi.responses import JSONResponse
|
|
import numpy as np
|
|
|
|
from doctr.io import DocumentFile
|
|
from doctr.models import ocr_predictor
|
|
from fastapi import FastAPI, HTTPException, Response
|
|
from jinja2 import Environment, FileSystemLoader, select_autoescape
|
|
from urllib.parse import unquote
|
|
|
|
app = FastAPI()
|
|
|
|
env = Environment(loader=FileSystemLoader("templates"), autoescape=select_autoescape())
|
|
|
|
|
|
@app.get("/")
|
|
def root():
|
|
base_dir = "/home/m/Projects/sandbox/data/app/upload/classifier/deal/389/"
|
|
files = os.listdir(base_dir)
|
|
|
|
predictor = ocr_predictor(
|
|
"db_resnet50",
|
|
pretrained=True,
|
|
assume_straight_pages=False,
|
|
preserve_aspect_ratio=True,
|
|
)
|
|
|
|
result = {}
|
|
i = 0
|
|
for f in files:
|
|
if i >= 1:
|
|
break
|
|
|
|
if pathlib.Path(f).suffix not in [".jpg", ".png"]:
|
|
continue
|
|
|
|
print("Working on: " + f)
|
|
doc = DocumentFile.from_images(base_dir + f)
|
|
|
|
pred_res = predictor(doc)
|
|
result[f] = pred_res.render()
|
|
|
|
i += 1
|
|
|
|
return Response(renderTemplate("main.html", {"results": result}))
|
|
|
|
|
|
@app.get("/detect/")
|
|
def file(file_name: str):
|
|
file_name = os.path.expanduser("~/Projects/sandbox/")+unquote(file_name).lstrip("/")
|
|
|
|
print("Working on: " + file_name)
|
|
|
|
try:
|
|
predictor = ocr_predictor(
|
|
"db_resnet50",
|
|
"vitstr_base",
|
|
pretrained=True,
|
|
assume_straight_pages=False,
|
|
preserve_aspect_ratio=True,
|
|
)
|
|
doc = DocumentFile.from_images(file_name)
|
|
pred_res = predictor(doc)
|
|
json_res = pred_res.export()
|
|
converted = convert_dict_items_to_list(json_res)
|
|
return JSONResponse(content=converted)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=422, detail=str(e))
|
|
|
|
|
|
def convert_to_list(value):
|
|
if isinstance(value, dict):
|
|
return {k:convert_to_list(v) for k,v in value.items()}
|
|
elif isinstance(value, list):
|
|
return [convert_to_list(item) if isinstance(item, (dict, np.ndarray)) else item.tolist() if isinstance(item, np.ndarray) else item for item in value]
|
|
elif isinstance(value, np.ndarray):
|
|
return value.tolist()
|
|
else:
|
|
return value
|
|
|
|
def convert_dict_items_to_list(d: dict):
|
|
converted = {}
|
|
|
|
for k, v in d.items():
|
|
converted[k] = convert_to_list(v)
|
|
|
|
return converted
|
|
|
|
|
|
def renderTemplate(template, context=None):
|
|
template = env.get_template(template)
|
|
return template.render(context)
|