app.py 8.7 KB


  1. import threading
  2. import time
  3. import uvicorn
  4. import cv2
  5. import gradio as gr
  6. from lineless_table_rec import LinelessTableRecognition
  7. from paddleocr import PPStructure
  8. from rapid_table import RapidTable
  9. from rapidocr_onnxruntime import RapidOCR
  10. # from rapidocr_paddle import RapidOCR
  11. from slanet_plus_table import SLANetPlus
  12. from table_cls import TableCls
  13. from PIL import Image
  14. from wired_table_rec import WiredTableRecognition
  15. from fastapi import APIRouter, File, UploadFile
  16. from utils import plot_rec_box, LoadImage, format_html, box_4_2_poly_to_box_4_1, build_logger
  17. logger = build_logger()
  18. img_loader = LoadImage()
  19. table_rec_path = "models/table_rec/ch_ppstructure_mobile_v2_SLANet.onnx"
  20. det_model_dir = {
  21. "mobile_det": "models/ocr/ch_PP-OCRv4_det_infer.onnx",
  22. }
  23. rec_model_dir = {
  24. "mobile_rec": "models/ocr/ch_PP-OCRv4_rec_infer.onnx",
  25. }
  26. table_engine_list = [
  27. "auto",
  28. "RapidTable(SLANet)",
  29. "RapidTable(SLANet-plus)",
  30. "wired_table_v2",
  31. "pp_table",
  32. "wired_table_v1",
  33. "lineless_table"
  34. ]
  35. # 示例图片路径
  36. example_images = [
  37. "images/wired1.png",
  38. "images/wired2.png",
  39. "images/wired3.png",
  40. "images/lineless1.png",
  41. "images/wired4.jpg",
  42. "images/lineless2.png",
  43. "images/wired5.jpg",
  44. "images/lineless3.jpg",
  45. "images/wired6.jpg",
  46. ]
  47. rapid_table_engine = RapidTable(model_path=table_rec_path)
  48. SLANet_plus_table_Engine = RapidTable()
  49. wired_table_engine_v1 = WiredTableRecognition(version="v1")
  50. wired_table_engine_v2 = WiredTableRecognition(version="v2")
  51. lineless_table_engine = LinelessTableRecognition()
  52. table_cls = TableCls()
  53. ocr_engine_dict = {}
  54. pp_engine_dict = {}
  55. for det_model in det_model_dir.keys():
  56. for rec_model in rec_model_dir.keys():
  57. det_model_path = det_model_dir[det_model]
  58. rec_model_path = rec_model_dir[rec_model]
  59. key = f"{det_model}_{rec_model}"
  60. ocr_engine_dict[key] = RapidOCR(det_model_path=det_model_path, rec_model_path=rec_model_path,
  61. rec_image_shape=[3, 48, 320])
  62. # ocr_engine_dict[key] = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
  63. pp_engine_dict[key] = PPStructure(
  64. layout=False,
  65. show_log=False,
  66. table=True,
  67. use_onnx=True,
  68. table_model_dir=table_rec_path,
  69. det_model_dir=det_model_path,
  70. rec_model_dir=rec_model_path
  71. )
  72. def select_ocr_model(det_model, rec_model):
  73. return ocr_engine_dict[f"{det_model}_{rec_model}"]
  74. def select_table_model(img, table_engine_type, det_model, rec_model):
  75. if table_engine_type == "RapidTable(SLANet)":
  76. return rapid_table_engine, table_engine_type
  77. elif table_engine_type == "RapidTable(SLANet-plus)":
  78. return SLANet_plus_table_Engine, table_engine_type
  79. elif table_engine_type == "wired_table_v1":
  80. return wired_table_engine_v1, table_engine_type
  81. elif table_engine_type == "wired_table_v2":
  82. print("使用v2 wired table")
  83. return wired_table_engine_v2, table_engine_type
  84. elif table_engine_type == "lineless_table":
  85. return lineless_table_engine, table_engine_type
  86. elif table_engine_type == "pp_table":
  87. return pp_engine_dict[f"{det_model}_{rec_model}"], 0
  88. elif table_engine_type == "auto":
  89. cls, elasp = table_cls(img)
  90. if cls == 'wired':
  91. table_engine = wired_table_engine_v2
  92. return table_engine, "wired_table_v2"
  93. return lineless_table_engine, "lineless_table"
  94. def process_image(img, table_engine_type, det_model, rec_model):
  95. img = img_loader(img)
  96. start = time.time()
  97. table_engine, table_type = select_table_model(img, table_engine_type, det_model, rec_model)
  98. ocr_engine = select_ocr_model(det_model, rec_model)
  99. if isinstance(table_engine, PPStructure):
  100. result = table_engine(img, return_ocr_result_in_table=True)
  101. html = result[0]['res']['html']
  102. polygons = result[0]['res']['cell_bbox']
  103. polygons = [[polygon[0], polygon[1], polygon[4], polygon[5]] for polygon in polygons]
  104. ocr_boxes = result[0]['res']['boxes']
  105. all_elapse = f"- `table all cost: {time.time() - start:.5f}`"
  106. else:
  107. ocr_res, ocr_infer_elapse = ocr_engine(img)
  108. det_cost, cls_cost, rec_cost = ocr_infer_elapse
  109. ocr_boxes = [box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_res]
  110. if isinstance(table_engine, RapidTable):
  111. html, polygons, table_rec_elapse = table_engine(img, ocr_result=ocr_res)
  112. polygons = [[polygon[0], polygon[1], polygon[4], polygon[5]] for polygon in polygons]
  113. elif isinstance(table_engine, (WiredTableRecognition, LinelessTableRecognition)):
  114. html, table_rec_elapse, polygons, _, _ = table_engine(img, ocr_result=ocr_res)
  115. if not polygons:
  116. # RapidTable模型兜底
  117. table_engine, table_type = select_table_model(img, "RapidTable(SLANet)", det_model, rec_model)
  118. ocr_engine = select_ocr_model(det_model, rec_model)
  119. ocr_res, ocr_infer_elapse = ocr_engine(img)
  120. det_cost, cls_cost, rec_cost = ocr_infer_elapse
  121. ocr_boxes = [box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_res]
  122. html, polygons, table_rec_elapse = table_engine(img, ocr_result=ocr_res)
  123. polygons = [[polygon[0], polygon[1], polygon[4], polygon[5]] for polygon in polygons]
  124. sum_elapse = time.time() - start
  125. all_elapse = f"- table_type: {table_type}\n table all cost: {sum_elapse:.5f}\n - table rec cost: {table_rec_elapse:.5f}\n - ocr cost: {det_cost + cls_cost + rec_cost:.5f}"
  126. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  127. table_boxes_img = plot_rec_box(img.copy(), polygons)
  128. ocr_boxes_img = plot_rec_box(img.copy(), ocr_boxes)
  129. complete_html = format_html(html)
  130. return complete_html, table_boxes_img, ocr_boxes_img, all_elapse
  131. def main():
  132. det_models_labels = list(det_model_dir.keys())
  133. rec_models_labels = list(rec_model_dir.keys())
  134. with gr.Blocks(css="""
  135. .scrollable-container {
  136. overflow-x: auto;
  137. white-space: nowrap;
  138. }
  139. """) as demo:
  140. with gr.Row(): # 两列布局
  141. with gr.Tab("Options"):
  142. with gr.Column(variant="panel", scale=1): # 侧边栏,宽度比例为1
  143. img_input = gr.Image(label="Upload or Select Image", sources="upload", value="images/lineless3.jpg")
  144. # 示例图片选择器
  145. examples = gr.Examples(
  146. examples=example_images,
  147. inputs=img_input,
  148. fn=lambda x: x, # 简单返回图片路径
  149. outputs=img_input,
  150. cache_examples=True
  151. )
  152. table_engine_type = gr.Dropdown(table_engine_list, label="Select Recognition Table Engine",
  153. value=table_engine_list[0])
  154. det_model = gr.Dropdown(det_models_labels, label="Select OCR Detection Model",
  155. value=det_models_labels[0])
  156. rec_model = gr.Dropdown(rec_models_labels, label="Select OCR Recognition Model",
  157. value=rec_models_labels[0])
  158. run_button = gr.Button("Run")
  159. gr.Markdown("# Elapsed Time")
  160. elapse_text = gr.Text(label="") # 使用 `gr.Text` 组件展示字符串
  161. with gr.Column(scale=2): # 右边列
  162. # 使用 Markdown 标题分隔各个组件
  163. gr.Markdown("# Html Render")
  164. html_output = gr.HTML(label="", elem_classes="scrollable-container")
  165. gr.Markdown("# Table Boxes")
  166. table_boxes_output = gr.Image(label="")
  167. gr.Markdown("# OCR Boxes")
  168. ocr_boxes_output = gr.Image(label="")
  169. run_button.click(
  170. fn=process_image,
  171. inputs=[img_input, table_engine_type, det_model, rec_model],
  172. outputs=[html_output, table_boxes_output, ocr_boxes_output, elapse_text]
  173. )
  174. demo.launch(server_port=20334, share=True)
  175. img_router = APIRouter(prefix="/img", tags=["img ocr"])
  176. @img_router.post("/img_ocr", summary="图片ocr形成表格")
  177. async def img_ocr(file: UploadFile = File(..., description="上传图片"), ):
  178. start_time = time.time()
  179. img = Image.open(file.file)
  180. complete_html, table_boxes_img, ocr_boxes_img, all_elapse = process_image(img, "auto", "mobile_det", "mobile_rec")
  181. logger.info(f"finish ocr {file.filename},total use time:{time.time() - start_time}")
  182. return complete_html
  183. if __name__ == '__main__':
  184. # main()
  185. uvicorn.run(app="app:img_router", host='0.0.0.0', port=8512, workers=4, reload=True)