sheet_infer.py 57 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273
  1. # @Author : lightXu
  2. # @File : sheet_infer.py
  3. # @Time : 2019/9/26 0026 上午 10:18
  4. import itertools
  5. import os
  6. import re
  7. import traceback
  8. import xml.etree.cElementTree as ET
  9. from itertools import combinations
  10. import cv2
  11. import numpy as np
  12. from shapely.geometry import LineString, Polygon
  13. from segment.sheet_resolve.tools.utils import create_xml, crop_region_direct, crop_region, image_hash_detection_simple
  14. from segment.sheet_resolve.tools.brain_api import get_ocr_text_and_coordinate
  15. ASPECT_FLAG = 4.0
  16. REMAIN_RATIO = 0.1
  17. PIX_VALUE_LOW = 15.0
  18. PIX_VALUE_HIGH = 245
  19. TYPE_SCORE_MNS = 0.5
  20. def _get_char_near_img(char_location, near):
  21. left = char_location['left']
  22. top = char_location['top']
  23. width = char_location['width']
  24. height = char_location['height']
  25. next_location = char_location
  26. if near == 'left':
  27. next_location = {'left': int(left - 1.5 * width), 'top': top, 'width': width, 'height': height}
  28. if near == 'right':
  29. next_location = {'left': int(left + 1.5 * width), 'top': top, 'width': width, 'height': height}
  30. if near == 'up':
  31. next_location = {'left': left, 'top': int(top - 1.5 * height), 'width': width, 'height': height}
  32. if near == 'down':
  33. next_location = {'left': left, 'top': int(top + 1.5 * height), 'width': width, 'height': height}
  34. return next_location
  35. def _get_board(image, location, direction):
  36. std = 0
  37. next_location = location
  38. while std < 10:
  39. next_location = _get_char_near_img(next_location, direction)
  40. box = (next_location['left'], next_location['top'],
  41. next_location['left'] + next_location['width'],
  42. next_location['top'] + next_location['height'],)
  43. region = crop_region_direct(image, box)
  44. std = np.var(region)
  45. return next_location
  46. def infer_bar_code(image, ocr_dict_list, attention_region):
  47. attention_polygon_list = []
  48. for attention in attention_region:
  49. coordinates = attention['bounding_box']
  50. xmin = coordinates['xmin']
  51. ymin = coordinates['ymin']
  52. xmax = coordinates['xmax']
  53. ymax = coordinates['ymax']
  54. attention_polygon = Polygon([(xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)])
  55. attention_polygon_list.append(attention_polygon)
  56. img_cols, img_rows = image.shape[0], image.shape[1]
  57. pattern = r'条形码|条码|条形|形码'
  58. bar_code_dict_list = []
  59. for index, ele in enumerate(ocr_dict_list):
  60. words = ele['words'].replace(' ', '')
  61. chars_list = ele['chars']
  62. length = len(chars_list)
  63. match_list = [(m.group(), m.span()) for m in re.finditer(pattern, words) if m]
  64. if match_list: # 不为空
  65. for match in match_list:
  66. start_index = match[1][0]
  67. end_index = match[1][1] - 1
  68. for i in range(start_index - 1, -1, -1):
  69. xmin_start = chars_list[start_index]['location']['left']
  70. start_tmp = chars_list[i]['location']['left'] + 2 * chars_list[i]['location']['width']
  71. if xmin_start <= start_tmp:
  72. start_index = i
  73. for i in range(end_index, length):
  74. xmax_end = chars_list[end_index]['location']['left'] + 2 * chars_list[i]['location']['width']
  75. end_tmp = chars_list[i]['location']['left']
  76. if xmax_end >= end_tmp:
  77. end_index = i
  78. bar_code_char_xmin = chars_list[start_index]['location']["left"]
  79. bar_code_char_xmax = chars_list[end_index]['location']["left"]+chars_list[end_index]['location']["width"]
  80. bar_code_char_ymin = chars_list[start_index]['location']["top"]
  81. bar_code_char_ymax = chars_list[end_index]['location']["top"]+chars_list[end_index]['location']["height"]
  82. bar_code_char_polygon = Polygon([(bar_code_char_xmin, bar_code_char_ymin),
  83. (bar_code_char_xmax, bar_code_char_ymin),
  84. (bar_code_char_xmax, bar_code_char_ymax),
  85. (bar_code_char_xmin, bar_code_char_ymax)])
  86. contain_cond = [False]*len(attention_polygon_list)
  87. for i, attention_ele in enumerate(attention_polygon_list):
  88. if attention_ele.contains(bar_code_char_polygon):
  89. contain_cond[i] = True
  90. if True not in contain_cond: # 条形码文字不在attention里面
  91. left_board_location = _get_board(image, chars_list[start_index]['location'], 'left')
  92. right_board_location = _get_board(image, chars_list[end_index]['location'], 'right')
  93. up_board_location = _get_board(image, chars_list[start_index]['location'], 'up')
  94. down_board_location = _get_board(image, chars_list[end_index]['location'], 'down')
  95. xmin = left_board_location['left']
  96. ymin = up_board_location['top']
  97. xmax = right_board_location['left'] + right_board_location['width']
  98. ymax = down_board_location['top'] + down_board_location['height']
  99. bias = 5
  100. xmin = max(1, int(xmin)-bias)
  101. ymin = max(1, int(ymin)-bias)
  102. xmax = min(int(xmax)+bias, img_rows - 1)
  103. ymax = min(int(ymax)+bias, img_cols - 1)
  104. bar_code_dict = {'class_name': 'bar_code',
  105. 'bounding_box': {'xmin': xmin, 'ymin': ymin, 'xmax': xmax, 'ymax': ymax}}
  106. bar_code_dict_list.append(bar_code_dict)
  107. # print(bar_code_dict)
  108. break # 默认只有一个条形码
  109. else:
  110. continue
  111. # 过滤attention 区域存在条形码的文字
  112. for bar_code in bar_code_dict_list.copy():
  113. coordinates = bar_code['bounding_box']
  114. xmin = coordinates['xmin']
  115. ymin = coordinates['ymin']
  116. xmax = coordinates['xmax']
  117. ymax = coordinates['ymax']
  118. bar_code_polygon = Polygon([(xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)])
  119. for attention_polygon in attention_polygon_list:
  120. cond1 = bar_code_polygon.within(attention_polygon) or bar_code_polygon.contains(attention_polygon)
  121. cond2 = False
  122. cond3 = bar_code_polygon.overlaps(attention_polygon)
  123. if cond3:
  124. intersection_poly = bar_code_polygon.intersection(attention_polygon)
  125. cond2 = intersection_poly.area / bar_code_polygon.area >= 0.01
  126. cond3 = intersection_poly.area / attention_polygon.area >= 0.01
  127. if cond1 or cond2 or cond3:
  128. bar_code_dict_list.remove(bar_code)
  129. break
  130. return bar_code_dict_list
  131. def infer_exam_number(image, ocr_dict_list, existed_regions, times_threshold=5):
  132. # existed_polygon_list = []
  133. # for region in existed_regions:
  134. # coordinates = region['bounding_box']
  135. # xmin = coordinates['xmin']
  136. # ymin = coordinates['ymin']
  137. # xmax = coordinates['xmax']
  138. # ymax = coordinates['ymax']
  139. # existed_polygon = Polygon([(xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)])
  140. # existed_polygon_list.append(existed_polygon)
  141. img_rows, img_cols = image.shape[0], image.shape[1]
  142. exam_number_dict_list = []
  143. xmin, ymin, xmax, ymax = 9999, 9999, 0, 0
  144. pattern = r'[0oO]|[2-9]' # 除去1,避免[]被识别为1
  145. exclude = r'分|题|[ABD]'
  146. key_digital = []
  147. all_height = []
  148. cols = []
  149. for index, ele in enumerate(ocr_dict_list):
  150. words = ele['words'].replace(' ', '')
  151. match_list = [(m.group(), m.span()) for m in re.finditer(pattern, words) if m]
  152. exclude_list = [(m.group(), m.span()) for m in re.finditer(exclude, words, re.I) if m]
  153. match_digital_arr = np.asarray([int(char[0].replace('o', '0').replace('O', '0')) for char in match_list])
  154. if len(match_digital_arr) > 0:
  155. counts = np.bincount(match_digital_arr)
  156. mode_times = np.max(counts)
  157. if mode_times >= times_threshold and len(exclude_list) < 1:
  158. mode_value = np.argmax(counts) # 众数,避免考号末尾出现的其他数字
  159. key_index = np.where(match_digital_arr == mode_value)[0]
  160. cols.append(len(key_index))
  161. start_index = match_list[key_index[0]][1][0]
  162. end_index = match_list[key_index[-1]][1][0]
  163. xmin_t = ele['chars'][start_index]['location']['left']
  164. ymin_t = ele['chars'][start_index]['location']['top']
  165. xmax_t = ele['chars'][end_index]['location']['left'] + ele['chars'][end_index]['location']['width']
  166. ymax_t = ele['chars'][end_index]['location']['top'] + ele['chars'][end_index]['location']['height']
  167. mean_width = sum([int(ele['chars'][match_list[i][1][0]]['location']['width'])
  168. for i in key_index]) // len(key_index)
  169. mean_height = sum([int(ele['chars'][match_list[i][1][0]]['location']['height'])
  170. for i in key_index]) // len(key_index)
  171. all_height.append(mean_height)
  172. xmin = min(xmin, xmin_t-mean_width)
  173. ymin = min(ymin, ymin_t)
  174. xmax = max(xmax, xmax_t+mean_width)
  175. ymax = max(ymax, ymax_t)
  176. xmin = int(xmin) if xmin >= 1 else 1
  177. ymin = int(ymin) if ymin >= 1 else 1
  178. xmax = int(xmax) if xmax <= img_cols - 1 else img_cols - 1
  179. ymax = int(ymax) if ymax <= img_rows - 1 else img_rows - 1
  180. key_digital.append(mode_value)
  181. if 9 in key_digital:
  182. break
  183. if 0 in key_digital and 9 in key_digital and len(key_digital) > 4:
  184. mean_height = sum(all_height)//10
  185. exam_number_dict = {'class_name': 'exam_number',
  186. 'bounding_box': {'xmin': xmin, 'ymin': ymin, 'xmax': xmax, 'ymax': ymax+mean_height},
  187. 'rows': 10,
  188. 'cols': max(cols)
  189. }
  190. exam_number_dict_list.append(exam_number_dict)
  191. return exam_number_dict_list
  192. else:
  193. if len(key_digital) > 1:
  194. dgt_min = min(key_digital)
  195. dgt_max = max(key_digital)
  196. mean_height = sum(all_height)//len(all_height)
  197. dif = dgt_max - dgt_min
  198. blank_height = ymax - ymin - mean_height * (dif+1)
  199. mean_blank = blank_height // dif
  200. upper_height = dgt_min * (mean_blank + mean_height) + mean_blank//2
  201. downward_height = (9-dgt_max) * (mean_blank + mean_height) + mean_blank
  202. exam_number_dict = {'class_name': 'exam_number',
  203. 'bounding_box': {'xmin': xmin, 'ymin': ymin-upper_height,
  204. 'xmax': xmax, 'ymax': ymax+downward_height},
  205. 'rows': 10,
  206. 'cols': max(cols)}
  207. exam_number_dict_list.append(exam_number_dict)
  208. if len(key_digital) == 1:
  209. dgt_min = dgt_max = min(key_digital)
  210. eval_height = sum(all_height)//len(all_height) * 1.5
  211. upper_height = dgt_min * eval_height
  212. downward_height = (9-dgt_max) * eval_height
  213. exam_number_dict = {'class_name': 'exam_number',
  214. 'bounding_box': {'xmin': xmin, 'ymin': ymin-upper_height,
  215. 'xmax': xmax, 'ymax': ymax+downward_height},
  216. 'rows': 10,
  217. 'cols': max(cols)}
  218. exam_number_dict_list.append(exam_number_dict)
  219. iou_cond = True
  220. exam_number_dict_list_check = []
  221. if len(exam_number_dict_list) > 0:
  222. for exam_number_dict in exam_number_dict_list:
  223. exam_number_polygon = Polygon([(exam_number_dict["xmin"], exam_number_dict["ymin"]),
  224. (exam_number_dict["xmax"], exam_number_dict["ymin"]),
  225. (exam_number_dict["xmax"], exam_number_dict["ymax"]),
  226. (exam_number_dict["xmin"], exam_number_dict["ymax"])])
  227. for region in existed_regions:
  228. class_name = region["class_name"]
  229. if class_name in ["attention", "solve", "choice", "choice_m", 'choice_s', "cloze", 'cloze_s',
  230. 'bar_code', 'qr_code', 'composition', 'solve0']:
  231. coordinates = region['bounding_box']
  232. xmin = coordinates['xmin']
  233. ymin = coordinates['ymin']
  234. xmax = coordinates['xmax']
  235. ymax = coordinates['ymax']
  236. existed_polygon = Polygon([(xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)])
  237. overlab_area = existed_polygon.intersection(exam_number_polygon).area
  238. iou = overlab_area / (exam_number_polygon.area + existed_polygon.area - overlab_area)
  239. if iou > 0:
  240. iou_cond = False
  241. break
  242. if iou_cond:
  243. exam_number_dict_list_check.append(exam_number_polygon)
  244. return exam_number_dict_list_check
  245. def adjust_exam_number(regions):
  246. exam_number_w_regions = list()
  247. exam_number_regions = list()
  248. for i in range(len(regions) - 1, -1, -1):
  249. region = regions[i]
  250. if region['class_name'] == 'exam_number_w':
  251. exam_number_w_regions.append(region)
  252. if region['class_name'] == 'exam_number':
  253. exam_number_regions.append(region)
  254. regions.pop(i)
  255. exam_number_region = exam_number_regions[0]
  256. if len(exam_number_regions) > 1:
  257. exam_number_regions = sorted(exam_number_regions, key=lambda x: x['bounding_box']['ymin'])
  258. exam_number_region = exam_number_regions[0]
  259. exam_number_w_index = 0
  260. if len(exam_number_w_regions) > 1:
  261. distance = [abs(int(ele['bounding_box']['ymax']) - int(exam_number_region['bounding_box']['ymin']))
  262. for ele in exam_number_w_regions]
  263. exam_number_w_index = distance.index(min(distance))
  264. exam_number_w_region = exam_number_w_regions[exam_number_w_index]
  265. standard = exam_number_w_region['bounding_box']
  266. exam_number_region['bounding_box'].update({'xmin': standard['xmin'], 'xmax': standard['xmax']})
  267. regions.append(exam_number_region)
  268. return regions
  269. def exam_number_infer_by_s(image, regions):
  270. exam_number_s_list = [ele for ele in regions if ele['class_name'] == 'exam_number_s'
  271. and (int(ele['bounding_box']['xmax'])-int(ele['bounding_box']['xmin']) <
  272. int(ele['bounding_box']['ymax'])-int(ele['bounding_box']['ymin']))]
  273. # 找边界
  274. exam_number_s_list = sorted(exam_number_s_list, key=lambda x: x['bounding_box']['xmin'])
  275. left_limit = exam_number_s_list[0]['bounding_box']['xmin']
  276. right_limit = exam_number_s_list[-1]['bounding_box']['xmax']
  277. left_image = crop_region(image, exam_number_s_list[0]['bounding_box'])
  278. right_image = crop_region(image, exam_number_s_list[-1]['bounding_box'])
  279. mean_width = sum([int(ele['bounding_box']['xmax'])-int(ele['bounding_box']['xmin'])
  280. for ele in exam_number_s_list]) // len(exam_number_s_list)
  281. top_limit = min([ele['bounding_box']['ymin'] for ele in exam_number_s_list])
  282. bottom_limit = max([ele['bounding_box']['ymax'] for ele in exam_number_s_list])
  283. left_infer = True
  284. while left_infer:
  285. infer_box_xmin = int(left_limit - 1.5*mean_width)
  286. infer_box_xmax = int(left_limit - 0.5*mean_width)
  287. infer_box_ymin = int(exam_number_s_list[0]['bounding_box']['ymin'])
  288. infer_box_ymax = int(exam_number_s_list[0]['bounding_box']['ymax'])
  289. infer_image = crop_region_direct(image, [infer_box_xmin, infer_box_ymin, infer_box_xmax, infer_box_ymax])
  290. simi = image_hash_detection_simple(left_image, infer_image)
  291. print('l:', simi)
  292. if simi >= 0.85:
  293. left_limit = infer_box_xmin
  294. else:
  295. left_infer = False
  296. right_infer = True
  297. while right_infer:
  298. infer_box_xmin = int(right_limit + 0.5 * mean_width)
  299. infer_box_xmax = int(right_limit + 1.5 * mean_width)
  300. infer_box_ymin = int(exam_number_s_list[-1]['bounding_box']['ymin'])
  301. infer_box_ymax = int(exam_number_s_list[-1]['bounding_box']['ymax'])
  302. infer_image = crop_region_direct(image, [infer_box_xmin, infer_box_ymin, infer_box_xmax, infer_box_ymax])
  303. simi = image_hash_detection_simple(right_image, infer_image)
  304. print('r:', simi)
  305. if simi >= 0.70:
  306. right_limit = infer_box_xmax
  307. else:
  308. right_infer = False
  309. infer_exam_number_region = {'xmin': left_limit, 'xmax': right_limit, 'ymin': top_limit, 'ymax': bottom_limit, }
  310. exam_dict_list = [{'class_name': 'exam_number', 'bounding_box': infer_exam_number_region}]
  311. # print(exam_dict_list)
  312. return exam_dict_list
  313. def draw_blank(image, infer_box):
  314. x_bias, y_bias = infer_box[:2]
  315. exam_number_infer_region = crop_region_direct(image, infer_box)
  316. ocr = get_ocr_text_and_coordinate(exam_number_infer_region)
  317. pattern = r'[\u4e00-\u9fa5]'
  318. for index, ele in enumerate(ocr):
  319. # words = ele['words'].replace(' ', '')
  320. for char in ele['chars']:
  321. if re.match(pattern, char['char']):
  322. loc = char['location']
  323. left, top, width, height = loc['left']+x_bias, loc['top']+y_bias, loc['width'], loc['height']
  324. image[top:top+2*height, left:left+width, :] = 255
  325. return image
  326. def exam_number_adjust_infer(image, regions):
  327. box = []
  328. box1 = []
  329. for i in range(len(regions) - 1, -1, -1):
  330. region = regions[i]
  331. if region['class_name'] == 'exam_number_s':
  332. loc = region['bounding_box']
  333. if loc['ymax'] - loc['ymin'] > loc['xmax'] - loc['xmin']:
  334. box.append([loc['xmin'], loc['ymin'], loc['xmax'], loc['ymax']])
  335. if region['class_name'] == 'exam_number':
  336. loc = region['bounding_box']
  337. box.append([loc['xmin'], loc['ymin'], loc['xmax'], loc['ymax']])
  338. regions.pop(i)
  339. if region['class_name'] == 'exam_number_w':
  340. loc = region['bounding_box']
  341. if loc['xmax'] - loc['xmin'] > loc['ymax'] - loc['ymin']:
  342. box1.append([loc['xmin'], loc['ymax'], loc['xmax'], loc['ymax']]) # exam_number_w 只取下界
  343. if not box:
  344. return image, regions
  345. box_array = np.array(box+box1)
  346. xmin, ymin = np.min(box_array, axis=0)[:2]
  347. xmax, ymax = np.max(box_array, axis=0)[2:]
  348. images = draw_blank(image, (xmin, ymin, xmax, ymax))
  349. exam_number_infer = {'class_name': 'exam_number',
  350. 'bounding_box': {'xmin': xmin, 'ymin': ymin,
  351. 'xmax': xmax, 'ymax': ymax}}
  352. regions.append(exam_number_infer)
  353. return images, regions
  354. def gen_xml_new(path, ocr_list):
  355. tree = ET.parse(r'../../tools/000000-template.xml') # xml tree
  356. for index, ele in enumerate(ocr_list):
  357. words = ele['words']
  358. location = ele['location']
  359. xmin = location['xmin']
  360. ymin = location['ymin']
  361. xmax = location['xmax']
  362. ymax = location['ymax']
  363. tree = create_xml('{}'.format(words), tree, str(xmin), str(ymin), str(xmax), str(ymax))
  364. # print(exam_items_bbox)
  365. tree.write(path.replace('.jpg', '.xml'))
  366. def subfield_answer_sheet(img0, answer_sheet):
  367. h, w = img0.shape[:2]
  368. one_part = 0
  369. line_xmax_1 = 0
  370. line_xmax_2 = 0
  371. modules = []
  372. modules11 = []
  373. w_int_1 = w
  374. w_int_2 = round(w / 2)
  375. w_int_3 = round(w / 3)
  376. w_int_4 = round(w / 4)
  377. w_int_8 = round(w / 8)
  378. if w_int_8 < 50:
  379. w_int_8 = 50
  380. key_modules_classes = ['choice', 'cloze', 'solve', 'solve0', 'composition0', 'composition', 'correction',
  381. 'ban_area', ]
  382. if h > w: # 暂定答题卡高大于宽的为单栏
  383. one_part = 1
  384. else:
  385. temp1 = 0
  386. temp2 = 0
  387. for ele in answer_sheet:
  388. if ele["class_name"] in key_modules_classes:
  389. modules.append(ele)
  390. modules_xmin = sorted(modules, key=lambda x: (x['bounding_box']['xmin']))
  391. modules_xmax = sorted(modules, key=lambda x: (x['bounding_box']['xmax']))
  392. for i in range(len(modules_xmin) - 1):
  393. if i == 0 and modules_xmin[0]['bounding_box']['xmin'] - 0 > w_int_4:
  394. temp1 = 1
  395. else:
  396. if modules_xmin[i + 1]['bounding_box']['xmin'] - modules_xmin[i]['bounding_box']['xmax'] > w_int_4:
  397. if modules11 == []:
  398. line_xmax_1 = modules_xmin[i]['bounding_box']['xmax'] + 20
  399. line_xmax_2 = modules_xmin[i + 1]['bounding_box']['xmin'] - 20
  400. else:
  401. modules11.append(modules_xmin[i]['bounding_box']['xmax'])
  402. modules11_xmax = sorted(modules11)[-1]
  403. line_xmax_1 = modules11_xmax + 20
  404. line_xmax_2 = modules_xmin[i + 1]['bounding_box']['xmin'] - 20
  405. modules11 = []
  406. temp1 = 1
  407. temp2 = 1
  408. break
  409. elif modules_xmin[i + 1]['bounding_box']['xmin'] - modules_xmin[i]['bounding_box']['xmax'] > -w_int_8:
  410. if temp1 == 0:
  411. if modules11 == []:
  412. line_xmax_1 = int((modules_xmin[i + 1]['bounding_box']['xmin'] +
  413. modules_xmin[i]['bounding_box']['xmax']) / 2)
  414. else:
  415. modules11.append(modules_xmin[i]['bounding_box']['xmax'])
  416. modules11_xmax = sorted(modules11)[-1]
  417. line_xmax_1 = int((modules_xmin[i + 1]['bounding_box']['xmin'] +
  418. modules11_xmax) / 2)
  419. modules11 = []
  420. temp1 = 1
  421. elif temp1 == 1:
  422. if modules11 == []:
  423. line_xmax_2 = int((modules_xmin[i + 1]['bounding_box']['xmin'] +
  424. modules_xmin[i]['bounding_box']['xmax']) / 2)
  425. else:
  426. modules11.append(modules_xmin[i]['bounding_box']['xmax'])
  427. modules11_xmax = sorted(modules11)[-1]
  428. line_xmax_2 = int((modules_xmin[i + 1]['bounding_box']['xmin'] +
  429. modules11_xmax) / 2)
  430. temp2 = 1
  431. else:
  432. modules11.append(modules_xmin[i]['bounding_box']['xmax'])
  433. if temp1 == 0 and temp2 == 0:
  434. if modules_xmax[-1]['bounding_box']['xmax'] - w < -(2 * w_int_4):
  435. line_xmax_1 = modules_xmax[-1]['bounding_box']['xmax'] + 20
  436. line_xmax_2 = 2 * w_int_3
  437. elif modules_xmax[-1]['bounding_box']['xmax'] - w < -w_int_4:
  438. line_xmax_1 = modules_xmax[-1]['bounding_box']['xmax'] + 20
  439. elif temp1 == 1 and temp2 == 0:
  440. if modules_xmax[-1]['bounding_box']['xmax'] - w < -w_int_4:
  441. line_xmax_2 = 2 * w_int_3
  442. return line_xmax_1, line_xmax_2
  443. def get_intersection_point(lines, orthogonal_lines, border):
  444. intersect_point_list = []
  445. for line in lines:
  446. width_min, height_min, width_max, height_max = border
  447. (x_l, y_u), (x_r, y_d) = line.coords
  448. x_l = x_l if x_l > width_min else width_min + 1 # 避免边界
  449. x_r = x_r if x_r < width_max else width_max - 1
  450. y_u = y_u if y_u > height_min else height_min + 1
  451. y_d = y_d if y_d < height_max else height_max - 1
  452. points_list = []
  453. if x_l == x_r:
  454. line_direction = 'lon'
  455. raw_line = LineString([(x_l, y_u), (x_r, y_d)])
  456. extend_line = LineString([(x_l, height_min), (x_r, height_max)])
  457. points_list.extend([height_min + 1, height_max - 1]) # 延长线与边界交点,并避免key_point位于现有边界上
  458. line_start, line_end = y_u, y_d
  459. else:
  460. line_direction = 'lat'
  461. raw_line = LineString([(x_l, y_u), (x_r, y_d)])
  462. extend_line = LineString([(width_min, y_u), (width_max, y_d)])
  463. points_list.extend([width_min + 1, width_max - 1]) # 延长线与边界交点,并避免key_point位于现有边界上
  464. line_start, line_end = x_l, x_r
  465. for ele in orthogonal_lines:
  466. cond1 = extend_line.intersects(ele) # T, L, 十交叉
  467. cond2 = extend_line.crosses(ele) # 十字交叉
  468. cond3 = raw_line.intersects(ele)
  469. cond4 = raw_line.crosses(ele)
  470. if line_direction == 'lat':
  471. if cond3:
  472. (xp, yp) = raw_line.intersection(ele).bounds[:2]
  473. intersect_point_list.append((xp, yp))
  474. elif cond1:
  475. (xp, yp) = extend_line.intersection(ele).bounds[:2]
  476. points_list.append(xp)
  477. if line_direction == 'lon':
  478. if cond3:
  479. (xp, yp) = raw_line.intersection(ele).bounds[:2]
  480. intersect_point_list.append((xp, yp))
  481. elif cond1:
  482. (xp, yp) = extend_line.intersection(ele).bounds[:2]
  483. points_list.append(yp)
  484. points_array = np.asarray(points_list, dtype=np.uint)
  485. left_key = np.max(points_array[points_array <= line_start])
  486. right_key = np.min(points_array[points_array >= line_end]) # 延长线两边延长并取得第一个交点
  487. if line_direction == 'lat':
  488. intersect_point = [(left_key, y_u), (right_key, y_d)]
  489. else:
  490. intersect_point = [(x_l, left_key), (x_r, right_key)]
  491. # print(intersect_point)
  492. intersect_point_list.extend(intersect_point)
  493. return intersect_point_list
  494. def infer_sheet_box(image, sheet_dict, lon_split_line, exclude_classes):
  495. height_max, width_max = image.shape[0], image.shape[1]
  496. height_min, width_min = 0, 0
  497. latitude = []
  498. longitude = []
  499. lines = []
  500. sheet_polygons = []
  501. all_sheet_polygons = []
  502. choice_polygon = []
  503. # exclude_classes = ['cloze_s', 'exam_number_s', 'choice_s', 'type_score',
  504. # 'mark', 'page', 'exam_number_s', 'cloze_score', 'name_w',
  505. # 'class_w',]
  506. h_min = []
  507. h_max = []
  508. for index, region_box in enumerate(sheet_dict):
  509. coordinates = region_box['bounding_box']
  510. xmin = coordinates['xmin']
  511. ymin = coordinates['ymin']
  512. xmax = coordinates['xmax']
  513. ymax = coordinates['ymax']
  514. if region_box['class_name'] == 'info_title': # 上限
  515. h_min.append(ymin)
  516. if region_box['class_name'] == 'page': # 下限
  517. h_max.append(ymin)
  518. if region_box['class_name'] == 'alarm_info':
  519. h_min.append(ymin)
  520. h_max.append(ymin)
  521. if h_min:
  522. hgt_min = min(h_min)
  523. if hgt_min < height_max / 4:
  524. height_min = hgt_min
  525. if h_max:
  526. hgt_max = max(h_max)
  527. if hgt_max > 3 * height_max / 4:
  528. height_max = hgt_max
  529. # height_min = h_min if h_min != 9999 else height_min
  530. # height_max = h_max if h_max != 0 else height_max
  531. for index, region_box in enumerate(sheet_dict):
  532. coordinates = region_box['bounding_box']
  533. xmin = coordinates['xmin']
  534. ymin = coordinates['ymin']
  535. xmax = coordinates['xmax']
  536. ymax = coordinates['ymax']
  537. box_polygon = Polygon([(xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)])
  538. if region_box['class_name'] not in exclude_classes:
  539. if region_box['class_name'] not in ['choice', 'cloze']: # 推断选择题区域内的choice_m
  540. sheet_polygons.append(box_polygon)
  541. if region_box['class_name'] == 'choice':
  542. choice_polygon.append(box_polygon)
  543. all_sheet_polygons.append(box_polygon)
  544. line1 = LineString([(xmin, ymin), (xmin, ymax)])
  545. line2 = LineString([(xmax, ymin), (xmax, ymax)])
  546. line3 = LineString([(xmin, ymin), (xmax, ymin)])
  547. line4 = LineString([(xmin, ymax), (xmax, ymax)])
  548. lines.extend([line1, line2, line3, line4])
  549. longitude.extend([line1, line2])
  550. latitude.extend([line3, line4])
  551. # sheet_polygons 去除包裹的情况
  552. sheet_polygons_ = list(combinations(sheet_polygons, 2))
  553. for polygons in sheet_polygons_:
  554. if polygons[0].within(polygons[1]) or polygons[0].contains(polygons[1]):
  555. area_list = [polygons[0].area, polygons[1].area]
  556. min_polygon = polygons[area_list.index(min(area_list))]
  557. if min_polygon in sheet_polygons:
  558. sheet_polygons.remove(min_polygon)
  559. min_polygon = sorted(all_sheet_polygons, key=lambda p: p.area)[0]
  560. avg_area = sum([polygon.area for polygon in sheet_polygons]) / len(sheet_polygons)
  561. # 所有矩形框的延长线与矩形框集图像边界的交点
  562. latitude = sorted(latitude, key=lambda x: x.bounds[1]) # y
  563. longitude = sorted(longitude, key=lambda x: x.bounds[0]) # x
  564. lat_intersect_point_list = get_intersection_point(latitude, longitude,
  565. (width_min, height_min, width_max, height_max))
  566. lon_intersect_point_list = get_intersection_point(longitude, latitude,
  567. (width_min, height_min, width_max, height_max))
  568. raw_corner = [(width_min + 1, height_min + 1), (width_min + 1, height_max - 1), (width_max - 1, 1),
  569. (width_max - 1, height_max - 1)]
  570. # raw_corner = []
  571. intersect_point_list = lat_intersect_point_list + lon_intersect_point_list + raw_corner
  572. intersect_point_list = list(set(intersect_point_list))
  573. intersect_point_dict = {k: index + 1 for index, k in enumerate(intersect_point_list)}
  574. def _filter_rect(p_list):
  575. flag = 0
  576. for ele in p_list:
  577. try:
  578. flag = intersect_point_dict[ele]
  579. except KeyError:
  580. flag = 0
  581. break
  582. if flag > 0:
  583. x_c = sum([ele[0] for ele in p_list]) / 4
  584. y_c = sum([ele[1] for ele in p_list]) / 4
  585. d1, d2, d3, d4 = [LineString([p, (x_c, y_c)]).length for p in p_list]
  586. return (0 not in [d1, d2, d3, d4]) and d1 == d2 and d1 == d3 and d1 == d4
  587. else:
  588. return False
  589. def _find_rect(point):
  590. (x1, y1) = point[0]
  591. (x2, y2) = point[1]
  592. if x1 != x2 and y1 != y2:
  593. xmin, ymin = min(x1, x2), min(y1, y2)
  594. xmax, ymax = max(x1, x2), max(y1, y2)
  595. points_4 = [(xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)]
  596. w, h = xmax - xmin, ymax - ymin
  597. aspect_flag_extreme = max(w / h, h / w) < 1.5 * ASPECT_FLAG # 解决极端情况
  598. rect_flag = _filter_rect(points_4)
  599. if aspect_flag_extreme and rect_flag:
  600. gen_polygon = Polygon([(points_4[0]), (points_4[1]), (points_4[2]), (points_4[3])])
  601. flags = set()
  602. for polygon in sheet_polygons:
  603. decision = [gen_polygon.contains(polygon),
  604. gen_polygon.within(polygon),
  605. gen_polygon.overlaps(polygon)]
  606. if True in decision: # 边界问题
  607. flags.add(False)
  608. break
  609. else:
  610. flags.add(True)
  611. if False in flags:
  612. pass
  613. else:
  614. return gen_polygon
  615. def _filter_none(p):
  616. if p is not None:
  617. return True
  618. points_2 = combinations(intersect_point_list, 2)
  619. gen_polygon_list = map(_find_rect, points_2)
  620. gen_polygon_list = list(filter(_filter_none, gen_polygon_list))
  621. gen_polygon_list = sorted(gen_polygon_list, key=lambda p: p.area, reverse=True)
  622. # gen_polygon_list = [polygon for index, polygon in enumerate(gen_polygon_list) if index % 2 == 0]
  623. it = itertools.groupby(gen_polygon_list)
  624. gen_polygon_list = [k for k, g in it]
  625. # 在选择题区域的infer polygon
  626. # gen_choice = []
  627. # for ele in gen_polygon_list:
  628. # for choice_p in choice_polygon:
  629. # if ele.within(choice_p):
  630. # gen_choice.append(ele)
  631. sheet_box_area = sum([polygon.area for polygon in sheet_polygons])
  632. image_area = width_max * height_max
  633. blank_ratio = 1 - sheet_box_area / image_area
  634. polygon_index = 0
  635. include_polygon = []
  636. while blank_ratio > REMAIN_RATIO and polygon_index < len(gen_polygon_list):
  637. polygon = gen_polygon_list[polygon_index]
  638. blank_ratio = blank_ratio - polygon.area / image_area
  639. include_polygon.append(polygon)
  640. polygon_index += 1
  641. # gen_polygon_list = [polygon for index, polygon in enumerate(gen_polygon_list)
  642. # if polygon.area > 1.5 * min_polygon.area]
  643. for polygon in gen_polygon_list.copy():
  644. xi, yi, xx, yx = polygon.bounds
  645. w, h = xx - xi, yx - yi
  646. if polygon.area <= 1.5 * min_polygon.area or h / w > 2 and polygon.area < avg_area:
  647. gen_polygon_list.remove(polygon)
  648. polygon_2 = list(combinations(gen_polygon_list, 2))
  649. for polygons in polygon_2:
  650. try:
  651. cond2 = polygons[0].overlaps(polygons[1]) # 叠置关系二次分段
  652. if cond2:
  653. area_list = [polygons[0].area, polygons[1].area]
  654. min_index = area_list.index(min(area_list))
  655. smaller_polygon = polygons[min_index]
  656. larger_polygon = polygons[1 - min_index]
  657. new_polygon = smaller_polygon.difference(larger_polygon)
  658. if smaller_polygon in gen_polygon_list:
  659. gen_polygon_list.remove(smaller_polygon)
  660. if 'MultiPolygon' in str(type(new_polygon)):
  661. for ele in new_polygon:
  662. xm, ym, xx, yx = ele.bounds
  663. w, h = xx - xm, yx - ym
  664. if max(w / h, h / w) < 1.5 * ASPECT_FLAG and ele.area > 1.5 * min_polygon.area:
  665. gen_polygon_list.append(ele)
  666. elif len(set(new_polygon.exterior.coords)) == 4:
  667. xm, ym, xx, yx = new_polygon.bounds
  668. w, h = xx - xm, yx - ym
  669. if max(w / h, h / w) < 1.5 * ASPECT_FLAG and new_polygon.area > 1.5 * min_polygon.area:
  670. gen_polygon_list.append(new_polygon)
  671. except Exception as polygon_e:
  672. print(polygon_e)
  673. continue
  674. polygon_2 = list(combinations(gen_polygon_list, 2)) # 包含关系取大值
  675. for polygons in polygon_2:
  676. cond1 = polygons[0].equals(polygons[1])
  677. if cond1 and polygons[1] in gen_polygon_list:
  678. gen_polygon_list.remove(polygons[1])
  679. polygon_2 = list(combinations(gen_polygon_list, 2))
  680. for polygons in polygon_2:
  681. cond2 = polygons[0].contains(polygons[1]) or polygons[0].within(polygons[1])
  682. if cond2:
  683. area_list = [polygons[0].area, polygons[1].area]
  684. min_index = area_list.index(min(area_list))
  685. smaller_polygon = polygons[min_index]
  686. larger_polygon = polygons[1 - min_index]
  687. sxi, syi, sxx, syx = smaller_polygon.bounds
  688. bxi, byi, bxx, byx = larger_polygon.bounds
  689. # inner_touch_cond = '212F11FF2' == larger_polygon.relate(smaller_polygon)
  690. two_side_touch_cond = (sxi == bxi and sxx == bxx) or (syi == byi and syx == byx)
  691. if two_side_touch_cond:
  692. dif_polygon = larger_polygon.difference(smaller_polygon)
  693. if larger_polygon in gen_polygon_list:
  694. gen_polygon_list.remove(larger_polygon)
  695. if 'MultiPolygon' in str(type(dif_polygon)):
  696. for ele in dif_polygon:
  697. xm, ym, xx, yx = ele.bounds
  698. w, h = xx - xm, yx - ym
  699. if max(w / h, h / w) < 1.5 * ASPECT_FLAG and ele.area > 1.5 * min_polygon.area:
  700. gen_polygon_list.append(ele)
  701. elif len(set(dif_polygon.exterior.coords)) == 4: # empty
  702. xm, ym, xx, yx = dif_polygon.bounds
  703. w, h = xx - xm, yx - ym
  704. if max(w / h, h / w) < 1.5 * ASPECT_FLAG and dif_polygon.area > 1.5 * min_polygon.area:
  705. gen_polygon_list.append(dif_polygon)
  706. else:
  707. if smaller_polygon in gen_polygon_list:
  708. gen_polygon_list.remove(smaller_polygon)
  709. polygon_2 = list(combinations(gen_polygon_list, 2)) # 包含关系取大值
  710. for polygons in polygon_2:
  711. cond1 = polygons[0].equals(polygons[1])
  712. if cond1 and polygons[1] in gen_polygon_list:
  713. gen_polygon_list.remove(polygons[1])
  714. if len(lon_split_line) > 0:
  715. for line in lon_split_line:
  716. # line = LineString([(286, 1), (286, 599)])
  717. for poly in gen_polygon_list.copy():
  718. cond1 = line.intersects(poly)
  719. cond2 = line.touches(poly)
  720. if cond1 and not cond2:
  721. dif_polygons = poly.difference(line)
  722. corner_list = list(set(dif_polygons.exterior.coords))
  723. sorted_corner_list = sorted(corner_list, key=lambda x: x[0])
  724. if len(sorted_corner_list) == 6:
  725. left = sorted(sorted_corner_list[0:2], key=lambda x: x[1])
  726. middle = sorted(sorted_corner_list[2:4], key=lambda x: x[1])
  727. right = sorted(sorted_corner_list[4:6], key=lambda x: x[1])
  728. tmp_corner_list = [middle[0], left[0], left[1], middle[1], right[1], right[0], middle[0]]
  729. polygon1 = Polygon(tmp_corner_list[:4])
  730. polygon2 = Polygon(tmp_corner_list[3:])
  731. gen_polygon_list.remove(poly)
  732. for p in [polygon1, polygon2]:
  733. xi, yi, xx, yx = p.bounds
  734. w, h = xx - xi, yx - yi
  735. aspect_flag = max(w / h, h / w) < ASPECT_FLAG
  736. if aspect_flag:
  737. gen_polygon_list.append(p)
  738. gen_polygon_list = [polygon for index, polygon in enumerate(gen_polygon_list) if polygon.area > min_polygon.area]
  739. # if gen_choice:
  740. # gen_choice = sorted(gen_choice, key=lambda x: x.area)[-1]
  741. # gen_polygon_list.append(gen_choice)
  742. return gen_polygon_list
  743. def infer_class(image, sheet_dict_list, infer_polygon, image_cols, ocr_dict_list=''):
  744. res = []
  745. all_type_score_polygon = []
  746. all_choice_polygon = []
  747. all_cloze_polygon = []
  748. all_solve_polygon = []
  749. all_choice_s_width = []
  750. for region_box in sheet_dict_list:
  751. if region_box['class_name'] in ['type_score', 'choice', 'cloze', 'solve', 'choice_s']:
  752. coordinates = region_box['bounding_box']
  753. xmin = coordinates['xmin']
  754. ymin = coordinates['ymin']
  755. xmax = coordinates['xmax']
  756. ymax = coordinates['ymax']
  757. box_polygon = Polygon([(xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)])
  758. if region_box['class_name'] == 'type_score':
  759. all_type_score_polygon.append(box_polygon)
  760. if region_box['class_name'] == 'choice':
  761. all_choice_polygon.append(box_polygon)
  762. if region_box['class_name'] == 'cloze':
  763. all_cloze_polygon.append(box_polygon)
  764. if region_box['class_name'] == 'solve':
  765. all_solve_polygon.append(box_polygon)
  766. if region_box['class_name'] == 'choice_s':
  767. all_choice_s_width.append(int(xmax)-int(xmin))
  768. for poly in infer_polygon.copy(): # infer type_score solve
  769. p_xmin, p_ymin, p_xmax, p_ymax = poly.bounds
  770. type_score_num = 0
  771. type_score_ymin = []
  772. for type_score_polygon in all_type_score_polygon:
  773. cond1 = type_score_polygon.within(poly)
  774. cond2 = False
  775. cond3 = type_score_polygon.overlaps(poly)
  776. if cond3:
  777. intersection_poly = type_score_polygon.intersection(poly)
  778. d1 = intersection_poly.area / type_score_polygon.area >= TYPE_SCORE_MNS
  779. print('type_score:', intersection_poly.area / type_score_polygon.area)
  780. d2 = type_score_polygon.area < 0.2 * poly.area
  781. cond2 = d1 and d2
  782. if cond1 or cond2:
  783. type_score_num += 1
  784. t_xmin, t_ymin, t_xmax, t_ymax = type_score_polygon.bounds
  785. type_score_ymin.append(t_ymin)
  786. t_height = t_ymax - t_ymin
  787. if t_ymin - p_ymin > 3 * t_height:
  788. type_score_num += 1
  789. type_score_ymin.append(p_ymin)
  790. if type_score_num == 1:
  791. in_xmin, in_ymin, in_xmax, in_ymax = poly.bounds
  792. solve_box = {'class_name': 'solve',
  793. 'bounding_box': {'xmin': int(in_xmin), 'ymin': int(in_ymin),
  794. 'xmax': int(in_xmax), 'ymax': int(in_ymax)}}
  795. sheet_dict_list.append(solve_box)
  796. infer_polygon.remove(poly)
  797. res.append(solve_box)
  798. if type_score_num > 1: # 多type_score
  799. type_score_ymin = sorted(type_score_ymin)
  800. type_score_ymin[0] = min(p_ymin, type_score_ymin[0])
  801. type_score_ymin.append(p_ymax)
  802. for i in range(0, len(type_score_ymin) - 1):
  803. w = p_xmax - p_xmin
  804. h = type_score_ymin[i + 1] - type_score_ymin[i]
  805. if max(w / h, h / w) < ASPECT_FLAG:
  806. solve_box = {'class_name': 'solve',
  807. 'bounding_box': {'xmin': int(p_xmin), 'ymin': int(type_score_ymin[i]),
  808. 'xmax': int(p_xmax), 'ymax': int(type_score_ymin[i + 1])}}
  809. sheet_dict_list.append(solve_box)
  810. res.append(solve_box)
  811. infer_polygon.remove(poly)
  812. # for poly in infer_polygon.copy(): # infer choice_m
  813. # for choice_polygon in all_choice_polygon:
  814. # cond1 = choice_polygon.within(poly) or choice_polygon.contains(poly)
  815. # cond2 = False
  816. # cond3 = choice_polygon.overlaps(poly)
  817. # if cond3:
  818. # intersection_poly = choice_polygon.intersection(poly)
  819. # cond2 = intersection_poly.area / poly.area >= 0.8
  820. #
  821. # if cond1 or cond2:
  822. # in_xmin, in_ymin, in_xmax, in_ymax = poly.bounds
  823. # choice_m_img = crop_region_direct(image, (int(in_xmin), int(in_ymin),
  824. # int(in_xmax), int(in_ymax)))
  825. # # cv2.imshow('m', choice_m_img)
  826. # # cv2.waitKey(0)
  827. # ocr_res = get_ocr_text_and_coordinate(choice_m_img)
  828. # char_a_min = []
  829. # char_d_max = []
  830. # for index, chars in enumerate(ocr_res):
  831. # for char in chars['chars']:
  832. # left, top = char['location']['left'], char['location']['top']
  833. # width, height = char['location']['width'], char['location']['height']
  834. # if char['char'] in 'abcdlABCD[]aabbccddAABBCCDD[[]]':
  835. # xm, ym = int(left - width / 2), int(top - height / 2)
  836. # char_a_min.append((xm, ym))
  837. # xx, yx = int(left + 3 * width / 2), int(top + 3 * height / 2)
  838. # char_d_max.append((xx, yx))
  839. # if char_a_min and char_d_max:
  840. # char_a_min_arr, char_d_max_arr = np.array(char_a_min), np.array(char_d_max)
  841. # tmp_min = np.min(char_a_min_arr, axis=0)
  842. # tmp_max = np.max(char_d_max_arr, axis=0)
  843. #
  844. # m_xmin, m_ymin, m_xmax, m_ymax = tmp_min[0], tmp_min[1], tmp_max[0], tmp_max[1]
  845. # dif_width = sum(all_choice_s_width) // len(all_choice_s_width) - (m_xmax - m_xmin)
  846. # choice_box = {'class_name': 'choice_m',
  847. # 'bounding_box': {'xmin': int(m_xmin) + int(in_xmin) - dif_width // 2,
  848. # 'ymin': int(m_ymin) + int(in_ymin),
  849. # 'xmax': int(m_xmax) + int(in_xmin) + dif_width // 2,
  850. # 'ymax': int(m_ymax) + int(in_ymin)
  851. # }}
  852. #
  853. # sheet_dict_list.append(choice_box)
  854. # infer_polygon.remove(poly)
  855. # res.append(choice_box)
  856. # break
  857. for poly in infer_polygon.copy(): # infer ocr blank
  858. flag = []
  859. for ocr in ocr_dict_list:
  860. location = ocr['location']
  861. xmin = location['left']
  862. ymin = location['top']
  863. xmax = location['left'] + location['width']
  864. ymax = location['top'] + location['height']
  865. box_polygon = Polygon([(xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)])
  866. cond1 = poly.within(box_polygon) or poly.contains(box_polygon)
  867. cond2 = False
  868. cond3 = box_polygon.overlaps(poly)
  869. if cond3:
  870. intersection_poly = box_polygon.intersection(poly)
  871. cond2 = intersection_poly.area / poly.area >= 0.2
  872. flag.append(cond1 or cond2 or False) # True 不是blank
  873. if True not in flag:
  874. in_xmin, in_ymin, in_xmax, in_ymax = poly.bounds
  875. blank_box = {'class_name': 'blank',
  876. 'bounding_box': {'xmin': int(in_xmin), 'ymin': int(in_ymin),
  877. 'xmax': int(in_xmax), 'ymax': int(in_ymax)}}
  878. # sheet_dict_list.append(solve_box)
  879. infer_polygon.remove(poly)
  880. res.append(blank_box)
  881. for poly in infer_polygon.copy(): # infer blank
  882. bounds = [int(ele) for ele in poly.bounds]
  883. img_region = crop_region_direct(image, bounds)
  884. img = cv2.threshold(img_region, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
  885. img_mean = np.mean(img)
  886. img_raw_mean = np.mean(img_region)
  887. # print(img_mean, img_raw_mean)
  888. cond = img_mean < PIX_VALUE_LOW or img_raw_mean > PIX_VALUE_HIGH
  889. if cond:
  890. in_xmin, in_ymin, in_xmax, in_ymax = bounds
  891. blank_box = {'class_name': 'blank',
  892. 'bounding_box': {'xmin': int(in_xmin), 'ymin': int(in_ymin),
  893. 'xmax': int(in_xmax), 'ymax': int(in_ymax)}}
  894. # sheet_dict_list.append(solve_box)
  895. infer_polygon.remove(poly)
  896. res.append(blank_box)
  897. # for poly in infer_polygon.copy(): # infer cloze_s
  898. # for cloze_polygon in all_cloze_polygon:
  899. # cond1 = cloze_polygon.within(poly) or cloze_polygon.contains(poly)
  900. # cond2 = False
  901. # cond3 = cloze_polygon.overlaps(poly)
  902. # if cond3:
  903. # intersection_poly = cloze_polygon.intersection(poly)
  904. # cond2 = intersection_poly.area / poly.area >= 0.8
  905. #
  906. # if cond1 or cond2:
  907. # in_xmin, in_ymin, in_xmax, in_ymax = poly.bounds
  908. # solve_box = {'class_name': 'cloze_s',
  909. # 'bounding_box': {'xmin': int(in_xmin), 'ymin': int(in_ymin),
  910. # 'xmax': int(in_xmax), 'ymax': int(in_ymax)}}
  911. #
  912. # sheet_dict_list.append(solve_box)
  913. # infer_polygon.remove(poly)
  914. # res.append(solve_box)
  915. # break
  916. for poly in infer_polygon.copy(): # infer solve
  917. in_xmin, in_ymin, in_xmax, in_ymax = poly.bounds
  918. w, h = in_xmax - in_xmin, in_ymax - in_ymin
  919. aspect_flag = max(w / h, h / w) < ASPECT_FLAG
  920. if aspect_flag:
  921. solve_box = {'class_name': 'solve_infer',
  922. 'bounding_box': {'xmin': int(in_xmin), 'ymin': int(in_ymin),
  923. 'xmax': int(in_xmax), 'ymax': int(in_ymax)}}
  924. else:
  925. solve_box = {'class_name': 'blank',
  926. 'bounding_box': {'xmin': int(in_xmin), 'ymin': int(in_ymin),
  927. 'xmax': int(in_xmax), 'ymax': int(in_ymax)}}
  928. sheet_dict_list.append(solve_box)
  929. infer_polygon.remove(poly)
  930. res.append(solve_box)
  931. if all_type_score_polygon:
  932. type_score_area = sum([ele.area for ele in all_type_score_polygon])
  933. mean_type_score_area = type_score_area/len(all_type_score_polygon)
  934. solve_filter = []
  935. for index, sheet_box in enumerate(sheet_dict_list.copy()):
  936. if sheet_box['class_name'] == 'solve_infer':
  937. w = sheet_box['bounding_box']['xmax'] - sheet_box['bounding_box']['xmin']
  938. h = sheet_box['bounding_box']['ymin'] - sheet_box['bounding_box']['ymin']
  939. if w * h < mean_type_score_area * 3:
  940. sheet_dict_list.remove(sheet_box)
  941. for ele in sheet_dict_list:
  942. if ele['class_name'] == 'solve_infer':
  943. ele.update({'class_name': 'solve'})
  944. return sheet_dict_list
  945. def box_infer_and_complete(image, sheet_region_dict, ocr=''):
  946. if len(image.shape) == 3:
  947. image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
  948. if len(image.shape) == 4:
  949. image = cv2.cvtColor(image, cv2.COLOR_RGBA2GRAY)
  950. exclude_classes = [
  951. 'cloze_s',
  952. 'exam_number_s',
  953. 'type_score',
  954. 'page',
  955. 'alarm_info',
  956. # 'score_collect',
  957. 'choice_m'
  958. 'choice_s',
  959. ] # 不找这些区域的延长线
  960. y, x = image.shape[0], image.shape[1]
  961. x1, x2 = subfield_answer_sheet(image, sheet_region_dict)
  962. # lon_split_line = []
  963. lon_split_line = [LineString([(px, 1), (px, y - 1)]) for px in [x1, x2] if px != 0]
  964. split_line_poly = [(px, 1, px + 1, y - 1) for px in [x1, x2] if px != 0]
  965. poly_list = infer_sheet_box(image, sheet_region_dict, lon_split_line, exclude_classes)
  966. image_cols = len(lon_split_line) + 1
  967. sheet_region_dict = infer_class(image, sheet_region_dict, poly_list, image_cols, ocr)
  968. return sheet_region_dict
  969. def infer_solve(sheet_dict_list, left, right, top, bottom, h, w, col_regions, col_split, subject='math'):
  970. col_split.insert(0, left)
  971. col_split.append(right)
  972. boundary_list = [(split, 1, col_split[i+1]-1, h-1) for i, split in enumerate(col_split[:-1])]
  973. infer_polygon = []
  974. for col_i, col in enumerate(col_regions):
  975. if not col:
  976. continue
  977. class_names = [ele['class_name'] for ele in col]
  978. # 确定是正面第一栏
  979. if ((subject != 'chinese')
  980. and
  981. (any([ele in class_names for ele in ['infer_title', 'bar_code', 'choice_m', 'cloze_s']]))):
  982. bottom_box = col[-1]
  983. bottom_box_ymax = bottom_box['bounding_box']['ymax']
  984. infer_loc = {'xmin': boundary_list[col_i][0], 'ymin': bottom_box_ymax+5,
  985. 'xmax': boundary_list[col_i][2], 'ymax': boundary_list[col_i][3]}
  986. box_polygon = Polygon([(infer_loc['xmin'], infer_loc['ymin']),
  987. (infer_loc['xmax'], infer_loc['ymin']),
  988. (infer_loc['xmax'], infer_loc['ymax']),
  989. (infer_loc['xmin'], infer_loc['ymax']),])
  990. infer_polygon.append(box_polygon)
  991. else:
  992. y_axis = np.zeros(h)
  993. y_axis[top:top+1] = 1
  994. y_axis[bottom-1:bottom] = 1
  995. for ele in col:
  996. ymin = ele['bounding_box']['ymin']
  997. ymax = ele['bounding_box']['ymax']
  998. y_axis[ymin:ymax+1] = 1
  999. split_index_list = []
  1000. split_index = 0
  1001. for i, ele in enumerate(y_axis):
  1002. split_index = split_index % 2
  1003. if ele == split_index:
  1004. # print(i)
  1005. split_index = split_index + 1
  1006. split_index_list.append(i)
  1007. for i in range(0, len(split_index_list)-1, 2):
  1008. ymin = split_index_list[i] + 3
  1009. ymax = split_index_list[i+1] - 3
  1010. infer_loc = {'xmin': boundary_list[col_i][0], 'ymin': ymin,
  1011. 'xmax': boundary_list[col_i][2], 'ymax': ymax}
  1012. box_polygon = Polygon([(infer_loc['xmin'], infer_loc['ymin']),
  1013. (infer_loc['xmax'], infer_loc['ymin']),
  1014. (infer_loc['xmax'], infer_loc['ymax']),
  1015. (infer_loc['xmin'], infer_loc['ymax']), ])
  1016. infer_polygon.append(box_polygon)
  1017. res = []
  1018. all_type_score_polygon = []
  1019. for region_box in sheet_dict_list:
  1020. if region_box['class_name'] in ['type_score']:
  1021. coordinates = region_box['bounding_box']
  1022. xmin = coordinates['xmin']
  1023. ymin = coordinates['ymin']
  1024. xmax = coordinates['xmax']
  1025. ymax = coordinates['ymax']
  1026. box_polygon = Polygon([(xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)])
  1027. all_type_score_polygon.append(box_polygon)
  1028. for poly in infer_polygon.copy(): # infer type_score solve
  1029. p_xmin, p_ymin, p_xmax, p_ymax = poly.bounds
  1030. type_score_num = 0
  1031. type_score_ymin = []
  1032. for type_score_polygon in all_type_score_polygon:
  1033. cond1 = type_score_polygon.within(poly)
  1034. cond2 = False
  1035. cond3 = type_score_polygon.overlaps(poly)
  1036. if cond3:
  1037. intersection_poly = type_score_polygon.intersection(poly)
  1038. d1 = intersection_poly.area / type_score_polygon.area >= TYPE_SCORE_MNS
  1039. print('type_score:', intersection_poly.area / type_score_polygon.area)
  1040. d2 = type_score_polygon.area < 0.5 * poly.area
  1041. cond2 = d1 and d2
  1042. if cond1 or cond2:
  1043. t_xmin, t_ymin, t_xmax, t_ymax = type_score_polygon.bounds
  1044. x_dif = t_xmin - p_xmin
  1045. type_score_num += 1
  1046. type_score_ymin.append(t_ymin)
  1047. t_height = t_ymax - t_ymin
  1048. if t_ymin - p_ymin > 3 * t_height:
  1049. type_score_num += 1
  1050. type_score_ymin.append(p_ymin)
  1051. if type_score_num == 1:
  1052. w, h = p_xmax-p_xmin, p_ymax-p_ymin
  1053. w, h = max(w, 0.1), max(h, 0.1)
  1054. aspect_flag = max(w / h, h / w) < ASPECT_FLAG
  1055. if aspect_flag:
  1056. solve_box = {'class_name': 'solve',
  1057. 'bounding_box': {'xmin': int(p_xmin), 'ymin': int(p_ymin),
  1058. 'xmax': int(p_xmax), 'ymax': int(p_ymax)}}
  1059. infer_polygon.remove(poly)
  1060. res.append(solve_box)
  1061. else:
  1062. solve_box = {'class_name': 'solve_with_type_score_without_aspect',
  1063. 'bounding_box': {'xmin': int(p_xmin), 'ymin': int(p_ymin),
  1064. 'xmax': int(p_xmax), 'ymax': int(p_ymax)}}
  1065. infer_polygon.remove(poly)
  1066. res.append(solve_box)
  1067. if type_score_num > 1: # 多type_score
  1068. type_score_ymin = sorted(type_score_ymin)
  1069. type_score_ymin[0] = min(p_ymin, type_score_ymin[0])
  1070. type_score_ymin.append(p_ymax)
  1071. for col_i in range(0, len(type_score_ymin) - 1):
  1072. w = p_xmax - p_xmin
  1073. h = type_score_ymin[col_i + 1] - type_score_ymin[col_i]
  1074. w, h = max(w, 0.1), max(h, 0.1)
  1075. aspect_flag = max(w / h, h / w) < ASPECT_FLAG
  1076. if aspect_flag:
  1077. solve_box = {'class_name': 'solve',
  1078. 'bounding_box': {'xmin': int(p_xmin), 'ymin': int(type_score_ymin[col_i]-5),
  1079. 'xmax': int(p_xmax), 'ymax': int(type_score_ymin[col_i + 1])}}
  1080. res.append(solve_box)
  1081. infer_polygon.remove(poly)
  1082. for poly in infer_polygon.copy(): # infer solve
  1083. in_xmin, in_ymin, in_xmax, in_ymax = poly.bounds
  1084. w, h = in_xmax - in_xmin, in_ymax - in_ymin
  1085. w, h = max(w, 0.1), max(h, 0.1)
  1086. aspect_flag = max(w / h, h / w) < ASPECT_FLAG
  1087. if aspect_flag:
  1088. solve_box = {'class_name': 'solve_without_type_score',
  1089. 'bounding_box': {'xmin': int(in_xmin), 'ymin': int(in_ymin),
  1090. 'xmax': int(in_xmax), 'ymax': int(in_ymax)}}
  1091. else:
  1092. solve_box = {'class_name': 'w_h_blank',
  1093. 'bounding_box': {'xmin': int(in_xmin), 'ymin': int(in_ymin),
  1094. 'xmax': int(in_xmax), 'ymax': int(in_ymax)}}
  1095. infer_polygon.remove(poly)
  1096. res.append(solve_box)
  1097. return res