sheet_infer.py 54 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192
  1. # @Author : lightXu
  2. # @File : sheet_infer.py
  3. # @Time : 2019/9/26 0026 上午 10:18
  4. import itertools
  5. import os
  6. import re
  7. import traceback
  8. import xml.etree.cElementTree as ET
  9. from itertools import combinations
  10. import cv2
  11. import numpy as np
  12. from shapely.geometry import LineString, Polygon
  13. from segment.sheet_resolve.tools.utils import create_xml, crop_region_direct, crop_region, image_hash_detection_simple
  14. from segment.sheet_resolve.tools.brain_api import get_ocr_text_and_coordinate
  15. ASPECT_FLAG = 4.0
  16. REMAIN_RATIO = 0.1
  17. PIX_VALUE_LOW = 15.0
  18. PIX_VALUE_HIGH = 245
  19. TYPE_SCORE_MNS = 0.5
  20. def _get_char_near_img(char_location, near):
  21. left = char_location['left']
  22. top = char_location['top']
  23. width = char_location['width']
  24. height = char_location['height']
  25. next_location = char_location
  26. if near == 'left':
  27. next_location = {'left': int(left - 1.5 * width), 'top': top, 'width': width, 'height': height}
  28. if near == 'right':
  29. next_location = {'left': int(left + 1.5 * width), 'top': top, 'width': width, 'height': height}
  30. if near == 'up':
  31. next_location = {'left': left, 'top': int(top - 1.5 * height), 'width': width, 'height': height}
  32. if near == 'down':
  33. next_location = {'left': left, 'top': int(top + 1.5 * height), 'width': width, 'height': height}
  34. return next_location
  35. def _get_board(image, location, direction):
  36. std = 0
  37. next_location = location
  38. while std < 10:
  39. next_location = _get_char_near_img(next_location, direction)
  40. box = (next_location['left'], next_location['top'],
  41. next_location['left'] + next_location['width'],
  42. next_location['top'] + next_location['height'],)
  43. region = crop_region_direct(image, box)
  44. std = np.var(region)
  45. return next_location
  46. def infer_bar_code(image, ocr_dict_list, attention_region):
  47. attention_polygon_list = []
  48. for attention in attention_region:
  49. coordinates = attention['bounding_box']
  50. xmin = coordinates['xmin']
  51. ymin = coordinates['ymin']
  52. xmax = coordinates['xmax']
  53. ymax = coordinates['ymax']
  54. attention_polygon = Polygon([(xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)])
  55. attention_polygon_list.append(attention_polygon)
  56. img_cols, img_rows = image.shape[0], image.shape[1]
  57. pattern = r'条形码|条码|条形|形码'
  58. bar_code_dict_list = []
  59. for index, ele in enumerate(ocr_dict_list):
  60. words = ele['words'].replace(' ', '')
  61. chars_list = ele['chars']
  62. length = len(chars_list)
  63. match_list = [(m.group(), m.span()) for m in re.finditer(pattern, words) if m]
  64. if match_list: # 不为空
  65. for match in match_list:
  66. start_index = match[1][0]
  67. end_index = match[1][1] - 1
  68. for i in range(start_index - 1, -1, -1):
  69. xmin_start = chars_list[start_index]['location']['left']
  70. start_tmp = chars_list[i]['location']['left'] + 2 * chars_list[i]['location']['width']
  71. if xmin_start <= start_tmp:
  72. start_index = i
  73. for i in range(end_index, length):
  74. xmax_end = chars_list[end_index]['location']['left'] + 2 * chars_list[i]['location']['width']
  75. end_tmp = chars_list[i]['location']['left']
  76. if xmax_end >= end_tmp:
  77. end_index = i
  78. bar_code_char_xmin = chars_list[start_index]['location']["left"]
  79. bar_code_char_xmax = chars_list[end_index]['location']["left"]+chars_list[end_index]['location']["width"]
  80. bar_code_char_ymin = chars_list[start_index]['location']["top"]
  81. bar_code_char_ymax = chars_list[end_index]['location']["top"]+chars_list[end_index]['location']["height"]
  82. bar_code_char_polygon = Polygon([(bar_code_char_xmin, bar_code_char_ymin),
  83. (bar_code_char_xmax, bar_code_char_ymin),
  84. (bar_code_char_xmax, bar_code_char_ymax),
  85. (bar_code_char_xmin, bar_code_char_ymax)])
  86. contain_cond = [False]*len(attention_polygon_list)
  87. for i, attention_ele in enumerate(attention_polygon_list):
  88. if attention_ele.contains(bar_code_char_polygon):
  89. contain_cond[i] = True
  90. if True not in contain_cond: # 条形码文字不在attention里面
  91. left_board_location = _get_board(image, chars_list[start_index]['location'], 'left')
  92. right_board_location = _get_board(image, chars_list[end_index]['location'], 'right')
  93. up_board_location = _get_board(image, chars_list[start_index]['location'], 'up')
  94. down_board_location = _get_board(image, chars_list[end_index]['location'], 'down')
  95. xmin = left_board_location['left']
  96. ymin = up_board_location['top']
  97. xmax = right_board_location['left'] + right_board_location['width']
  98. ymax = down_board_location['top'] + down_board_location['height']
  99. xmin = int(xmin) if xmin >= 1 else 1
  100. ymin = int(ymin) if ymin >= 1 else 1
  101. xmax = int(xmax) if xmax <= img_cols - 1 else img_cols - 1
  102. ymax = int(ymax) if ymax <= img_rows - 1 else img_rows - 1
  103. bar_code_dict = {'class_name': 'bar_code',
  104. 'bounding_box': {'xmin': xmin, 'ymin': ymin, 'xmax': xmax, 'ymax': ymax}}
  105. bar_code_dict_list.append(bar_code_dict)
  106. # print(bar_code_dict)
  107. break # 默认只有一个条形码
  108. else:
  109. continue
  110. # 过滤attention 区域存在条形码的文字
  111. for bar_code in bar_code_dict_list.copy():
  112. coordinates = bar_code['bounding_box']
  113. xmin = coordinates['xmin']
  114. ymin = coordinates['ymin']
  115. xmax = coordinates['xmax']
  116. ymax = coordinates['ymax']
  117. bar_code_polygon = Polygon([(xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)])
  118. for attention_polygon in attention_polygon_list:
  119. cond1 = bar_code_polygon.within(attention_polygon) or bar_code_polygon.contains(attention_polygon)
  120. cond2 = False
  121. cond3 = bar_code_polygon.overlaps(attention_polygon)
  122. if cond3:
  123. intersection_poly = bar_code_polygon.intersection(attention_polygon)
  124. cond2 = intersection_poly.area / bar_code_polygon.area >= 0.01
  125. cond3 = intersection_poly.area / attention_polygon.area >= 0.01
  126. if cond1 or cond2 or cond3:
  127. bar_code_dict_list.remove(bar_code)
  128. break
  129. return bar_code_dict_list
  130. def infer_exam_number(image, ocr_dict_list, existed_regions, times_threshold=5):
  131. # existed_polygon_list = []
  132. # for region in existed_regions:
  133. # coordinates = region['bounding_box']
  134. # xmin = coordinates['xmin']
  135. # ymin = coordinates['ymin']
  136. # xmax = coordinates['xmax']
  137. # ymax = coordinates['ymax']
  138. # existed_polygon = Polygon([(xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)])
  139. # existed_polygon_list.append(existed_polygon)
  140. img_rows, img_cols = image.shape[0], image.shape[1]
  141. exam_number_dict_list = []
  142. xmin, ymin, xmax, ymax = 9999, 9999, 0, 0
  143. pattern = r'[0oO]|[2-9]' # 除去1,避免[]被识别为1
  144. exclude = r'分|题|[ABD]'
  145. key_digital = []
  146. all_height = []
  147. cols = []
  148. for index, ele in enumerate(ocr_dict_list):
  149. words = ele['words'].replace(' ', '')
  150. match_list = [(m.group(), m.span()) for m in re.finditer(pattern, words) if m]
  151. exclude_list = [(m.group(), m.span()) for m in re.finditer(exclude, words, re.I) if m]
  152. match_digital_arr = np.asarray([int(char[0].replace('o', '0').replace('O', '0')) for char in match_list])
  153. if len(match_digital_arr) > 0:
  154. counts = np.bincount(match_digital_arr)
  155. mode_times = np.max(counts)
  156. if mode_times >= times_threshold and len(exclude_list) < 1:
  157. mode_value = np.argmax(counts) # 众数,避免考号末尾出现的其他数字
  158. key_index = np.where(match_digital_arr == mode_value)[0]
  159. cols.append(len(key_index))
  160. start_index = match_list[key_index[0]][1][0]
  161. end_index = match_list[key_index[-1]][1][0]
  162. xmin_t = ele['chars'][start_index]['location']['left']
  163. ymin_t = ele['chars'][start_index]['location']['top']
  164. xmax_t = ele['chars'][end_index]['location']['left'] + ele['chars'][end_index]['location']['width']
  165. ymax_t = ele['chars'][end_index]['location']['top'] + ele['chars'][end_index]['location']['height']
  166. mean_width = sum([int(ele['chars'][match_list[i][1][0]]['location']['width'])
  167. for i in key_index]) // len(key_index)
  168. mean_height = sum([int(ele['chars'][match_list[i][1][0]]['location']['height'])
  169. for i in key_index]) // len(key_index)
  170. all_height.append(mean_height)
  171. xmin = min(xmin, xmin_t-mean_width)
  172. ymin = min(ymin, ymin_t)
  173. xmax = max(xmax, xmax_t+mean_width)
  174. ymax = max(ymax, ymax_t)
  175. xmin = int(xmin) if xmin >= 1 else 1
  176. ymin = int(ymin) if ymin >= 1 else 1
  177. xmax = int(xmax) if xmax <= img_cols - 1 else img_cols - 1
  178. ymax = int(ymax) if ymax <= img_rows - 1 else img_rows - 1
  179. key_digital.append(mode_value)
  180. if 9 in key_digital:
  181. break
  182. if 0 in key_digital and 9 in key_digital:
  183. mean_height = sum(all_height)//10
  184. exam_number_dict = {'class_name': 'exam_number',
  185. 'bounding_box': {'xmin': xmin, 'ymin': ymin, 'xmax': xmax, 'ymax': ymax+mean_height},
  186. 'rows': 10,
  187. 'cols': max(cols)
  188. }
  189. exam_number_dict_list.append(exam_number_dict)
  190. return exam_number_dict_list
  191. else:
  192. if len(key_digital) > 1:
  193. dgt_min = min(key_digital)
  194. dgt_max = max(key_digital)
  195. mean_height = sum(all_height)//len(all_height)
  196. dif = dgt_max - dgt_min
  197. blank_height = ymax - ymin - mean_height * (dif+1)
  198. mean_blank = blank_height // dif
  199. upper_height = dgt_min * (mean_blank + mean_height) + mean_blank//2
  200. downward_height = (9-dgt_max) * (mean_blank + mean_height) + mean_blank
  201. exam_number_dict = {'class_name': 'exam_number',
  202. 'bounding_box': {'xmin': xmin, 'ymin': ymin-upper_height,
  203. 'xmax': xmax, 'ymax': ymax+downward_height},
  204. 'rows': 10,
  205. 'cols': max(cols)}
  206. exam_number_dict_list.append(exam_number_dict)
  207. if len(key_digital) == 1:
  208. dgt_min = dgt_max = min(key_digital)
  209. eval_height = sum(all_height)//len(all_height) * 1.5
  210. upper_height = dgt_min * eval_height
  211. downward_height = (9-dgt_max) * eval_height
  212. exam_number_dict = {'class_name': 'exam_number',
  213. 'bounding_box': {'xmin': xmin, 'ymin': ymin-upper_height,
  214. 'xmax': xmax, 'ymax': ymax+downward_height},
  215. 'rows': 10,
  216. 'cols': max(cols)}
  217. exam_number_dict_list.append(exam_number_dict)
  218. iou_cond = True
  219. exam_number_dict_list_check = []
  220. for exam_number_dict in exam_number_dict_list:
  221. exam_number_polygon = Polygon([(exam_number_dict["xmin"], exam_number_dict["ymin"]),
  222. (exam_number_dict["xmax"], exam_number_dict["ymin"]),
  223. (exam_number_dict["xmax"], exam_number_dict["ymax"]),
  224. (exam_number_dict["xmin"], exam_number_dict["ymax"])])
  225. for region in existed_regions:
  226. class_name = region["class_name"]
  227. if class_name in ["attention", "solve", "choice", "choice_m", 'choice_s', "cloze", 'cloze_s',
  228. 'bar_code', 'qr_code', 'composition', 'solve0']:
  229. coordinates = region['bounding_box']
  230. xmin = coordinates['xmin']
  231. ymin = coordinates['ymin']
  232. xmax = coordinates['xmax']
  233. ymax = coordinates['ymax']
  234. existed_polygon = Polygon([(xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)])
  235. overlab_area = existed_polygon.intersection(exam_number_polygon).area
  236. iou = overlab_area / (exam_number_polygon.area + existed_polygon.area - overlab_area)
  237. if iou > 0:
  238. iou_cond = False
  239. break
  240. if iou_cond:
  241. exam_number_dict_list_check.append(exam_number_polygon)
  242. return exam_number_dict_list_check
  243. def adjust_exam_number(regions):
  244. exam_number_w_regions = list()
  245. exam_number_regions = list()
  246. for i in range(len(regions) - 1, -1, -1):
  247. region = regions[i]
  248. if region['class_name'] == 'exam_number_w':
  249. exam_number_w_regions.append(region)
  250. if region['class_name'] == 'exam_number':
  251. exam_number_regions.append(region)
  252. regions.pop(i)
  253. exam_number_region = exam_number_regions[0]
  254. if len(exam_number_regions) > 1:
  255. exam_number_regions = sorted(exam_number_regions, key=lambda x: x['bounding_box']['ymin'])
  256. exam_number_region = exam_number_regions[0]
  257. exam_number_w_index = 0
  258. if len(exam_number_w_regions) > 1:
  259. distance = [abs(int(ele['bounding_box']['ymax']) - int(exam_number_region['bounding_box']['ymin ']))
  260. for ele in exam_number_w_regions]
  261. exam_number_w_index = distance.index(min(distance))
  262. exam_number_w_region = exam_number_w_regions[exam_number_w_index]
  263. standard = exam_number_w_region['bounding_box']
  264. exam_number_region['bounding_box'].update({'xmin': standard['xmin'], 'xmax': standard['xmax']})
  265. regions.append(exam_number_region)
  266. return regions
  267. def exam_number_infer_by_s(image, regions):
  268. exam_number_s_list = [ele for ele in regions if ele['class_name'] == 'exam_number_s'
  269. and (int(ele['bounding_box']['xmax'])-int(ele['bounding_box']['xmin']) <
  270. int(ele['bounding_box']['ymax'])-int(ele['bounding_box']['ymin']))]
  271. # 找边界
  272. exam_number_s_list = sorted(exam_number_s_list, key=lambda x: x['bounding_box']['xmin'])
  273. left_limit = exam_number_s_list[0]['bounding_box']['xmin']
  274. right_limit = exam_number_s_list[-1]['bounding_box']['xmax']
  275. left_image = crop_region(image, exam_number_s_list[0]['bounding_box'])
  276. right_image = crop_region(image, exam_number_s_list[-1]['bounding_box'])
  277. mean_width = sum([int(ele['bounding_box']['xmax'])-int(ele['bounding_box']['xmin'])
  278. for ele in exam_number_s_list]) // len(exam_number_s_list)
  279. top_limit = min([ele['bounding_box']['ymin'] for ele in exam_number_s_list])
  280. bottom_limit = max([ele['bounding_box']['ymax'] for ele in exam_number_s_list])
  281. left_infer = True
  282. while left_infer:
  283. infer_box_xmin = int(left_limit - 1.5*mean_width)
  284. infer_box_xmax = int(left_limit - 0.5*mean_width)
  285. infer_box_ymin = int(exam_number_s_list[0]['bounding_box']['ymin'])
  286. infer_box_ymax = int(exam_number_s_list[0]['bounding_box']['ymax'])
  287. infer_image = crop_region_direct(image, [infer_box_xmin, infer_box_ymin, infer_box_xmax, infer_box_ymax])
  288. simi = image_hash_detection_simple(left_image, infer_image)
  289. print('l:', simi)
  290. if simi >= 0.85:
  291. left_limit = infer_box_xmin
  292. else:
  293. left_infer = False
  294. right_infer = True
  295. while right_infer:
  296. infer_box_xmin = int(right_limit + 0.5 * mean_width)
  297. infer_box_xmax = int(right_limit + 1.5 * mean_width)
  298. infer_box_ymin = int(exam_number_s_list[-1]['bounding_box']['ymin'])
  299. infer_box_ymax = int(exam_number_s_list[-1]['bounding_box']['ymax'])
  300. infer_image = crop_region_direct(image, [infer_box_xmin, infer_box_ymin, infer_box_xmax, infer_box_ymax])
  301. simi = image_hash_detection_simple(right_image, infer_image)
  302. print('r:', simi)
  303. if simi >= 0.70:
  304. right_limit = infer_box_xmax
  305. else:
  306. right_infer = False
  307. infer_exam_number_region = {'xmin': left_limit, 'xmax': right_limit, 'ymin': top_limit, 'ymax': bottom_limit, }
  308. exam_dict_list = [{'class_name': 'exam_number', 'bounding_box': infer_exam_number_region}]
  309. # print(exam_dict_list)
  310. return exam_dict_list
  311. def gen_xml_new(path, ocr_list):
  312. tree = ET.parse(r'../../tools/000000-template.xml') # xml tree
  313. for index, ele in enumerate(ocr_list):
  314. words = ele['words']
  315. location = ele['location']
  316. xmin = location['xmin']
  317. ymin = location['ymin']
  318. xmax = location['xmax']
  319. ymax = location['ymax']
  320. tree = create_xml('{}'.format(words), tree, str(xmin), str(ymin), str(xmax), str(ymax))
  321. # print(exam_items_bbox)
  322. tree.write(path.replace('.jpg', '.xml'))
  323. def subfield_answer_sheet(img0, answer_sheet):
  324. h, w = img0.shape[:2]
  325. one_part = 0
  326. line_xmax_1 = 0
  327. line_xmax_2 = 0
  328. modules = []
  329. modules11 = []
  330. w_int_1 = w
  331. w_int_2 = round(w / 2)
  332. w_int_3 = round(w / 3)
  333. w_int_4 = round(w / 4)
  334. w_int_8 = round(w / 8)
  335. if w_int_8 < 50:
  336. w_int_8 = 50
  337. key_modules_classes = ['choice', 'cloze', 'solve', 'solve0', 'composition0', 'composition', 'correction',
  338. 'ban_area', ]
  339. if h > w: # 暂定答题卡高大于宽的为单栏
  340. one_part = 1
  341. else:
  342. temp1 = 0
  343. temp2 = 0
  344. for ele in answer_sheet:
  345. if ele["class_name"] in key_modules_classes:
  346. modules.append(ele)
  347. modules_xmin = sorted(modules, key=lambda x: (x['bounding_box']['xmin']))
  348. modules_xmax = sorted(modules, key=lambda x: (x['bounding_box']['xmax']))
  349. for i in range(len(modules_xmin) - 1):
  350. if i == 0 and modules_xmin[0]['bounding_box']['xmin'] - 0 > w_int_4:
  351. temp1 = 1
  352. else:
  353. if modules_xmin[i + 1]['bounding_box']['xmin'] - modules_xmin[i]['bounding_box']['xmax'] > w_int_4:
  354. if modules11 == []:
  355. line_xmax_1 = modules_xmin[i]['bounding_box']['xmax'] + 20
  356. line_xmax_2 = modules_xmin[i + 1]['bounding_box']['xmin'] - 20
  357. else:
  358. modules11.append(modules_xmin[i]['bounding_box']['xmax'])
  359. modules11_xmax = sorted(modules11)[-1]
  360. line_xmax_1 = modules11_xmax + 20
  361. line_xmax_2 = modules_xmin[i + 1]['bounding_box']['xmin'] - 20
  362. modules11 = []
  363. temp1 = 1
  364. temp2 = 1
  365. break
  366. elif modules_xmin[i + 1]['bounding_box']['xmin'] - modules_xmin[i]['bounding_box']['xmax'] > -w_int_8:
  367. if temp1 == 0:
  368. if modules11 == []:
  369. line_xmax_1 = int((modules_xmin[i + 1]['bounding_box']['xmin'] +
  370. modules_xmin[i]['bounding_box']['xmax']) / 2)
  371. else:
  372. modules11.append(modules_xmin[i]['bounding_box']['xmax'])
  373. modules11_xmax = sorted(modules11)[-1]
  374. line_xmax_1 = int((modules_xmin[i + 1]['bounding_box']['xmin'] +
  375. modules11_xmax) / 2)
  376. modules11 = []
  377. temp1 = 1
  378. elif temp1 == 1:
  379. if modules11 == []:
  380. line_xmax_2 = int((modules_xmin[i + 1]['bounding_box']['xmin'] +
  381. modules_xmin[i]['bounding_box']['xmax']) / 2)
  382. else:
  383. modules11.append(modules_xmin[i]['bounding_box']['xmax'])
  384. modules11_xmax = sorted(modules11)[-1]
  385. line_xmax_2 = int((modules_xmin[i + 1]['bounding_box']['xmin'] +
  386. modules11_xmax) / 2)
  387. temp2 = 1
  388. else:
  389. modules11.append(modules_xmin[i]['bounding_box']['xmax'])
  390. if temp1 == 0 and temp2 == 0:
  391. if modules_xmax[-1]['bounding_box']['xmax'] - w < -(2 * w_int_4):
  392. line_xmax_1 = modules_xmax[-1]['bounding_box']['xmax'] + 20
  393. line_xmax_2 = 2 * w_int_3
  394. elif modules_xmax[-1]['bounding_box']['xmax'] - w < -w_int_4:
  395. line_xmax_1 = modules_xmax[-1]['bounding_box']['xmax'] + 20
  396. elif temp1 == 1 and temp2 == 0:
  397. if modules_xmax[-1]['bounding_box']['xmax'] - w < -w_int_4:
  398. line_xmax_2 = 2 * w_int_3
  399. return line_xmax_1, line_xmax_2
  400. def get_intersection_point(lines, orthogonal_lines, border):
  401. intersect_point_list = []
  402. for line in lines:
  403. width_min, height_min, width_max, height_max = border
  404. (x_l, y_u), (x_r, y_d) = line.coords
  405. x_l = x_l if x_l > width_min else width_min + 1 # 避免边界
  406. x_r = x_r if x_r < width_max else width_max - 1
  407. y_u = y_u if y_u > height_min else height_min + 1
  408. y_d = y_d if y_d < height_max else height_max - 1
  409. points_list = []
  410. if x_l == x_r:
  411. line_direction = 'lon'
  412. raw_line = LineString([(x_l, y_u), (x_r, y_d)])
  413. extend_line = LineString([(x_l, height_min), (x_r, height_max)])
  414. points_list.extend([height_min + 1, height_max - 1]) # 延长线与边界交点,并避免key_point位于现有边界上
  415. line_start, line_end = y_u, y_d
  416. else:
  417. line_direction = 'lat'
  418. raw_line = LineString([(x_l, y_u), (x_r, y_d)])
  419. extend_line = LineString([(width_min, y_u), (width_max, y_d)])
  420. points_list.extend([width_min + 1, width_max - 1]) # 延长线与边界交点,并避免key_point位于现有边界上
  421. line_start, line_end = x_l, x_r
  422. for ele in orthogonal_lines:
  423. cond1 = extend_line.intersects(ele) # T, L, 十交叉
  424. cond2 = extend_line.crosses(ele) # 十字交叉
  425. cond3 = raw_line.intersects(ele)
  426. cond4 = raw_line.crosses(ele)
  427. if line_direction == 'lat':
  428. if cond3:
  429. (xp, yp) = raw_line.intersection(ele).bounds[:2]
  430. intersect_point_list.append((xp, yp))
  431. elif cond1:
  432. (xp, yp) = extend_line.intersection(ele).bounds[:2]
  433. points_list.append(xp)
  434. if line_direction == 'lon':
  435. if cond3:
  436. (xp, yp) = raw_line.intersection(ele).bounds[:2]
  437. intersect_point_list.append((xp, yp))
  438. elif cond1:
  439. (xp, yp) = extend_line.intersection(ele).bounds[:2]
  440. points_list.append(yp)
  441. points_array = np.asarray(points_list, dtype=np.uint)
  442. left_key = np.max(points_array[points_array <= line_start])
  443. right_key = np.min(points_array[points_array >= line_end]) # 延长线两边延长并取得第一个交点
  444. if line_direction == 'lat':
  445. intersect_point = [(left_key, y_u), (right_key, y_d)]
  446. else:
  447. intersect_point = [(x_l, left_key), (x_r, right_key)]
  448. # print(intersect_point)
  449. intersect_point_list.extend(intersect_point)
  450. return intersect_point_list
  451. def infer_sheet_box(image, sheet_dict, lon_split_line, exclude_classes):
  452. height_max, width_max = image.shape[0], image.shape[1]
  453. height_min, width_min = 0, 0
  454. latitude = []
  455. longitude = []
  456. lines = []
  457. sheet_polygons = []
  458. all_sheet_polygons = []
  459. choice_polygon = []
  460. # exclude_classes = ['cloze_s', 'exam_number_s', 'choice_s', 'type_score',
  461. # 'mark', 'page', 'exam_number_s', 'cloze_score', 'name_w',
  462. # 'class_w',]
  463. h_min = []
  464. h_max = []
  465. for index, region_box in enumerate(sheet_dict):
  466. coordinates = region_box['bounding_box']
  467. xmin = coordinates['xmin']
  468. ymin = coordinates['ymin']
  469. xmax = coordinates['xmax']
  470. ymax = coordinates['ymax']
  471. if region_box['class_name'] == 'info_title': # 上限
  472. h_min.append(ymin)
  473. if region_box['class_name'] == 'page': # 下限
  474. h_max.append(ymin)
  475. if region_box['class_name'] == 'alarm_info':
  476. h_min.append(ymin)
  477. h_max.append(ymin)
  478. if h_min:
  479. hgt_min = min(h_min)
  480. if hgt_min < height_max / 4:
  481. height_min = hgt_min
  482. if h_max:
  483. hgt_max = max(h_max)
  484. if hgt_max > 3 * height_max / 4:
  485. height_max = hgt_max
  486. # height_min = h_min if h_min != 9999 else height_min
  487. # height_max = h_max if h_max != 0 else height_max
  488. for index, region_box in enumerate(sheet_dict):
  489. coordinates = region_box['bounding_box']
  490. xmin = coordinates['xmin']
  491. ymin = coordinates['ymin']
  492. xmax = coordinates['xmax']
  493. ymax = coordinates['ymax']
  494. box_polygon = Polygon([(xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)])
  495. if region_box['class_name'] not in exclude_classes:
  496. if region_box['class_name'] not in ['choice', 'cloze']: # 推断选择题区域内的choice_m
  497. sheet_polygons.append(box_polygon)
  498. if region_box['class_name'] == 'choice':
  499. choice_polygon.append(box_polygon)
  500. all_sheet_polygons.append(box_polygon)
  501. line1 = LineString([(xmin, ymin), (xmin, ymax)])
  502. line2 = LineString([(xmax, ymin), (xmax, ymax)])
  503. line3 = LineString([(xmin, ymin), (xmax, ymin)])
  504. line4 = LineString([(xmin, ymax), (xmax, ymax)])
  505. lines.extend([line1, line2, line3, line4])
  506. longitude.extend([line1, line2])
  507. latitude.extend([line3, line4])
  508. # sheet_polygons 去除包裹的情况
  509. sheet_polygons_ = list(combinations(sheet_polygons, 2))
  510. for polygons in sheet_polygons_:
  511. if polygons[0].within(polygons[1]) or polygons[0].contains(polygons[1]):
  512. area_list = [polygons[0].area, polygons[1].area]
  513. min_polygon = polygons[area_list.index(min(area_list))]
  514. if min_polygon in sheet_polygons:
  515. sheet_polygons.remove(min_polygon)
  516. min_polygon = sorted(all_sheet_polygons, key=lambda p: p.area)[0]
  517. avg_area = sum([polygon.area for polygon in sheet_polygons]) / len(sheet_polygons)
  518. # 所有矩形框的延长线与矩形框集图像边界的交点
  519. latitude = sorted(latitude, key=lambda x: x.bounds[1]) # y
  520. longitude = sorted(longitude, key=lambda x: x.bounds[0]) # x
  521. lat_intersect_point_list = get_intersection_point(latitude, longitude,
  522. (width_min, height_min, width_max, height_max))
  523. lon_intersect_point_list = get_intersection_point(longitude, latitude,
  524. (width_min, height_min, width_max, height_max))
  525. raw_corner = [(width_min + 1, height_min + 1), (width_min + 1, height_max - 1), (width_max - 1, 1),
  526. (width_max - 1, height_max - 1)]
  527. # raw_corner = []
  528. intersect_point_list = lat_intersect_point_list + lon_intersect_point_list + raw_corner
  529. intersect_point_list = list(set(intersect_point_list))
  530. intersect_point_dict = {k: index + 1 for index, k in enumerate(intersect_point_list)}
  531. def _filter_rect(p_list):
  532. flag = 0
  533. for ele in p_list:
  534. try:
  535. flag = intersect_point_dict[ele]
  536. except KeyError:
  537. flag = 0
  538. break
  539. if flag > 0:
  540. x_c = sum([ele[0] for ele in p_list]) / 4
  541. y_c = sum([ele[1] for ele in p_list]) / 4
  542. d1, d2, d3, d4 = [LineString([p, (x_c, y_c)]).length for p in p_list]
  543. return (0 not in [d1, d2, d3, d4]) and d1 == d2 and d1 == d3 and d1 == d4
  544. else:
  545. return False
  546. def _find_rect(point):
  547. (x1, y1) = point[0]
  548. (x2, y2) = point[1]
  549. if x1 != x2 and y1 != y2:
  550. xmin, ymin = min(x1, x2), min(y1, y2)
  551. xmax, ymax = max(x1, x2), max(y1, y2)
  552. points_4 = [(xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)]
  553. w, h = xmax - xmin, ymax - ymin
  554. aspect_flag_extreme = max(w / h, h / w) < 1.5 * ASPECT_FLAG # 解决极端情况
  555. rect_flag = _filter_rect(points_4)
  556. if aspect_flag_extreme and rect_flag:
  557. gen_polygon = Polygon([(points_4[0]), (points_4[1]), (points_4[2]), (points_4[3])])
  558. flags = set()
  559. for polygon in sheet_polygons:
  560. decision = [gen_polygon.contains(polygon),
  561. gen_polygon.within(polygon),
  562. gen_polygon.overlaps(polygon)]
  563. if True in decision: # 边界问题
  564. flags.add(False)
  565. break
  566. else:
  567. flags.add(True)
  568. if False in flags:
  569. pass
  570. else:
  571. return gen_polygon
  572. def _filter_none(p):
  573. if p is not None:
  574. return True
  575. points_2 = combinations(intersect_point_list, 2)
  576. gen_polygon_list = map(_find_rect, points_2)
  577. gen_polygon_list = list(filter(_filter_none, gen_polygon_list))
  578. gen_polygon_list = sorted(gen_polygon_list, key=lambda p: p.area, reverse=True)
  579. # gen_polygon_list = [polygon for index, polygon in enumerate(gen_polygon_list) if index % 2 == 0]
  580. it = itertools.groupby(gen_polygon_list)
  581. gen_polygon_list = [k for k, g in it]
  582. # 在选择题区域的infer polygon
  583. gen_choice = []
  584. for ele in gen_polygon_list:
  585. for choice_p in choice_polygon:
  586. if ele.within(choice_p):
  587. gen_choice.append(ele)
  588. sheet_box_area = sum([polygon.area for polygon in sheet_polygons])
  589. image_area = width_max * height_max
  590. blank_ratio = 1 - sheet_box_area / image_area
  591. polygon_index = 0
  592. include_polygon = []
  593. while blank_ratio > REMAIN_RATIO and polygon_index < len(gen_polygon_list):
  594. polygon = gen_polygon_list[polygon_index]
  595. blank_ratio = blank_ratio - polygon.area / image_area
  596. include_polygon.append(polygon)
  597. polygon_index += 1
  598. # gen_polygon_list = [polygon for index, polygon in enumerate(gen_polygon_list)
  599. # if polygon.area > 1.5 * min_polygon.area]
  600. for polygon in gen_polygon_list.copy():
  601. xi, yi, xx, yx = polygon.bounds
  602. w, h = xx - xi, yx - yi
  603. if polygon.area <= 1.5 * min_polygon.area or h / w > 2 and polygon.area < avg_area:
  604. gen_polygon_list.remove(polygon)
  605. polygon_2 = list(combinations(gen_polygon_list, 2))
  606. for polygons in polygon_2:
  607. try:
  608. cond2 = polygons[0].overlaps(polygons[1]) # 叠置关系二次分段
  609. if cond2:
  610. area_list = [polygons[0].area, polygons[1].area]
  611. min_index = area_list.index(min(area_list))
  612. smaller_polygon = polygons[min_index]
  613. larger_polygon = polygons[1 - min_index]
  614. new_polygon = smaller_polygon.difference(larger_polygon)
  615. if smaller_polygon in gen_polygon_list:
  616. gen_polygon_list.remove(smaller_polygon)
  617. if 'MultiPolygon' in str(type(new_polygon)):
  618. for ele in new_polygon:
  619. xm, ym, xx, yx = ele.bounds
  620. w, h = xx - xm, yx - ym
  621. if max(w / h, h / w) < 1.5 * ASPECT_FLAG and ele.area > 1.5 * min_polygon.area:
  622. gen_polygon_list.append(ele)
  623. elif len(set(new_polygon.exterior.coords)) == 4:
  624. xm, ym, xx, yx = new_polygon.bounds
  625. w, h = xx - xm, yx - ym
  626. if max(w / h, h / w) < 1.5 * ASPECT_FLAG and new_polygon.area > 1.5 * min_polygon.area:
  627. gen_polygon_list.append(new_polygon)
  628. except Exception as polygon_e:
  629. print(polygon_e)
  630. continue
  631. polygon_2 = list(combinations(gen_polygon_list, 2)) # 包含关系取大值
  632. for polygons in polygon_2:
  633. cond1 = polygons[0].equals(polygons[1])
  634. if cond1 and polygons[1] in gen_polygon_list:
  635. gen_polygon_list.remove(polygons[1])
  636. polygon_2 = list(combinations(gen_polygon_list, 2))
  637. for polygons in polygon_2:
  638. cond2 = polygons[0].contains(polygons[1]) or polygons[0].within(polygons[1])
  639. if cond2:
  640. area_list = [polygons[0].area, polygons[1].area]
  641. min_index = area_list.index(min(area_list))
  642. smaller_polygon = polygons[min_index]
  643. larger_polygon = polygons[1 - min_index]
  644. sxi, syi, sxx, syx = smaller_polygon.bounds
  645. bxi, byi, bxx, byx = larger_polygon.bounds
  646. # inner_touch_cond = '212F11FF2' == larger_polygon.relate(smaller_polygon)
  647. two_side_touch_cond = (sxi == bxi and sxx == bxx) or (syi == byi and syx == byx)
  648. if two_side_touch_cond:
  649. dif_polygon = larger_polygon.difference(smaller_polygon)
  650. if larger_polygon in gen_polygon_list:
  651. gen_polygon_list.remove(larger_polygon)
  652. if 'MultiPolygon' in str(type(dif_polygon)):
  653. for ele in dif_polygon:
  654. xm, ym, xx, yx = ele.bounds
  655. w, h = xx - xm, yx - ym
  656. if max(w / h, h / w) < 1.5 * ASPECT_FLAG and ele.area > 1.5 * min_polygon.area:
  657. gen_polygon_list.append(ele)
  658. elif len(set(dif_polygon.exterior.coords)) == 4: # empty
  659. xm, ym, xx, yx = dif_polygon.bounds
  660. w, h = xx - xm, yx - ym
  661. if max(w / h, h / w) < 1.5 * ASPECT_FLAG and dif_polygon.area > 1.5 * min_polygon.area:
  662. gen_polygon_list.append(dif_polygon)
  663. else:
  664. if smaller_polygon in gen_polygon_list:
  665. gen_polygon_list.remove(smaller_polygon)
  666. polygon_2 = list(combinations(gen_polygon_list, 2)) # 包含关系取大值
  667. for polygons in polygon_2:
  668. cond1 = polygons[0].equals(polygons[1])
  669. if cond1 and polygons[1] in gen_polygon_list:
  670. gen_polygon_list.remove(polygons[1])
  671. if len(lon_split_line) > 0:
  672. for line in lon_split_line:
  673. # line = LineString([(286, 1), (286, 599)])
  674. for poly in gen_polygon_list.copy():
  675. cond1 = line.intersects(poly)
  676. cond2 = line.touches(poly)
  677. if cond1 and not cond2:
  678. dif_polygons = poly.difference(line)
  679. corner_list = list(set(dif_polygons.exterior.coords))
  680. sorted_corner_list = sorted(corner_list, key=lambda x: x[0])
  681. if len(sorted_corner_list) == 6:
  682. left = sorted(sorted_corner_list[0:2], key=lambda x: x[1])
  683. middle = sorted(sorted_corner_list[2:4], key=lambda x: x[1])
  684. right = sorted(sorted_corner_list[4:6], key=lambda x: x[1])
  685. tmp_corner_list = [middle[0], left[0], left[1], middle[1], right[1], right[0], middle[0]]
  686. polygon1 = Polygon(tmp_corner_list[:4])
  687. polygon2 = Polygon(tmp_corner_list[3:])
  688. gen_polygon_list.remove(poly)
  689. for p in [polygon1, polygon2]:
  690. xi, yi, xx, yx = p.bounds
  691. w, h = xx - xi, yx - yi
  692. aspect_flag = max(w / h, h / w) < ASPECT_FLAG
  693. if aspect_flag:
  694. gen_polygon_list.append(p)
  695. gen_polygon_list = [polygon for index, polygon in enumerate(gen_polygon_list) if polygon.area > min_polygon.area]
  696. if gen_choice:
  697. gen_choice = sorted(gen_choice, key=lambda x: x.area)[-1]
  698. gen_polygon_list.append(gen_choice)
  699. return gen_polygon_list
  700. def infer_class(image, sheet_dict_list, infer_polygon, image_cols, ocr_dict_list=''):
  701. res = []
  702. all_type_score_polygon = []
  703. all_choice_polygon = []
  704. all_cloze_polygon = []
  705. all_solve_polygon = []
  706. all_choice_s_width = []
  707. for region_box in sheet_dict_list:
  708. if region_box['class_name'] in ['type_score', 'choice', 'cloze', 'solve', 'choice_s']:
  709. coordinates = region_box['bounding_box']
  710. xmin = coordinates['xmin']
  711. ymin = coordinates['ymin']
  712. xmax = coordinates['xmax']
  713. ymax = coordinates['ymax']
  714. box_polygon = Polygon([(xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)])
  715. if region_box['class_name'] == 'type_score':
  716. all_type_score_polygon.append(box_polygon)
  717. if region_box['class_name'] == 'choice':
  718. all_choice_polygon.append(box_polygon)
  719. if region_box['class_name'] == 'cloze':
  720. all_cloze_polygon.append(box_polygon)
  721. if region_box['class_name'] == 'solve':
  722. all_solve_polygon.append(box_polygon)
  723. if region_box['class_name'] == 'choice_s':
  724. all_choice_s_width.append(int(xmax)-int(xmin))
  725. for poly in infer_polygon.copy(): # infer type_score solve
  726. p_xmin, p_ymin, p_xmax, p_ymax = poly.bounds
  727. type_score_num = 0
  728. type_score_ymin = []
  729. for type_score_polygon in all_type_score_polygon:
  730. cond1 = type_score_polygon.within(poly)
  731. cond2 = False
  732. cond3 = type_score_polygon.overlaps(poly)
  733. if cond3:
  734. intersection_poly = type_score_polygon.intersection(poly)
  735. d1 = intersection_poly.area / type_score_polygon.area >= TYPE_SCORE_MNS
  736. print('type_score:', intersection_poly.area / type_score_polygon.area)
  737. d2 = type_score_polygon.area < 0.2 * poly.area
  738. cond2 = d1 and d2
  739. if cond1 or cond2:
  740. type_score_num += 1
  741. t_xmin, t_ymin, t_xmax, t_ymax = type_score_polygon.bounds
  742. type_score_ymin.append(t_ymin)
  743. t_height = t_ymax - t_ymin
  744. if t_ymin - p_ymin > 3 * t_height:
  745. type_score_num += 1
  746. type_score_ymin.append(p_ymin)
  747. if type_score_num == 1:
  748. in_xmin, in_ymin, in_xmax, in_ymax = poly.bounds
  749. solve_box = {'class_name': 'solve',
  750. 'bounding_box': {'xmin': int(in_xmin), 'ymin': int(in_ymin),
  751. 'xmax': int(in_xmax), 'ymax': int(in_ymax)}}
  752. sheet_dict_list.append(solve_box)
  753. infer_polygon.remove(poly)
  754. res.append(solve_box)
  755. if type_score_num > 1: # 多type_score
  756. type_score_ymin = sorted(type_score_ymin)
  757. type_score_ymin[0] = min(p_ymin, type_score_ymin[0])
  758. type_score_ymin.append(p_ymax)
  759. for i in range(0, len(type_score_ymin) - 1):
  760. w = p_xmax - p_xmin
  761. h = type_score_ymin[i + 1] - type_score_ymin[i]
  762. if max(w / h, h / w) < ASPECT_FLAG:
  763. solve_box = {'class_name': 'solve',
  764. 'bounding_box': {'xmin': int(p_xmin), 'ymin': int(type_score_ymin[i]),
  765. 'xmax': int(p_xmax), 'ymax': int(type_score_ymin[i + 1])}}
  766. sheet_dict_list.append(solve_box)
  767. res.append(solve_box)
  768. infer_polygon.remove(poly)
  769. # for poly in infer_polygon.copy(): # infer choice_m
  770. # for choice_polygon in all_choice_polygon:
  771. # cond1 = choice_polygon.within(poly) or choice_polygon.contains(poly)
  772. # cond2 = False
  773. # cond3 = choice_polygon.overlaps(poly)
  774. # if cond3:
  775. # intersection_poly = choice_polygon.intersection(poly)
  776. # cond2 = intersection_poly.area / poly.area >= 0.8
  777. #
  778. # if cond1 or cond2:
  779. # in_xmin, in_ymin, in_xmax, in_ymax = poly.bounds
  780. # choice_m_img = crop_region_direct(image, (int(in_xmin), int(in_ymin),
  781. # int(in_xmax), int(in_ymax)))
  782. # # cv2.imshow('m', choice_m_img)
  783. # # cv2.waitKey(0)
  784. # ocr_res = get_ocr_text_and_coordinate(choice_m_img)
  785. # char_a_min = []
  786. # char_d_max = []
  787. # for index, chars in enumerate(ocr_res):
  788. # for char in chars['chars']:
  789. # left, top = char['location']['left'], char['location']['top']
  790. # width, height = char['location']['width'], char['location']['height']
  791. # if char['char'] in 'abcdlABCD[]aabbccddAABBCCDD[[]]':
  792. # xm, ym = int(left - width / 2), int(top - height / 2)
  793. # char_a_min.append((xm, ym))
  794. # xx, yx = int(left + 3 * width / 2), int(top + 3 * height / 2)
  795. # char_d_max.append((xx, yx))
  796. # if char_a_min and char_d_max:
  797. # char_a_min_arr, char_d_max_arr = np.array(char_a_min), np.array(char_d_max)
  798. # tmp_min = np.min(char_a_min_arr, axis=0)
  799. # tmp_max = np.max(char_d_max_arr, axis=0)
  800. #
  801. # m_xmin, m_ymin, m_xmax, m_ymax = tmp_min[0], tmp_min[1], tmp_max[0], tmp_max[1]
  802. # dif_width = sum(all_choice_s_width) // len(all_choice_s_width) - (m_xmax - m_xmin)
  803. # choice_box = {'class_name': 'choice_m',
  804. # 'bounding_box': {'xmin': int(m_xmin) + int(in_xmin) - dif_width // 2,
  805. # 'ymin': int(m_ymin) + int(in_ymin),
  806. # 'xmax': int(m_xmax) + int(in_xmin) + dif_width // 2,
  807. # 'ymax': int(m_ymax) + int(in_ymin)
  808. # }}
  809. #
  810. # sheet_dict_list.append(choice_box)
  811. # infer_polygon.remove(poly)
  812. # res.append(choice_box)
  813. # break
  814. for poly in infer_polygon.copy(): # infer ocr blank
  815. flag = []
  816. for ocr in ocr_dict_list:
  817. location = ocr['location']
  818. xmin = location['left']
  819. ymin = location['top']
  820. xmax = location['left'] + location['width']
  821. ymax = location['top'] + location['height']
  822. box_polygon = Polygon([(xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)])
  823. cond1 = poly.within(box_polygon) or poly.contains(box_polygon)
  824. cond2 = False
  825. cond3 = box_polygon.overlaps(poly)
  826. if cond3:
  827. intersection_poly = box_polygon.intersection(poly)
  828. cond2 = intersection_poly.area / poly.area >= 0.2
  829. flag.append(cond1 or cond2 or False) # True 不是blank
  830. if True not in flag:
  831. in_xmin, in_ymin, in_xmax, in_ymax = poly.bounds
  832. blank_box = {'class_name': 'blank',
  833. 'bounding_box': {'xmin': int(in_xmin), 'ymin': int(in_ymin),
  834. 'xmax': int(in_xmax), 'ymax': int(in_ymax)}}
  835. # sheet_dict_list.append(solve_box)
  836. infer_polygon.remove(poly)
  837. res.append(blank_box)
  838. for poly in infer_polygon.copy(): # infer blank
  839. bounds = [int(ele) for ele in poly.bounds]
  840. img_region = crop_region_direct(image, bounds)
  841. img = cv2.threshold(img_region, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
  842. img_mean = np.mean(img)
  843. img_raw_mean = np.mean(img_region)
  844. # print(img_mean, img_raw_mean)
  845. cond = img_mean < PIX_VALUE_LOW or img_raw_mean > PIX_VALUE_HIGH
  846. if cond:
  847. in_xmin, in_ymin, in_xmax, in_ymax = bounds
  848. blank_box = {'class_name': 'blank',
  849. 'bounding_box': {'xmin': int(in_xmin), 'ymin': int(in_ymin),
  850. 'xmax': int(in_xmax), 'ymax': int(in_ymax)}}
  851. # sheet_dict_list.append(solve_box)
  852. infer_polygon.remove(poly)
  853. res.append(blank_box)
  854. # for poly in infer_polygon.copy(): # infer cloze_s
  855. # for cloze_polygon in all_cloze_polygon:
  856. # cond1 = cloze_polygon.within(poly) or cloze_polygon.contains(poly)
  857. # cond2 = False
  858. # cond3 = cloze_polygon.overlaps(poly)
  859. # if cond3:
  860. # intersection_poly = cloze_polygon.intersection(poly)
  861. # cond2 = intersection_poly.area / poly.area >= 0.8
  862. #
  863. # if cond1 or cond2:
  864. # in_xmin, in_ymin, in_xmax, in_ymax = poly.bounds
  865. # solve_box = {'class_name': 'cloze_s',
  866. # 'bounding_box': {'xmin': int(in_xmin), 'ymin': int(in_ymin),
  867. # 'xmax': int(in_xmax), 'ymax': int(in_ymax)}}
  868. #
  869. # sheet_dict_list.append(solve_box)
  870. # infer_polygon.remove(poly)
  871. # res.append(solve_box)
  872. # break
  873. for poly in infer_polygon.copy(): # infer solve
  874. in_xmin, in_ymin, in_xmax, in_ymax = poly.bounds
  875. w, h = in_xmax - in_xmin, in_ymax - in_ymin
  876. aspect_flag = max(w / h, h / w) < ASPECT_FLAG
  877. if aspect_flag:
  878. solve_box = {'class_name': 'solve_infer',
  879. 'bounding_box': {'xmin': int(in_xmin), 'ymin': int(in_ymin),
  880. 'xmax': int(in_xmax), 'ymax': int(in_ymax)}}
  881. else:
  882. solve_box = {'class_name': 'blank',
  883. 'bounding_box': {'xmin': int(in_xmin), 'ymin': int(in_ymin),
  884. 'xmax': int(in_xmax), 'ymax': int(in_ymax)}}
  885. sheet_dict_list.append(solve_box)
  886. infer_polygon.remove(poly)
  887. res.append(solve_box)
  888. if all_type_score_polygon:
  889. type_score_area = sum([ele.area for ele in all_type_score_polygon])
  890. mean_type_score_area = type_score_area/len(all_type_score_polygon)
  891. solve_filter = []
  892. for index, sheet_box in enumerate(sheet_dict_list.copy()):
  893. if sheet_box['class_name'] == 'solve_infer':
  894. w = sheet_box['bounding_box']['xmax'] - sheet_box['bounding_box']['xmin']
  895. h = sheet_box['bounding_box']['ymin'] - sheet_box['bounding_box']['ymin']
  896. if w * h < mean_type_score_area * 3:
  897. sheet_dict_list.remove(sheet_box)
  898. for ele in sheet_dict_list:
  899. if ele['class_name'] == 'solve_infer':
  900. ele.update({'class_name': 'solve'})
  901. return sheet_dict_list
  902. def box_infer_and_complete(image, sheet_region_dict, ocr=''):
  903. if len(image.shape) == 3:
  904. image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
  905. if len(image.shape) == 4:
  906. image = cv2.cvtColor(image, cv2.COLOR_RGBA2GRAY)
  907. exclude_classes = [
  908. 'cloze_s',
  909. 'exam_number_s',
  910. 'type_score',
  911. 'page',
  912. 'alarm_info',
  913. # 'score_collect',
  914. 'choice_s',
  915. ]
  916. y, x = image.shape[0], image.shape[1]
  917. x1, x2 = subfield_answer_sheet(image, sheet_region_dict)
  918. # lon_split_line = []
  919. lon_split_line = [LineString([(px, 1), (px, y - 1)]) for px in [x1, x2] if px != 0]
  920. split_line_poly = [(px, 1, px + 1, y - 1) for px in [x1, x2] if px != 0]
  921. poly_list = infer_sheet_box(image, sheet_region_dict, lon_split_line, exclude_classes)
  922. image_cols = len(lon_split_line) + 1
  923. sheet_region_dict = infer_class(image, sheet_region_dict, poly_list, image_cols, ocr)
  924. return sheet_region_dict
  925. # 选择题区域补全
  926. def _get_split_index(sorted_list, spilt_value):
  927. y_dif_list = np.array(sorted_list[1:]) - np.array(sorted_list[:-1])
  928. y_split_index = [index for index, ele in enumerate(y_dif_list) if ele >= spilt_value]
  929. y_split_index = [ele + 1 for ele in y_split_index] # 索引值扩大
  930. y_split_index.insert(0, 0)
  931. y_split_index.insert(-1, len(sorted_list))
  932. y_split_index = sorted(list(set(y_split_index)))
  933. return y_split_index
  934. def get_letter_group(letter, location_list):
  935. y_list = sorted([ele['location']['top'] for ele in location_list])
  936. height = np.mean(np.array([ele['location']['height'] for ele in location_list]))
  937. width = np.mean(np.array([ele['location']['width'] for ele in location_list]))
  938. y_split_dif, x_split_dif = height * 1.5, width * 1.5
  939. y_split_index = _get_split_index(y_list, y_split_dif)
  940. letter_group_list = []
  941. letter_group_location_list = []
  942. for i, split in enumerate(y_split_index[1:]):
  943. one_group_location_list = location_list[y_split_index[i]:y_split_index[i + 1]]
  944. one_group_x_list = sorted([ele['location']['top'] for ele in one_group_location_list])
  945. one_group_x_split_index = _get_split_index(one_group_x_list, x_split_dif)
  946. block = []
  947. block_location = []
  948. for i_i, s_split in enumerate(one_group_x_split_index[1:]):
  949. letter_group = one_group_location_list[one_group_x_split_index[i_i]:
  950. one_group_x_split_index[i_i + 1]]
  951. letter_group = sorted(letter_group, key=lambda k: k.get('location')['top'])
  952. xmin = min([ele['location']['left'] for ele in letter_group])
  953. ymin = min([ele['location']['top'] for ele in letter_group])
  954. xmax = max([ele['location']['left'] for ele in letter_group]) + width
  955. ymax = max([ele['location']['top'] for ele in letter_group]) + height
  956. middle_x, middle_y = (xmax - xmin) / 2 + xmin, (ymax - ymin) / 2 + ymin
  957. block_location.append((xmin, ymin, xmax, ymax, middle_x, middle_y))
  958. block.append(letter_group)
  959. letter_group_list.append(block)
  960. letter_group_location_list.append(block_location)
  961. res_dict = {'letter': letter,
  962. 'letter_group': letter_group_list,
  963. 'letter_group_location': letter_group_location_list,
  964. 'width': width, 'height': height}
  965. return res_dict
  966. def get_letter_group_h(letter, location_list):
  967. location_list = sorted(location_list, key=lambda k: k.get('location')['left'])
  968. x_list = sorted([ele['location']['left'] for ele in location_list])
  969. height = np.mean(np.array([ele['location']['height'] for ele in location_list]), dtype=np.uint)
  970. width = np.mean(np.array([ele['location']['width'] for ele in location_list]), dtype=np.uint)
  971. print('h, w: ', height, width)
  972. y_split_dif, x_split_dif = height * 1.5, width * 1.5
  973. x_split_index = _get_split_index(x_list, x_split_dif)
  974. letter_group_location_list = []
  975. for i, split in enumerate(x_split_index[1:]):
  976. one_group_location_list = location_list[x_split_index[i]:x_split_index[i + 1]]
  977. one_group_location_list = sorted(one_group_location_list, key=lambda k: k.get('location')['top'])
  978. xmin = min([ele['location']['left'] for ele in one_group_location_list])
  979. ymin = one_group_location_list[0]['location']['top']
  980. xmax = xmin + width
  981. ymax = one_group_location_list[-1]['location']['top'] + 2*one_group_location_list[-1]['location']['height']
  982. letter_group_location_list.append((xmin - 2*width, ymin,
  983. xmax + 2*width, ymax))
  984. return {'letter': letter, 'group_location': letter_group_location_list}
  985. def infer_choice_m_by_ocr(ocr_dict_list):
  986. # 若字母识别漏掉结果太多, 此方法不能使用
  987. a_e = 'ABCDEF'
  988. pattern = '[ABCDEF]'
  989. a_e_dict = {k: [] for k in a_e}
  990. block_num = 1 # default
  991. for i, ele in enumerate(ocr_dict_list):
  992. words = ele['words']
  993. cal_num = max([words.upper().count(char) for char in a_e])
  994. if cal_num > 0:
  995. words = words.replace(' ', '').upper() # 去除空格,baidu_api bug
  996. abcd_words_m = re.finditer(pattern, words)
  997. abcd_index_list = [(m.group(), m.span()) for m in abcd_words_m if m]
  998. for letter_info in abcd_index_list:
  999. letter = letter_info[0]
  1000. a_e_dict[letter].append(ele['chars'][letter_info[1][0]])
  1001. letter_group_list = []
  1002. for k, v in a_e_dict.items():
  1003. if v:
  1004. letter_group = get_letter_group_h(k, v)
  1005. block_num = max(block_num, len(letter_group['group_location']))
  1006. print(letter_group)
  1007. letter_group_list.append(letter_group)
  1008. choice_m_list = []
  1009. for i in range(0, block_num):
  1010. block = []
  1011. for letter_group in letter_group_list:
  1012. if len(letter_group['group_location']) > i:
  1013. block.append(letter_group['group_location'][i])
  1014. if block:
  1015. block_array = np.asarray(block)
  1016. b_min = np.min(block_array, axis=0)
  1017. b_max = np.max(block_array, axis=0)
  1018. choice_m_dict = {'class_name': 'choice_m',
  1019. 'location': {'xmin': b_min[0], 'ymin': b_min[1],
  1020. 'xmax': b_max[2], 'ymax': b_max[3]}}
  1021. choice_m_list.append(choice_m_dict)
  1022. # print(choice_m_list)
  1023. return choice_m_list