utils.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. from io import BytesIO
  2. from pathlib import Path
  3. from typing import Union, List
  4. import os
  5. from functools import partial
  6. import loguru
  7. import loguru._logger
  8. from memoization import cached, CachingAlgorithmFlag
  9. import numpy as np
  10. import cv2
  11. from PIL import UnidentifiedImageError, Image
  12. InputType = Union[str, np.ndarray, bytes, Path, Image.Image]
  13. class LoadImage:
  14. def __init__(
  15. self,
  16. ):
  17. pass
  18. def __call__(self, img: InputType) -> np.ndarray:
  19. if not isinstance(img, InputType.__args__):
  20. raise LoadImageError(
  21. f"The img type {type(img)} does not in {InputType.__args__}"
  22. )
  23. origin_img_type = type(img)
  24. img = self.load_img(img)
  25. img = self.convert_img(img, origin_img_type)
  26. return img
  27. def load_img(self, img: InputType) -> np.ndarray:
  28. if isinstance(img, (str, Path)):
  29. self.verify_exist(img)
  30. try:
  31. img = np.array(Image.open(img))
  32. except UnidentifiedImageError as e:
  33. raise LoadImageError(f"cannot identify image file {img}") from e
  34. return img
  35. if isinstance(img, bytes):
  36. img = np.array(Image.open(BytesIO(img)))
  37. return img
  38. if isinstance(img, BytesIO):
  39. img = np.array(Image.open(img))
  40. return img
  41. if isinstance(img, np.ndarray):
  42. return img
  43. if isinstance(img, Image.Image):
  44. return np.array(img)
  45. raise LoadImageError(f"{type(img)} is not supported!")
  46. def convert_img(self, img: np.ndarray, origin_img_type):
  47. if img.ndim == 2:
  48. return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
  49. if img.ndim == 3:
  50. channel = img.shape[2]
  51. if channel == 1:
  52. return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
  53. if channel == 2:
  54. return self.cvt_two_to_three(img)
  55. if channel == 3:
  56. if issubclass(origin_img_type, (str, Path, bytes, Image.Image)):
  57. return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
  58. return img
  59. if channel == 4:
  60. return self.cvt_four_to_three(img)
  61. raise LoadImageError(
  62. f"The channel({channel}) of the img is not in [1, 2, 3, 4]"
  63. )
  64. raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]")
  65. @staticmethod
  66. def cvt_two_to_three(img: np.ndarray) -> np.ndarray:
  67. """gray + alpha → BGR"""
  68. img_gray = img[..., 0]
  69. img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR)
  70. img_alpha = img[..., 1]
  71. not_a = cv2.bitwise_not(img_alpha)
  72. not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
  73. new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha)
  74. new_img = cv2.add(new_img, not_a)
  75. return new_img
  76. @staticmethod
  77. def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
  78. """RGBA → BGR"""
  79. r, g, b, a = cv2.split(img)
  80. new_img = cv2.merge((b, g, r))
  81. not_a = cv2.bitwise_not(a)
  82. not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
  83. new_img = cv2.bitwise_and(new_img, new_img, mask=a)
  84. new_img = cv2.add(new_img, not_a)
  85. return new_img
  86. @staticmethod
  87. def verify_exist(file_path: Union[str, Path]):
  88. if not Path(file_path).exists():
  89. raise LoadImageError(f"{file_path} does not exist.")
  90. class LoadImageError(Exception):
  91. pass
  92. def plot_rec_box_with_logic_info(img_path, logic_points, sorted_polygons, without_text=True):
  93. """
  94. :param img_path
  95. :param output_path
  96. :param logic_points: [row_start,row_end,col_start,col_end]
  97. :param sorted_polygons: [xmin,ymin,xmax,ymax]
  98. :return:
  99. """
  100. # 读取原图
  101. img = cv2.imread(img_path)
  102. img = cv2.copyMakeBorder(
  103. img, 0, 0, 0, 100, cv2.BORDER_CONSTANT, value=[255, 255, 255]
  104. )
  105. # 绘制 polygons 矩形
  106. for idx, polygon in enumerate(sorted_polygons):
  107. x0, y0, x1, y1 = polygon[0], polygon[1], polygon[2], polygon[3]
  108. x0 = round(x0)
  109. y0 = round(y0)
  110. x1 = round(x1)
  111. y1 = round(y1)
  112. cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), 1)
  113. # 增大字体大小和线宽
  114. font_scale = 1.0 # 原先是0.5
  115. thickness = 2 # 原先是1
  116. if without_text:
  117. return img
  118. cv2.putText(
  119. img,
  120. f"{idx}",
  121. (x1, y1),
  122. cv2.FONT_HERSHEY_PLAIN,
  123. font_scale,
  124. (0, 0, 255),
  125. thickness,
  126. )
  127. return img
  128. def plot_rec_box(img, sorted_polygons):
  129. """
  130. :param img_path
  131. :param output_path
  132. :param sorted_polygons: [xmin,ymin,xmax,ymax]
  133. :return:
  134. """
  135. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  136. # 处理ocr_res
  137. img = cv2.copyMakeBorder(
  138. img, 0, 0, 0, 100, cv2.BORDER_CONSTANT, value=[255, 255, 255]
  139. )
  140. # 绘制 ocr_res 矩形
  141. for idx, polygon in enumerate(sorted_polygons):
  142. x0, y0, x1, y1 = polygon[0], polygon[1], polygon[2], polygon[3]
  143. x0 = round(x0)
  144. y0 = round(y0)
  145. x1 = round(x1)
  146. y1 = round(y1)
  147. cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), 1)
  148. # 增大字体大小和线宽
  149. font_scale = 1.0 # 原先是0.5
  150. thickness = 2 # 原先是1
  151. # cv2.putText(
  152. # img,
  153. # str(idx),
  154. # (x1, y1),
  155. # cv2.FONT_HERSHEY_PLAIN,
  156. # font_scale,
  157. # (0, 0, 255),
  158. # thickness,
  159. # )
  160. return img
  161. def format_html(html: str):
  162. html = html.replace("<html>", "")
  163. html = html.replace("</html>", "")
  164. html = html.replace("<body>", "")
  165. html = html.replace("</body>", "")
  166. return f"""
  167. <!DOCTYPE html>
  168. <html lang="zh-CN">
  169. <head>
  170. <meta charset="UTF-8">
  171. <title>Complex Table Example</title>
  172. <style>
  173. table {{
  174. border-collapse: collapse;
  175. width: 100%;
  176. }}
  177. th, td {{
  178. border: 1px solid black;
  179. padding: 8px;
  180. text-align: center;
  181. }}
  182. th {{
  183. background-color: #f2f2f2;
  184. }}
  185. </style>
  186. </head>
  187. <body>
  188. {html}
  189. </body>
  190. </html>
  191. """
  192. def box_4_2_poly_to_box_4_1(poly_box: Union[np.ndarray, list]) -> List[float]:
  193. """
  194. 将poly_box转换为box_4_1
  195. :param poly_box:
  196. :return:
  197. """
  198. return [poly_box[0][0], poly_box[0][1], poly_box[2][0], poly_box[2][1]]
  199. def _filter_logs(record: dict) -> bool:
  200. # hide debug logs if Settings.basic_settings.log_verbose=False
  201. if record["level"].no <= 10:
  202. return False
  203. # hide traceback logs if Settings.basic_settings.log_verbose=False
  204. if record["level"].no == 40:
  205. record["exception"] = None
  206. return True
  207. @cached(max_size=100, algorithm=CachingAlgorithmFlag.LRU)
  208. def build_logger(log_file: str = "chatchat"):
  209. """
  210. build a logger with colorized output and a log file, for example:
  211. logger = build_logger("api")
  212. logger.info("<green>some message</green>")
  213. user can set basic_settings.log_verbose=True to output debug logs
  214. use logger.exception to log errors with exceptions
  215. """
  216. loguru.logger._core.handlers[0]._filter = _filter_logs
  217. logger = loguru.logger.opt(colors=True)
  218. logger.opt = partial(loguru.logger.opt, colors=True)
  219. # logger.error = partial(logger.exception)
  220. if log_file:
  221. if not log_file.endswith(".log"):
  222. log_file = f"{log_file}.log"
  223. logger.add(log_file, colorize=False, filter=_filter_logs)
  224. return logger