resolve.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529
  1. # @Author : lightXu
  2. # @File : resolve.py
  3. # @Time : 2018/12/3 0003 上午 10:16
  4. import traceback
  5. import xml.etree.cElementTree as ET
  6. from django.conf import settings
  7. import segment.logging_config as logging
  8. import segment.sheet_resolve.analysis.choice.analysis_choice as resolve_choice
  9. import segment.sheet_resolve.analysis.choice.choice_line_box as choice_line_box
  10. import segment.sheet_resolve.analysis.cloze.analysis_cloze as resolve_cloze
  11. import segment.sheet_resolve.analysis.cloze.cloze_line_box as resolve_cloze_line_box
  12. import segment.sheet_resolve.analysis.exam_number.exam_number_box as resolve_exam_number_box
  13. import segment.sheet_resolve.analysis.exam_number.exam_number_row_column as exam_number_row_column
  14. import segment.sheet_resolve.analysis.sheet.analysis_sheet as resolve_sheet
  15. import segment.sheet_resolve.analysis.solve.mark_box as resolve_mark_box
  16. import segment.sheet_resolve.analysis.solve.mark_line_box as resolve_mark_line_box
  17. from segment.sheet_resolve.analysis.sheet.choice_infer import infer_choice_m
  18. from segment.sheet_resolve.analysis.sheet.ocr_sheet import tell_columns, sheet_sorted
  19. from segment.sheet_resolve.analysis.sheet.sheet_adjust import adjust_item_edge_by_gray_image
  20. from segment.sheet_resolve.analysis.sheet.sheet_infer import infer_bar_code, adjust_exam_number, exam_number_infer_by_s
  21. from segment.sheet_resolve.analysis.sheet.sheet_infer import infer_exam_number, infer_solve, box_infer_and_complete
  22. from segment.sheet_resolve.analysis.sheet.sheet_infer import exam_number_adjust_infer
  23. from segment.sheet_resolve.tools import utils
  24. from segment.sheet_resolve.tools.tf_settings import xml_template_path, model_dict
  25. from segment.sheet_resolve.tools.utils import create_xml
  26. from segment.sheet_resolve.analysis.sheet.ocr_sheet import ocr2sheet
  27. logger = logging.getLogger(settings.LOGGING_TYPE)
  28. sheet_infer_dict = dict(bar_code=True,
  29. choice_m=True,
  30. exam_number=True,
  31. common_sheet=False,
  32. cloze=True,
  33. solve=True)
  34. infer_choice_m_flag = False
  35. def sheet(series_number, image_path, image, conf_thresh, mns_thresh, subject, sheet_sess, ocr=''):
  36. global infer_choice_m_flag
  37. model_type = subject
  38. classes = list(model_dict[model_type]['classes'])
  39. coordinate_bias_dict = model_dict[model_type]['class_coordinate_bias']
  40. sheets_dict = resolve_sheet.get_single_image_sheet_regions(model_type, image_path, image, classes,
  41. sheet_sess.sess, sheet_sess.net,
  42. conf_thresh, mns_thresh, coordinate_bias_dict)
  43. h, w = image.shape[0], image.shape[1]
  44. regions = sheets_dict['regions']
  45. fetched_class = [ele['class_name'] for ele in regions]
  46. infer_box_list = []
  47. try:
  48. regions = adjust_item_edge_by_gray_image(image, regions)
  49. except Exception as e:
  50. traceback.print_exc()
  51. logger.info('试卷:{} 自适应边框失败: {}'.format(image_path, e))
  52. # 分栏
  53. col_split_x = tell_columns(image, regions)
  54. if sheet_infer_dict['bar_code']:
  55. try:
  56. if ('bar_code' not in fetched_class) and ocr:
  57. attention_region = [ele for ele in regions if ele['class_name'] == 'attention']
  58. bar_code_list = infer_bar_code(image, ocr, attention_region)
  59. regions.extend(bar_code_list)
  60. except Exception as e:
  61. traceback.print_exc()
  62. logger.info('试卷:{} 条形码推断失败: {}'.format(image_path, e))
  63. if sheet_infer_dict['exam_number']:
  64. try:
  65. cond1 = 'exam_number' in fetched_class
  66. tmp = ['info_title', 'qr_code', 'bar_code', 'choice', 'choice_m', 'exam_number_w']
  67. cond2 = True in [True for ele in tmp if ele in fetched_class] # 第一面特征
  68. cond3 = 'exam_number_w' in fetched_class
  69. cond4 = 'exam_number_s' in fetched_class
  70. # if cond1 and cond3 and not cond4:
  71. if cond1 and cond3:
  72. regions = adjust_exam_number(regions)
  73. if not cond1 and cond4:
  74. exam_number_list = exam_number_infer_by_s(image, regions)
  75. regions.extend(exam_number_list)
  76. if not cond1 and not cond4 and cond2 and ocr:
  77. exam_number_list = infer_exam_number(image, ocr, regions)
  78. if len(exam_number_list) > 0:
  79. regions.extend(exam_number_list)
  80. image, regions = exam_number_adjust_infer(image, regions)
  81. except Exception as e:
  82. traceback.print_exc()
  83. logger.info('试卷:{} 考号推断失败: {}'.format(image_path, e))
  84. if sheet_infer_dict['choice_m']:
  85. try:
  86. col_split = col_split_x.copy()
  87. if not col_split:
  88. col_split = [w - 1]
  89. infer_box_list = ocr2sheet(image, col_split_x, ocr)
  90. choice_m_list = infer_choice_m(image, regions, infer_box_list, col_split)
  91. if len(choice_m_list) > 0:
  92. choice_m_old_list = [ele for ele in regions if 'choice_m' == ele['class_name']]
  93. for infer_box in choice_m_list.copy():
  94. infer_loc = infer_box['bounding_box']
  95. for tf_box in choice_m_old_list:
  96. tf_loc = tf_box['bounding_box']
  97. iou = utils.cal_iou(infer_loc, tf_loc)
  98. # if iou[0] > 0.70 or iou[1] > 0.70 or iou[2] > 0.70:
  99. # if iou[0] > 0.70 or iou[2] > 0.70:
  100. if iou[0] > 0.85:
  101. # if infer_box not in remain_choice_m:
  102. # remain_choice_m.append(infer_box)
  103. # choice_m_list.remove(infer_box)
  104. regions.remove(tf_box)
  105. # break
  106. elif iou[0] > 0.05:
  107. choice_m_list.remove(infer_box)
  108. break
  109. # remain_choice_m.extend(choice_m_list)
  110. # regions = [ele for ele in regions if 'choice_m' != ele['class_name']]
  111. # regions.extend(remain_choice_m)
  112. regions.extend(choice_m_list)
  113. infer_choice_m_flag = True
  114. except Exception as e:
  115. traceback.print_exc()
  116. logger.info('试卷:{} 选择题推断失败: {}'.format(image_path, e))
  117. if sheet_infer_dict['cloze']:
  118. cloze_list = []
  119. for infer_box in infer_box_list:
  120. # {'loc': [240, 786, 1569, 1368]}
  121. loc = infer_box['loc']
  122. xmin, ymin, xmax, ymax = loc[0], loc[1], loc[2], loc[3]
  123. for ele in regions:
  124. if ele['class_name'] in ['cloze_s', 'cloze']:
  125. tf_loc = ele['bounding_box']
  126. tf_loc_l = tf_loc['xmin']
  127. tf_loc_t = tf_loc['ymin']
  128. if xmin < tf_loc_l < xmax and ymin < tf_loc_t < ymax:
  129. cloze_box = {'class_name': 'cloze',
  130. 'bounding_box': {'xmin': xmin, 'ymin': ymin, 'xmax': xmax, 'ymax': ymax}}
  131. cloze_list.append(cloze_box)
  132. break
  133. regions.extend(cloze_list)
  134. if sheet_infer_dict['solve']:
  135. try:
  136. include_class = ['info_title',
  137. 'bar_code',
  138. 'choice_m',
  139. 'cloze',
  140. 'cloze_s',
  141. 'exam_number',
  142. 'solve',
  143. 'solve0',
  144. 'composition',
  145. 'composition0',
  146. 'correction',
  147. 'alarm_info',
  148. 'page',
  149. 'mark'
  150. ]
  151. if 'math' not in subject:
  152. include_class.remove('cloze_s')
  153. regions_subset = [ele for ele in regions if ele['class_name'] in include_class]
  154. col_split = col_split_x.copy()
  155. if not col_split:
  156. col_split = [w - 1]
  157. col_regions = sheet_sorted(regions_subset, col_split.copy())
  158. top = min([ele['bounding_box']['ymin'] for ele in regions if 'seal' not in ele['class_name']])
  159. bottom = max([ele['bounding_box']['ymax'] for ele in regions if 'seal' not in ele['class_name']])
  160. seal_area = [ele for ele in regions if 'seal' in ele['class_name']]
  161. if len(seal_area) > 0:
  162. right, left = w, 1
  163. for ele in seal_area:
  164. if ele['bounding_box']['xmax'] > w // 2:
  165. right = ele['bounding_box']['xmin']
  166. if ele['bounding_box']['xmax'] < w // 2:
  167. left = ele['bounding_box']['xmax']
  168. else:
  169. left = min([ele['bounding_box']['xmin'] for ele in regions])
  170. right = max([ele['bounding_box']['xmax'] for ele in regions])
  171. solve_regions = infer_solve(regions, left, right, top, bottom, h, w, col_regions, col_split_x.copy())
  172. regions.extend(solve_regions)
  173. except Exception as e:
  174. traceback.print_exc()
  175. logger.info('试卷:{} 解答题补全推断失败: {}'.format(image_path, e))
  176. if sheet_infer_dict['common_sheet']:
  177. try:
  178. regions = box_infer_and_complete(image, regions, ocr)
  179. except Exception as e:
  180. traceback.print_exc()
  181. logger.info('试卷:{} 识别框补全推断失败: {}'.format(image_path, e))
  182. try:
  183. adjust_regions = adjust_item_edge_by_gray_image(image, regions)
  184. except Exception as e:
  185. adjust_regions = regions
  186. traceback.print_exc()
  187. logger.info('试卷:{} 自适应边框失败: {}'.format(image_path, e))
  188. sheets_dict.update({'regions': adjust_regions, 'col_split': col_split_x})
  189. # generate xml
  190. tree = ET.parse(xml_template_path)
  191. xml_save_path = sheets_dict['img_name'].replace('.jpg', '.xml')
  192. root = tree.getroot()
  193. series = ET.SubElement(root, 'paper_id')
  194. series.text = series_number
  195. img_shape = image.shape
  196. project = ET.SubElement(root, 'size', {})
  197. width = ET.SubElement(project, 'width')
  198. width.text = str(img_shape[1])
  199. height = ET.SubElement(project, 'height')
  200. height.text = str(img_shape[0])
  201. depth = ET.SubElement(project, 'depth')
  202. if len(img_shape) >= 3:
  203. depth.text = '3'
  204. else:
  205. depth.text = '1'
  206. # 没有adjust 的regions
  207. for ele in regions:
  208. name = ele['class_name']
  209. xmin = ele['bounding_box']['xmin']
  210. ymin = ele['bounding_box']['ymin']
  211. xmax = ele['bounding_box']['xmax']
  212. ymax = ele['bounding_box']['ymax']
  213. tree = create_xml(name, tree, xmin, ymin, xmax, ymax)
  214. for i, ele in enumerate(sheets_dict['col_split']):
  215. name = 'line_{}'.format(i)
  216. xmin = ele
  217. ymin = 1
  218. xmax = ele + 2
  219. ymax = img_shape[0]
  220. tree = create_xml(name, tree, xmin, ymin, xmax, ymax)
  221. tree.write(xml_save_path)
  222. return sheets_dict, xml_save_path
  223. def choice(image, regions, xml_path, conf_thresh, mns_thresh, choice_sess):
  224. model_type = 'choice'
  225. classes = model_dict[model_type]['classes']
  226. coordinate_bias_dict = model_dict[model_type]['class_coordinate_bias']
  227. choice_list = []
  228. for ele in regions:
  229. if ele["class_name"] == 'choice':
  230. choice_bbox = ele['bounding_box']
  231. left = choice_bbox['xmin']
  232. top = choice_bbox['ymin']
  233. choice_img = utils.crop_region(image, choice_bbox)
  234. choice_dict_tf = resolve_choice.get_single_image_sheet_regions('choice', choice_img, classes,
  235. choice_sess.sess, choice_sess.net,
  236. conf_thresh, mns_thresh,
  237. coordinate_bias_dict)
  238. choice_list = choice_list + choice_line_box.choice_line(left, top, choice_img, choice_dict_tf, xml_path)
  239. return choice_list
  240. def choice_row_col(image, regions, xml_path, conf_thresh, mns_thresh, choice_sess):
  241. model_type = 'choice_m'
  242. classes = model_dict[model_type]['classes']
  243. coordinate_bias_dict = model_dict[model_type]['class_coordinate_bias']
  244. choice_list = []
  245. for ele in regions:
  246. if ele["class_name"] == 'choice':
  247. choice_box = ele['bounding_box']
  248. left = choice_box['xmin']
  249. top = choice_box['ymin']
  250. choice_img = utils.crop_region(image, choice_box)
  251. choice_m_dict_tf = resolve_choice.get_single_image_sheet_regions('choice_m', choice_img, classes,
  252. choice_sess.sess, choice_sess.net,
  253. conf_thresh, mns_thresh,
  254. coordinate_bias_dict)
  255. choice_list = choice_list + choice_line_box.choice_line_with_number(left, top, choice_img,
  256. choice_m_dict_tf,
  257. xml_path)
  258. return choice_list
  259. def choice_m_row_col(image, regions, subject, xml_path):
  260. choice_m_dict_tf = []
  261. direction_list = []
  262. for ele in regions:
  263. if ele['class_name'] == 'choice_m':
  264. choice_m_dict_tf.append(ele)
  265. if ele['class_name'] == 'choice_n':
  266. loc = ele['bounding_box']
  267. xmin, ymin, xmax, ymax = loc['xmin'], loc['ymin'], loc['xmax'], loc['ymax']
  268. if ymax - ymin > 2 * (xmax - xmin):
  269. direction = 180
  270. else:
  271. direction = 90
  272. direction_list.append(direction)
  273. if ele['class_name'] == 'choice_s':
  274. loc = ele['bounding_box']
  275. xmin, ymin, xmax, ymax = loc['xmin'], loc['ymin'], loc['xmax'], loc['ymax']
  276. if ymax - ymin > 2 * (xmax - xmin):
  277. direction = 90
  278. else:
  279. direction = 180
  280. direction_list.append(direction)
  281. c180 = direction_list.count(180)
  282. c90 = direction_list.count(90)
  283. if c180 == c90 == 0:
  284. direction = 0
  285. else:
  286. direction = 180 if c180 >= c90 else 90
  287. # choice_m_row_col_with_number
  288. choice_list = []
  289. try:
  290. # choice_list = choice_box.get_number_by_enlarge_choice_m(image, choice_m_dict_tf, xml_path)
  291. # if infer_choice_m_flag:
  292. # choice_list = choice_line_box.choice_m_adjust(image, choice_m_dict_tf)
  293. #
  294. # else:
  295. # choice_list = choice_line_box.choice_m_row_col(image, choice_m_dict_tf, xml_path) # 找选择题行列、分数
  296. choice_list = choice_line_box.choice_m_row_col(image, choice_m_dict_tf, direction, subject,
  297. xml_path) # 找选择题行列、分数
  298. tree = ET.parse(xml_path) # xml tree
  299. for index_num, box in enumerate(choice_list):
  300. if len(box['bounding_box']) > 0:
  301. abcd = box['bounding_box']
  302. number = str(box['number'])
  303. name = '{}_{}*{}_{}_{}'.format('choice_m', box['rows'], box['cols'], box['direction'], number)
  304. tree = utils.create_xml(name, tree,
  305. abcd['xmin'], abcd['ymin'],
  306. abcd['xmax'], abcd['ymax'])
  307. tree.write(xml_path)
  308. except Exception as e:
  309. traceback.print_exc()
  310. print(e)
  311. return choice_list
  312. def exam_number(image, regions, xml_path):
  313. exam_number_dict = {}
  314. for ele in regions:
  315. if ele["class_name"] == 'exam_number':
  316. exam_number_dict = ele
  317. exam_number_box = exam_number_dict['bounding_box']
  318. left = exam_number_box['xmin']
  319. top = exam_number_box['ymin']
  320. exam_number_img = utils.crop_region(image, exam_number_box)
  321. # exam_number_dict = resolve_exam_number_box.exam_number(left, top, exam_number_img, xml_path)
  322. exam_number_dict = resolve_exam_number_box.exam_number_whole(left, top, exam_number_img, xml_path)
  323. # print(exam_number_dict)
  324. return exam_number_dict
  325. def exam_number_row_col(image, regions, xml_path):
  326. exam_number_dict = {}
  327. for ele in regions:
  328. if ele["class_name"] == 'exam_number':
  329. exam_number_dict = ele
  330. exam_number_box = exam_number_dict['bounding_box']
  331. left = exam_number_box['xmin']
  332. top = exam_number_box['ymin']
  333. exam_number_img = utils.crop_region(image, exam_number_box)
  334. exam_number_row_col_dict = exam_number_row_column.get_exam_number_row_and_col(left, top, exam_number_img)
  335. tree = ET.parse(xml_path) # xml tree
  336. if len(exam_number_row_col_dict) > 0:
  337. exam_number_box = exam_number_row_col_dict['bounding_box']
  338. name = '{}_{}*{}_{}'.format('exam_number',
  339. exam_number_row_col_dict['rows'],
  340. exam_number_row_col_dict['cols'],
  341. exam_number_row_col_dict['direction'])
  342. tree = utils.create_xml(name, tree,
  343. exam_number_box['xmin'], exam_number_box['ymin'],
  344. exam_number_box['xmax'], exam_number_box['ymax'])
  345. tree.write(xml_path)
  346. return [exam_number_row_col_dict]
  347. else:
  348. tree = utils.create_xml('exam_number', tree,
  349. exam_number_box['xmin'], exam_number_box['ymin'],
  350. exam_number_box['xmax'], exam_number_box['ymax'])
  351. tree.write(xml_path)
  352. return []
  353. def cloze(image, regions, xml_path, conf_thresh, mns_thresh, cloze_sess):
  354. classes = model_dict['cloze']['classes']
  355. coordinate_bias_dict = model_dict['cloze']['class_coordinate_bias']
  356. cloze_list = []
  357. for ele in regions:
  358. if ele["class_name"] == 'cloze':
  359. cloze_box = ele['bounding_box']
  360. left = cloze_box['xmin']
  361. top = cloze_box['ymin']
  362. cloze_img = utils.crop_region(image, cloze_box)
  363. cloze_dict_tf = resolve_cloze.get_single_image_sheet_regions('cloze', cloze_img, classes,
  364. cloze_sess.sess, cloze_sess.net, conf_thresh,
  365. mns_thresh, coordinate_bias_dict)
  366. cloze_list = cloze_list + resolve_cloze_line_box.cloze_line(left, top, cloze_img, cloze_dict_tf['regions'],
  367. xml_path)
  368. return cloze_list
  369. def solve_with_mark(image, regions, xml_path):
  370. solve_list = []
  371. mark_list = []
  372. for ele in regions.copy():
  373. if 'solve' in ele["class_name"]:
  374. exam_number_box = ele['bounding_box']
  375. left = exam_number_box['xmin']
  376. top = exam_number_box['ymin']
  377. exam_number_img = utils.crop_region(image, exam_number_box)
  378. solve_mark_dict = resolve_mark_box.solve_mark(left, top, exam_number_img, xml_path)
  379. if len(solve_mark_dict) > 0:
  380. ele['class_name'] = 'solve_' + str(solve_mark_dict['number'])
  381. solve_list.append(ele)
  382. mark_list.append(solve_mark_dict)
  383. return solve_list, mark_list
  384. def solve(image, regions, xml_path):
  385. solve_list = []
  386. tree = ET.parse(xml_path)
  387. for ele in regions.copy():
  388. if 'solve' in ele["class_name"]:
  389. exam_number_box = ele['bounding_box']
  390. exam_number_img = utils.crop_region(image, exam_number_box)
  391. number = resolve_mark_line_box.solve_line(exam_number_img)
  392. solve_dict = {'number': number, 'location': exam_number_box, 'default_points': 12}
  393. solve_list.append(solve_dict)
  394. tree = utils.create_xml(str(number), tree,
  395. exam_number_box['xmin'], exam_number_box['ymin'],
  396. exam_number_box['xmax'], exam_number_box['ymax'])
  397. tree.write(xml_path)
  398. return solve_list
  399. def solve_with_number(regions, xml_path):
  400. solve_list = []
  401. for ele in regions:
  402. if 'solve' in ele["class_name"] or 'composition' in ele["class_name"] or 'correction' in ele["class_name"]:
  403. solve_dict = {'number': -1, 'default_points': -1, 'span': False, 'span_id': 1}
  404. ele.update(solve_dict)
  405. solve_list.append(ele)
  406. tree = ET.parse(xml_path) # xml tree
  407. for index_num, box in enumerate(solve_list):
  408. if len(box['bounding_box']) > 0:
  409. abcd = box['bounding_box']
  410. number = str(box['number'])
  411. default_points = box["default_points"]
  412. name = '{}_{}_{}'.format(box["class_name"], number, default_points)
  413. tree = utils.create_xml(name, tree,
  414. abcd['xmin'], abcd['ymin'],
  415. abcd['xmax'], abcd['ymax'])
  416. tree.write(xml_path)
  417. return solve_list
  418. def cloze_with_number(regions, xml_path):
  419. cloze_list = []
  420. for ele in regions:
  421. if 'cloze' == ele["class_name"] or "cloze_s" == ele["class_name"]:
  422. cloze_dict = {'number': -1, 'default_points': -1}
  423. ele.update(cloze_dict)
  424. cloze_list.append(ele)
  425. tree = ET.parse(xml_path) # xml tree
  426. for index_num, box in enumerate(cloze_list):
  427. if len(box['bounding_box']) > 0:
  428. abcd = box['bounding_box']
  429. number = str(box['number'])
  430. default_points = box["default_points"]
  431. name = '{}_{}_{}'.format(box["class_name"], number, default_points)
  432. tree = utils.create_xml(name, tree,
  433. abcd['xmin'], abcd['ymin'],
  434. abcd['xmax'], abcd['ymax'])
  435. tree.write(xml_path)
  436. return cloze_list