formula_segment.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. # @Author : lightXu
  2. # @File : formula_segment.py
  3. # @Time : 2019/1/24 0024 下午 13:24
  4. import time
  5. import re
  6. import copy
  7. import math
  8. import cv2
  9. import numpy as np
  10. import xml.etree.cElementTree as ET
  11. from segment.formula import mathpix_ocr
  12. from segment.server import get_ocr_text_and_coordinate_formula
  13. from segment.image_operation import utils
  14. def get_coordinates(word_res, formula_words_list):
  15. res_list = []
  16. for formula_raw in formula_words_list:
  17. coordinates_start_index = formula_raw[1][0]
  18. coordinates_end_index = formula_raw[1][1] - 1
  19. coordinates_start = word_res['chars'][coordinates_start_index]['location']
  20. coordinates_end = word_res['chars'][coordinates_end_index]['location']
  21. coordinates = (coordinates_start['left'], # xmin
  22. min(coordinates_start['top'], coordinates_end['top']), # ymin
  23. coordinates_end['left'] + coordinates_end['width'], # xmax
  24. max(coordinates_start['top'] + coordinates_start['height'],
  25. coordinates_end['top'] + coordinates_end['height'])) # ymax
  26. tmp_dict = {'chars': formula_raw[0],
  27. 'raw_chars': formula_raw[0],
  28. 'coordinates': coordinates,
  29. 'middle': (coordinates[0] + int((coordinates[2] - coordinates[0]) // 2),
  30. coordinates[1] + int((coordinates[3] - coordinates[1]) // 2))}
  31. res_list.append(tmp_dict)
  32. return res_list
  33. def generate_char(words, index_pair, zh=True):
  34. if index_pair:
  35. # new_words = words.copy()
  36. length = index_pair[1] - index_pair[0]
  37. gen = ''
  38. if zh:
  39. for i in range(length):
  40. gen = '中' + gen
  41. else:
  42. for i in range(length):
  43. gen = 'F' + gen
  44. words = words.replace(words[index_pair[0]:index_pair[1]], gen)
  45. return words
  46. else:
  47. return words
  48. def segment(img, save_path, access_token):
  49. # raw_img = img.copy()
  50. # img = utils.preprocess(raw_img, None)
  51. word_result_list = get_ocr_text_and_coordinate_formula(img, access_token)
  52. formula_coordinates_dict_list = []
  53. zh_coordinates_dict_list = []
  54. zh_char_height = 20 # default
  55. zh_char_width = 15 # default
  56. zh_char_height_list = []
  57. zh_char_width_list = []
  58. exclude = r'{}|{}|{}|{}|{}|{}'.format(
  59. '^[((]*[\d]+[))]',
  60. '[ABCD]\.',
  61. '[\u4e00-\u9fa5][,;:。,;:.]',
  62. '[①②③④⑤⑥⑦⑧⑨⑩]',
  63. '[((][))]',
  64. '[\u4e00-\u9fa5][\d]+[\u4e00-\u9fa5]')
  65. for index, word_res in enumerate(word_result_list):
  66. words = word_res['words'].replace(' ', '').replace('兀', 'π') # 去除空格,baidu_api bug
  67. abcd_words_m = re.finditer(exclude, words)
  68. abcd_index_list = [(m.group(), m.span()) for m in abcd_words_m if m]
  69. words_tmp_zh = copy.copy(words)
  70. for ele in abcd_index_list:
  71. words_tmp_zh = generate_char(words_tmp_zh, ele[1], zh=True)
  72. formula_words_m = re.finditer(r'[^\u4e00-\u9fa5._"“”]+', words_tmp_zh)
  73. formula_index_list = [(m.group(), m.span()) for m in formula_words_m if m]
  74. formula_list = get_coordinates(word_res, formula_index_list)
  75. formula_coordinates_dict_list = formula_coordinates_dict_list + formula_list
  76. words_tmp_formula = copy.copy(words)
  77. for ele in abcd_index_list:
  78. words_tmp_formula = generate_char(words_tmp_formula, ele[1], zh=False)
  79. zh_words_m = re.finditer(r'[\u4e00-\u9fa5._"“”]+', words_tmp_formula)
  80. zh_index_list = [(m.group(), m.span()) for m in zh_words_m if m]
  81. zh_list = get_coordinates(word_res, zh_index_list + abcd_index_list)
  82. zh_coordinates_dict_list = zh_coordinates_dict_list + zh_list
  83. one_zh_char_m = re.match(r'[\u4e00-\u9fa5]+', words)
  84. if one_zh_char_m:
  85. index = one_zh_char_m.span()[0]
  86. zh_char_height_list.append(word_res['chars'][index]['location']['height'])
  87. zh_char_width_list.append(word_res['chars'][index]['location']['width'])
  88. if len(zh_char_width_list) > 0 and len(zh_char_height_list) > 0:
  89. zh_char_height = np.mean(zh_char_height_list)
  90. zh_char_width = np.mean(zh_char_width_list)
  91. formula_coordinates_list = [ele['coordinates'] for ele in formula_coordinates_dict_list]
  92. formula_combine_list = combine(formula_coordinates_list, zh_char_height, zh_char_width) # 欧式距离
  93. formula_combine_dict_list = []
  94. for i, ele in enumerate(formula_combine_list):
  95. middle = (ele[0] + int((ele[2] - ele[0]) // 2), ele[1] + int((ele[3] - ele[1]) // 2))
  96. ocr_region = utils.crop_region_direct(img, ele)
  97. y, x = ocr_region.shape[0], ocr_region.shape[1]
  98. if min(y, x) <= 50:
  99. ocr_region = utils.resize_by_percent(ocr_region, 2.00) # 放大若干倍
  100. try:
  101. mathpix_raw_chars, latex_confidence = mathpix_ocr.mathpix_api(ocr_region) # 识别公式
  102. render_mathpix_chars = '<latex>{}</latex>'.format(mathpix_raw_chars)
  103. if latex_confidence < 0.2:
  104. for item in formula_coordinates_dict_list:
  105. if ele == item['coordinates']:
  106. mathpix_raw_chars = item['chars']
  107. render_mathpix_chars = '<latex>{}</latex>'.format(item['chars'])
  108. break
  109. except Exception:
  110. render_mathpix_chars = 'formula'
  111. mathpix_raw_chars = 'formula'
  112. for item in formula_coordinates_dict_list:
  113. if ele == item['coordinates']:
  114. mathpix_raw_chars = item['chars']
  115. render_mathpix_chars = '<latex>{}</latex>'.format(item['chars'])
  116. break
  117. # print(render_mathpix_chars)
  118. tmp_dict = {'chars': render_mathpix_chars, 'middle': middle, 'coordinates': ele, 'raw_chars': mathpix_raw_chars}
  119. formula_combine_dict_list.append(tmp_dict)
  120. # res_dict = {'formula': formula_combine_list, 'zh_chars': zh_coordinates_dict_list}
  121. all_dict_list = formula_combine_dict_list + zh_coordinates_dict_list
  122. all_dict_list = sorted(all_dict_list, key=lambda k: k.get('middle')[1])
  123. # 相邻y做差
  124. former = np.array([ele['middle'][1] for ele in all_dict_list[:-1]])
  125. rear = np.array([ele['middle'][1] for ele in all_dict_list[1:]])
  126. dif = rear - former
  127. split_x_index = [index for index, ele in enumerate(dif) if ele >= zh_char_height] # y轴排序
  128. if not split_x_index:
  129. all_dict_list = sorted(all_dict_list, key=lambda k: k.get('middle')[0]) # x轴排序
  130. lines = [ele['chars'] for ele in all_dict_list]
  131. raw_lines = [ele['raw_chars'] for ele in all_dict_list]
  132. return lines, raw_lines
  133. else:
  134. res_list = []
  135. split_x_index = [ele + 1 for ele in split_x_index] # 索引值扩大
  136. split_x_index.insert(0, 0)
  137. split_x_index.insert(-1, len(all_dict_list))
  138. split_x_index = sorted(list(set(split_x_index)))
  139. for i, split in enumerate(split_x_index[1:]):
  140. one_line = all_dict_list[split_x_index[i]:split_x_index[i + 1]]
  141. one_line = sorted(one_line, key=lambda k: k.get('middle')[0]) # x轴排序
  142. res_list.append(one_line)
  143. lines = []
  144. raw_lines = []
  145. for ele in res_list:
  146. line_chars = ''
  147. raw_lines_chars = ''
  148. for ele1 in ele:
  149. chars = ele1['chars']
  150. raw_chars = ele1['raw_chars']
  151. line_chars = line_chars + chars
  152. raw_lines_chars = raw_lines_chars + raw_chars
  153. lines.append(line_chars + '\n')
  154. raw_lines.append(raw_lines_chars + '\n')
  155. # print(lines)
  156. return lines, raw_lines
  157. def combine(formula_coordinates_list, zh_char_height, zh_char_width):
  158. formula_coordinates_list = sorted(formula_coordinates_list, key=lambda k: k[0])
  159. formula_coordinates_list = sorted(formula_coordinates_list, key=lambda k: k[1]) # 先x轴,再y轴排序
  160. recursion_flag = False
  161. del_list = []
  162. temp_list = formula_coordinates_list.copy()
  163. for i, outer in enumerate(temp_list): # xmin, ymin, xmax, ymax
  164. for j, inner in enumerate(temp_list): # xmin, ymin, xmax, ymax
  165. if not i == j:
  166. min_distance, flag = get_min_distance(outer, inner)
  167. combine_coordinate = ()
  168. if flag == 'i':
  169. recursion_flag = True
  170. combine_coordinate = (min(outer[0], inner[0]), min(outer[1], inner[1]),
  171. max(outer[2], inner[2]), max(outer[3], inner[3]))
  172. elif flag == 'h' and min_distance <= 1:
  173. recursion_flag = True
  174. combine_coordinate = (min(outer[0], inner[0]), min(outer[1], inner[1]),
  175. max(outer[2], inner[2]), max(outer[3], inner[3]))
  176. elif flag == 'w' and min_distance <= zh_char_width:
  177. recursion_flag = True
  178. combine_coordinate = (min(outer[0], inner[0]), min(outer[1], inner[1]),
  179. max(outer[2], inner[2]), max(outer[3], inner[3]))
  180. elif flag == 'c' and min_distance <= 1:
  181. recursion_flag = True
  182. combine_coordinate = (min(outer[0], inner[0]), min(outer[1], inner[1]),
  183. max(outer[2], inner[2]), max(outer[3], inner[3]))
  184. if combine_coordinate:
  185. if not combine_coordinate == outer and not combine_coordinate == inner: # 避免全包围的情况
  186. del_list.append(outer)
  187. del_list.append(inner)
  188. if combine_coordinate == outer:
  189. del_list.append(inner)
  190. if combine_coordinate == inner:
  191. del_list.append(outer)
  192. formula_coordinates_list.append(combine_coordinate)
  193. res = list(set(formula_coordinates_list) - set(del_list))
  194. if recursion_flag:
  195. return combine(res, zh_char_height, zh_char_width)
  196. else:
  197. return res
  198. def get_min_distance_square(coordinate1, coordinate2): # 顶点间欧式距离最小值的平方和
  199. all_points1 = [(x, y) for x in [coordinate1[0], coordinate1[2]] for y in [coordinate1[1], coordinate1[3]]]
  200. all_points2 = [(x, y) for x in [coordinate2[0], coordinate2[2]] for y in [coordinate2[1], coordinate2[3]]]
  201. distance_list = []
  202. for index1, point1 in enumerate(all_points1):
  203. for index2, point2 in enumerate(all_points2):
  204. distance = (point1[0] - point2[0]) ** 2 + (point1[1] - point2[1]) ** 2
  205. distance_list.append(distance)
  206. min_distance = min(distance_list)
  207. return min_distance
  208. def get_min_distance(coordinate1, coordinate2): # 欧式距离最小值
  209. def dist(point1, point2):
  210. distance = (point1[0] - point2[0]) ** 2 + (point1[1] - point2[1]) ** 2
  211. return math.sqrt(distance)
  212. (x1, y1, x1b, y1b) = coordinate1
  213. (x2, y2, x2b, y2b) = coordinate2
  214. left = x2b < x1 # 2在1的坐标左边
  215. right = x1b < x2 # 2在1的坐标右边
  216. bottom = y2b < y1 # 2在1的坐标下边
  217. top = y1b < y2 # 2在1的坐标上边
  218. if top and left:
  219. return dist((x1, y1b), (x2b, y2)), 'c'
  220. elif left and bottom:
  221. return dist((x1, y1), (x2b, y2b)), 'c'
  222. elif bottom and right:
  223. return dist((x1b, y1), (x2, y2b)), 'c'
  224. elif right and top:
  225. return dist((x1b, y1b), (x2, y2)), 'c'
  226. elif left:
  227. return x1 - x2b, 'w'
  228. elif right:
  229. return x2 - x1b, 'w'
  230. elif bottom:
  231. return y1 - y2b, 'h'
  232. elif top:
  233. return y2 - y1b, 'h'
  234. else: # rectangles intersect
  235. return 0, 'i'