paddleXtablePlus.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481
  1. # import time
  2. # from io import BytesIO
  3. # from pathlib import Path
  4. # from typing import Union, List, Tuple
  5. #
  6. # import cv2
  7. # from PIL import Image, UnidentifiedImageError
  8. #
  9. # import numpy as np
  10. # from paddle.inference import Config, create_predictor
  11. # InputType = Union[str, np.ndarray, bytes, Path, Image.Image]
  12. # # paddle2onnx --model_dir C:\Users\51954\.paddlex\official_models\SLANet_plus --model_filename inference.pdmodel --params_filename inference.pdiparams --save_file ./table.onnx --opset_version 16 --enable_onnx_checker
  13. # # paddle2onnx --model_dir ./ --model_filename inference.pdmodel --params_filename inference.pdiparams --save_file ./table.onnx --opset_version 16 --enable_onnx_checker
  14. #
  15. # def get_boxes_recs(
  16. # ocr_result: List[Union[List[List[float]], str, str]], h: int, w: int
  17. # ) -> Tuple[np.ndarray, Tuple[str, str]]:
  18. # dt_boxes, rec_res, scores = list(zip(*ocr_result))
  19. # rec_res = list(zip(rec_res, scores))
  20. #
  21. # r_boxes = []
  22. # for box in dt_boxes:
  23. # box = np.array(box)
  24. # x_min = max(0, box[:, 0].min() - 1)
  25. # x_max = min(w, box[:, 0].max() + 1)
  26. # y_min = max(0, box[:, 1].min() - 1)
  27. # y_max = min(h, box[:, 1].max() + 1)
  28. # box = [x_min, y_min, x_max, y_max]
  29. # r_boxes.append(box)
  30. # dt_boxes = np.array(r_boxes)
  31. # return dt_boxes, rec_res
  32. # def distance(box_1, box_2):
  33. # x1, y1, x2, y2 = box_1
  34. # x3, y3, x4, y4 = box_2
  35. # dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2)
  36. # dis_2 = abs(x3 - x1) + abs(y3 - y1)
  37. # dis_3 = abs(x4 - x2) + abs(y4 - y2)
  38. # return dis + min(dis_2, dis_3)
  39. #
  40. # def convert_corners_to_bounding_boxes(corners):
  41. # """
  42. # 转换给定的角点坐标到边界框坐标 (xmin, ymin, xmax, ymax)。
  43. #
  44. # 参数:
  45. # corners : numpy.ndarray
  46. # 形状为 (n, 8) 的数组,每行包含四个角点的坐标 (x1, y1, x2, y2, x3, y3, x4, y4)。
  47. #
  48. # 返回:
  49. # bounding_boxes : numpy.ndarray
  50. # 形状为 (n, 4) 的数组,每行包含 (xmin, ymin, xmax, ymax)。
  51. # """
  52. # # 分别提取四个角点的 x 和 y 坐标
  53. # x1, y1, x2, y2, x3, y3, x4, y4 = np.split(corners, 8, axis=1)
  54. #
  55. # # 计算 xmin, ymin, xmax, ymax
  56. # xmin = np.min(np.hstack((x1, x2, x3, x4)), axis=1, keepdims=True)
  57. # ymin = np.min(np.hstack((y1, y2, y3, y4)), axis=1, keepdims=True)
  58. # xmax = np.max(np.hstack((x1, x2, x3, x4)), axis=1, keepdims=True)
  59. # ymax = np.max(np.hstack((y1, y2, y3, y4)), axis=1, keepdims=True)
  60. #
  61. # # 拼接成新的数组
  62. # bounding_boxes = np.concatenate((xmin, ymin, xmax, ymax), axis=1)
  63. #
  64. # return bounding_boxes
  65. # def compute_iou(rec1, rec2):
  66. # """
  67. # computing IoU
  68. # :param rec1: (y0, x0, y1, x1), which reflects
  69. # (top, left, bottom, right)
  70. # :param rec2: (y0, x0, y1, x1)
  71. # :return: scala value of IoU
  72. # """
  73. # # computing area of each rectangles
  74. # S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
  75. # S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
  76. #
  77. # # computing the sum_area
  78. # sum_area = S_rec1 + S_rec2
  79. #
  80. # # find the each edge of intersect rectangle
  81. # left_line = max(rec1[1], rec2[1])
  82. # right_line = min(rec1[3], rec2[3])
  83. # top_line = max(rec1[0], rec2[0])
  84. # bottom_line = min(rec1[2], rec2[2])
  85. #
  86. # # judge if there is an intersect
  87. # if left_line >= right_line or top_line >= bottom_line:
  88. # return 0.0
  89. # else:
  90. # intersect = (right_line - left_line) * (bottom_line - top_line)
  91. # return (intersect / (sum_area - intersect)) * 1.0
  92. #
  93. # class LoadImageError(Exception):
  94. # pass
  95. #
  96. #
  97. # class LoadImage:
  98. # def __init__(self):
  99. # pass
  100. #
  101. # def __call__(self, img: InputType) -> np.ndarray:
  102. # if not isinstance(img, InputType.__args__):
  103. # raise LoadImageError(
  104. # f"The img type {type(img)} does not in {InputType.__args__}"
  105. # )
  106. #
  107. # origin_img_type = type(img)
  108. # img = self.load_img(img)
  109. # img = self.convert_img(img, origin_img_type)
  110. # return img
  111. #
  112. # def load_img(self, img: InputType) -> np.ndarray:
  113. # if isinstance(img, (str, Path)):
  114. # self.verify_exist(img)
  115. # try:
  116. # img = np.array(Image.open(img))
  117. # except UnidentifiedImageError as e:
  118. # raise LoadImageError(f"cannot identify image file {img}") from e
  119. # return img
  120. #
  121. # if isinstance(img, bytes):
  122. # img = np.array(Image.open(BytesIO(img)))
  123. # return img
  124. #
  125. # if isinstance(img, np.ndarray):
  126. # return img
  127. #
  128. # if isinstance(img, Image.Image):
  129. # return np.array(img)
  130. #
  131. # raise LoadImageError(f"{type(img)} is not supported!")
  132. #
  133. # def convert_img(self, img: np.ndarray, origin_img_type):
  134. # if img.ndim == 2:
  135. # return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
  136. #
  137. # if img.ndim == 3:
  138. # channel = img.shape[2]
  139. # if channel == 1:
  140. # return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
  141. #
  142. # if channel == 2:
  143. # return self.cvt_two_to_three(img)
  144. #
  145. # if channel == 3:
  146. # if issubclass(origin_img_type, (str, Path, bytes, Image.Image)):
  147. # return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
  148. # return img
  149. #
  150. # if channel == 4:
  151. # return self.cvt_four_to_three(img)
  152. #
  153. # raise LoadImageError(
  154. # f"The channel({channel}) of the img is not in [1, 2, 3, 4]"
  155. # )
  156. #
  157. # raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]")
  158. #
  159. # @staticmethod
  160. # def cvt_two_to_three(img: np.ndarray) -> np.ndarray:
  161. # """gray + alpha → BGR"""
  162. # img_gray = img[..., 0]
  163. # img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR)
  164. #
  165. # img_alpha = img[..., 1]
  166. # not_a = cv2.bitwise_not(img_alpha)
  167. # not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
  168. #
  169. # new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha)
  170. # new_img = cv2.add(new_img, not_a)
  171. # return new_img
  172. #
  173. # @staticmethod
  174. # def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
  175. # """RGBA → BGR"""
  176. # r, g, b, a = cv2.split(img)
  177. # new_img = cv2.merge((b, g, r))
  178. #
  179. # not_a = cv2.bitwise_not(a)
  180. # not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
  181. #
  182. # new_img = cv2.bitwise_and(new_img, new_img, mask=a)
  183. # new_img = cv2.add(new_img, not_a)
  184. # return new_img
  185. #
  186. # @staticmethod
  187. # def verify_exist(file_path: Union[str, Path]):
  188. # if not Path(file_path).exists():
  189. # raise LoadImageError(f"{file_path} does not exist.")
  190. #
  191. #
  192. # class TableMatch:
  193. # def __init__(self, filter_ocr_result=True, use_master=False):
  194. # self.filter_ocr_result = filter_ocr_result
  195. # self.use_master = use_master
  196. #
  197. # def __call__(self, pred_structures, pred_bboxes, dt_boxes, rec_res):
  198. # if self.filter_ocr_result:
  199. # dt_boxes, rec_res = self._filter_ocr_result(pred_bboxes, dt_boxes, rec_res)
  200. # matched_index = self.match_result(dt_boxes, pred_bboxes)
  201. # pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res)
  202. # return pred_html
  203. #
  204. # def match_result(self, dt_boxes, pred_bboxes):
  205. # matched = {}
  206. # for i, gt_box in enumerate(dt_boxes):
  207. # distances = []
  208. # for j, pred_box in enumerate(pred_bboxes):
  209. # if len(pred_box) == 8:
  210. # pred_box = [
  211. # np.min(pred_box[0::2]),
  212. # np.min(pred_box[1::2]),
  213. # np.max(pred_box[0::2]),
  214. # np.max(pred_box[1::2]),
  215. # ]
  216. # distances.append(
  217. # (distance(gt_box, pred_box), 1.0 - compute_iou(gt_box, pred_box))
  218. # ) # compute iou and l1 distance
  219. # sorted_distances = distances.copy()
  220. # # select det box by iou and l1 distance
  221. # sorted_distances = sorted(
  222. # sorted_distances, key=lambda item: (item[1], item[0])
  223. # )
  224. # if distances.index(sorted_distances[0]) not in matched.keys():
  225. # matched[distances.index(sorted_distances[0])] = [i]
  226. # else:
  227. # matched[distances.index(sorted_distances[0])].append(i)
  228. # return matched
  229. #
  230. # def get_pred_html(self, pred_structures, matched_index, ocr_contents):
  231. # end_html = []
  232. # td_index = 0
  233. # for tag in pred_structures:
  234. # if "</td>" not in tag:
  235. # end_html.append(tag)
  236. # continue
  237. #
  238. # if "<td></td>" == tag:
  239. # end_html.extend("<td>")
  240. #
  241. # if td_index in matched_index.keys():
  242. # b_with = False
  243. # if (
  244. # "<b>" in ocr_contents[matched_index[td_index][0]]
  245. # and len(matched_index[td_index]) > 1
  246. # ):
  247. # b_with = True
  248. # end_html.extend("<b>")
  249. #
  250. # for i, td_index_index in enumerate(matched_index[td_index]):
  251. # content = ocr_contents[td_index_index][0]
  252. # if len(matched_index[td_index]) > 1:
  253. # if len(content) == 0:
  254. # continue
  255. #
  256. # if content[0] == " ":
  257. # content = content[1:]
  258. #
  259. # if "<b>" in content:
  260. # content = content[3:]
  261. #
  262. # if "</b>" in content:
  263. # content = content[:-4]
  264. #
  265. # if len(content) == 0:
  266. # continue
  267. #
  268. # if i != len(matched_index[td_index]) - 1 and " " != content[-1]:
  269. # content += " "
  270. # end_html.extend(content)
  271. #
  272. # if b_with:
  273. # end_html.extend("</b>")
  274. #
  275. # if "<td></td>" == tag:
  276. # end_html.append("</td>")
  277. # else:
  278. # end_html.append(tag)
  279. #
  280. # td_index += 1
  281. #
  282. # # Filter <thead></thead><tbody></tbody> elements
  283. # filter_elements = ["<thead>", "</thead>", "<tbody>", "</tbody>"]
  284. # end_html = [v for v in end_html if v not in filter_elements]
  285. # return "".join(end_html), end_html
  286. #
  287. # def _filter_ocr_result(self, pred_bboxes, dt_boxes, rec_res):
  288. # y1 = pred_bboxes[:, 1::2].min()
  289. # new_dt_boxes = []
  290. # new_rec_res = []
  291. #
  292. # for box, rec in zip(dt_boxes, rec_res):
  293. # if np.max(box[1::2]) < y1:
  294. # continue
  295. # new_dt_boxes.append(box)
  296. # new_rec_res.append(rec)
  297. # return new_dt_boxes, new_rec_res
  298. #
  299. # class TablePredictor:
  300. # def __init__(self, model_dir, model_prefix="inference"):
  301. # model_file = f"{model_dir}/{model_prefix}.pdmodel"
  302. # params_file = f"{model_dir}/{model_prefix}.pdiparams"
  303. # config = Config(model_file, params_file)
  304. # config.disable_gpu()
  305. # config.disable_glog_info()
  306. # config.enable_new_ir(True)
  307. # config.enable_new_executor(True)
  308. # config.enable_memory_optim()
  309. # config.switch_ir_optim(True)
  310. # # Disable feed, fetch OP, needed by zero_copy_run
  311. # config.switch_use_feed_fetch_ops(False)
  312. # predictor = create_predictor(config)
  313. # self.config = config
  314. # self.predictor = predictor
  315. # # Get input and output handlers
  316. # input_names = predictor.get_input_names()
  317. # self.input_names = input_names.sort()
  318. # self.input_handlers = []
  319. # self.output_handlers = []
  320. # for input_name in input_names:
  321. # input_handler = predictor.get_input_handle(input_name)
  322. # self.input_handlers.append(input_handler)
  323. # self.output_names = predictor.get_output_names()
  324. # for output_name in self.output_names:
  325. # output_handler = predictor.get_output_handle(output_name)
  326. # self.output_handlers.append(output_handler)
  327. #
  328. # def __call__(self, batch_imgs):
  329. # self.input_handlers[0].reshape(batch_imgs.shape)
  330. # self.input_handlers[0].copy_from_cpu(batch_imgs)
  331. # self.predictor.run()
  332. # output = []
  333. # for out_tensor in self.output_handlers:
  334. # batch = out_tensor.copy_to_cpu()
  335. # output.append(batch)
  336. # return self.format_output(output)
  337. #
  338. # def format_output(self, pred):
  339. # return [res for res in zip(*pred)]
  340. #
  341. #
  342. # class SLANetPlus:
  343. # def __init__(self, model_dir, model_prefix="inference"):
  344. # self.mean=[0.485, 0.456, 0.406]
  345. # self.std=[0.229, 0.224, 0.225]
  346. # self.target_img_size = [488, 488]
  347. # self.scale=1 / 255
  348. # self.order="hwc"
  349. # self.img_loader = LoadImage()
  350. # self.target_size = 488
  351. # self.pad_color = 0
  352. # self.predictor = TablePredictor(model_dir, model_prefix)
  353. # dict_character=['sos', '<thead>', '</thead>', '<tbody>', '</tbody>', '<tr>', '</tr>', '<td', '>', '</td>', ' colspan="2"', ' colspan="3"', ' colspan="4"', ' colspan="5"', ' colspan="6"', ' colspan="7"', ' colspan="8"', ' colspan="9"', ' colspan="10"', ' colspan="11"', ' colspan="12"', ' colspan="13"', ' colspan="14"', ' colspan="15"', ' colspan="16"', ' colspan="17"', ' colspan="18"', ' colspan="19"', ' colspan="20"', ' rowspan="2"', ' rowspan="3"', ' rowspan="4"', ' rowspan="5"', ' rowspan="6"', ' rowspan="7"', ' rowspan="8"', ' rowspan="9"', ' rowspan="10"', ' rowspan="11"', ' rowspan="12"', ' rowspan="13"', ' rowspan="14"', ' rowspan="15"', ' rowspan="16"', ' rowspan="17"', ' rowspan="18"', ' rowspan="19"', ' rowspan="20"', '<td></td>', 'eos']
  354. # self.beg_str = "sos"
  355. # self.end_str = "eos"
  356. # self.dict = {}
  357. # self.table_matcher = TableMatch()
  358. # for i, char in enumerate(dict_character):
  359. # self.dict[char] = i
  360. # self.character = dict_character
  361. # self.td_token = ["<td>", "<td", "<td></td>"]
  362. #
  363. # def __call__(self, img, ocr_result):
  364. # img = self.img_loader(img)
  365. # h, w = img.shape[:2]
  366. # n_img, h_resize, w_resize = self.resize(img)
  367. # n_img = self.normalize(n_img)
  368. # n_img = self.pad(n_img)
  369. # n_img = n_img.transpose((2, 0, 1))
  370. # n_img = np.expand_dims(n_img, axis=0)
  371. # start = time.time()
  372. # batch_output = self.predictor(n_img)
  373. # elapse_time = time.time() - start
  374. # ori_img_size = [[w, h]]
  375. # output = self.decode(batch_output, ori_img_size)[0]
  376. # corners = np.stack(output['bbox'], axis=0)
  377. # dt_boxes, rec_res = get_boxes_recs(ocr_result, h, w)
  378. # pred_html = self.table_matcher(output['structure'], convert_corners_to_bounding_boxes(corners), dt_boxes, rec_res)
  379. # return pred_html,output['bbox'], elapse_time
  380. # def resize(self, img):
  381. # h, w = img.shape[:2]
  382. # scale = self.target_size / max(h, w)
  383. # h_resize = round(h * scale)
  384. # w_resize = round(w * scale)
  385. # resized_img = cv2.resize(img, (w_resize, h_resize), interpolation=cv2.INTER_LINEAR)
  386. # return resized_img, h_resize, w_resize
  387. # def pad(self, img):
  388. # h, w = img.shape[:2]
  389. # tw, th = self.target_img_size
  390. # ph = th - h
  391. # pw = tw - w
  392. # pad = (0, ph, 0, pw)
  393. # chns = 1 if img.ndim == 2 else img.shape[2]
  394. # im = cv2.copyMakeBorder(img, *pad, cv2.BORDER_CONSTANT, value=(self.pad_color,) * chns)
  395. # return im
  396. # def normalize(self, img):
  397. # img = img.astype("float32", copy=False)
  398. # img *= self.scale
  399. # img -= self.mean
  400. # img /= self.std
  401. # return img
  402. #
  403. #
  404. # def decode(self, pred, ori_img_size):
  405. # bbox_preds, structure_probs = [], []
  406. # for bbox_pred, stru_prob in pred:
  407. # bbox_preds.append(bbox_pred)
  408. # structure_probs.append(stru_prob)
  409. # bbox_preds = np.array(bbox_preds)
  410. # structure_probs = np.array(structure_probs)
  411. #
  412. # bbox_list, structure_str_list, structure_score = self.decode_single(
  413. # structure_probs, bbox_preds, [self.target_img_size], ori_img_size
  414. # )
  415. # structure_str_list = [
  416. # (
  417. # ["<html>", "<body>", "<table>"]
  418. # + structure
  419. # + ["</table>", "</body>", "</html>"]
  420. # )
  421. # for structure in structure_str_list
  422. # ]
  423. # return [
  424. # {"bbox": bbox, "structure": structure, "structure_score": structure_score}
  425. # for bbox, structure in zip(bbox_list, structure_str_list)
  426. # ]
  427. #
  428. #
  429. # def decode_single(self, structure_probs, bbox_preds, padding_size, ori_img_size):
  430. # """convert text-label into text-index."""
  431. # ignored_tokens = [self.beg_str, self.end_str]
  432. # end_idx = self.dict[self.end_str]
  433. #
  434. # structure_idx = structure_probs.argmax(axis=2)
  435. # structure_probs = structure_probs.max(axis=2)
  436. #
  437. # structure_batch_list = []
  438. # bbox_batch_list = []
  439. # batch_size = len(structure_idx)
  440. # for batch_idx in range(batch_size):
  441. # structure_list = []
  442. # bbox_list = []
  443. # score_list = []
  444. # for idx in range(len(structure_idx[batch_idx])):
  445. # char_idx = int(structure_idx[batch_idx][idx])
  446. # if idx > 0 and char_idx == end_idx:
  447. # break
  448. # if char_idx in ignored_tokens:
  449. # continue
  450. # text = self.character[char_idx]
  451. # if text in self.td_token:
  452. # bbox = bbox_preds[batch_idx, idx]
  453. # bbox = self._bbox_decode(
  454. # bbox, padding_size[batch_idx], ori_img_size[batch_idx]
  455. # )
  456. # bbox_list.append(bbox.astype(int))
  457. # structure_list.append(text)
  458. # score_list.append(structure_probs[batch_idx, idx])
  459. # structure_batch_list.append(structure_list)
  460. # structure_score = np.mean(score_list)
  461. # bbox_batch_list.append(bbox_list)
  462. #
  463. # return bbox_batch_list, structure_batch_list, structure_score
  464. #
  465. # def _bbox_decode(self, bbox, padding_shape, ori_shape):
  466. #
  467. # pad_w, pad_h = padding_shape
  468. # w, h = ori_shape
  469. # ratio_w = pad_w / w
  470. # ratio_h = pad_h / h
  471. # ratio = min(ratio_w, ratio_h)
  472. #
  473. # bbox[0::2] *= pad_w
  474. # bbox[1::2] *= pad_h
  475. # bbox[0::2] /= ratio
  476. # bbox[1::2] /= ratio
  477. #
  478. # return bbox
  479. #
  480. #
  481. #