resolve.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512
  1. # @Author : lightXu
  2. # @File : resolve.py
  3. # @Time : 2018/12/3 0003 上午 10:16
  4. import time
  5. import traceback
  6. import xml.etree.cElementTree as ET
  7. from django.conf import settings
  8. import segment.logging_config as logging
  9. import segment.sheet_resolve.analysis.choice.analysis_choice as resolve_choice
  10. import segment.sheet_resolve.analysis.choice.choice_box as choice_box
  11. import segment.sheet_resolve.analysis.choice.choice_line_box as choice_line_box
  12. import segment.sheet_resolve.analysis.cloze.analysis_cloze as resolve_cloze
  13. import segment.sheet_resolve.analysis.cloze.cloze_line_box as resolve_cloze_line_box
  14. import segment.sheet_resolve.analysis.exam_number.exam_number_box as resolve_exam_number_box
  15. import segment.sheet_resolve.analysis.exam_number.exam_number_row_column as exam_number_row_column
  16. import segment.sheet_resolve.analysis.sheet.analysis_sheet as resolve_sheet
  17. import segment.sheet_resolve.analysis.solve.mark_box as resolve_mark_box
  18. import segment.sheet_resolve.analysis.solve.mark_line_box as resolve_mark_line_box
  19. from segment.sheet_resolve.tools import utils
  20. from segment.sheet_resolve.tools.tf_sess import TfSess
  21. from segment.sheet_resolve.tools.tf_settings import xml_template_path, model_dict
  22. from segment.sheet_resolve.tools.utils import read_single_img, read_xml_to_json, create_xml
  23. from segment.sheet_resolve.analysis.sheet.sheet_adjust import adjust_item_edge_by_gray_image
  24. from segment.sheet_resolve.analysis.sheet.sheet_infer import infer_bar_code, box_infer_and_complete, infer_solve
  25. from segment.sheet_resolve.analysis.sheet.sheet_infer import infer_exam_number, adjust_exam_number, exam_number_infer_by_s
  26. from segment.sheet_resolve.analysis.sheet.choice_infer import infer_choice_m
  27. from segment.sheet_resolve.analysis.sheet.ocr_sheet import tell_columns, sheet_sorted
  28. logger = logging.getLogger(settings.LOGGING_TYPE)
  29. sheet_infer_dict = dict(bar_code=True,
  30. choice_m=True,
  31. exam_number=True,
  32. common_sheet=False,
  33. solve=True)
  34. infer_choice_m_flag = False
  35. def sheet(series_number, image_path, image, conf_thresh, mns_thresh, subject, sheet_sess, ocr=''):
  36. global infer_choice_m_flag
  37. model_type = subject
  38. classes = list(model_dict[model_type]['classes'])
  39. coordinate_bias_dict = model_dict[model_type]['class_coordinate_bias']
  40. if '_blank' in model_type:
  41. model_type = model_type.replace("_blank", "")
  42. sheets_dict = resolve_sheet.get_single_image_sheet_regions(model_type, image_path, image, classes,
  43. sheet_sess.sess, sheet_sess.net,
  44. conf_thresh, mns_thresh, coordinate_bias_dict)
  45. h, w = image.shape[0], image.shape[1]
  46. regions = sheets_dict['regions']
  47. fetched_class = [ele['class_name'] for ele in regions]
  48. try:
  49. regions = adjust_item_edge_by_gray_image(image, regions)
  50. except Exception as e:
  51. traceback.print_exc()
  52. logger.info('试卷:{} 自适应边框失败: {}'.format(image_path, e))
  53. # 分栏
  54. col_split_x = tell_columns(image, regions)
  55. if sheet_infer_dict['bar_code']:
  56. try:
  57. if ('bar_code' not in fetched_class) and ocr:
  58. attention_region = [ele for ele in regions if ele['class_name'] == 'attention']
  59. bar_code_list = infer_bar_code(image, ocr, attention_region)
  60. regions.extend(bar_code_list)
  61. except Exception as e:
  62. traceback.print_exc()
  63. logger.info('试卷:{} 条形码推断失败: {}'.format(image_path, e))
  64. if sheet_infer_dict['exam_number']:
  65. try:
  66. cond1 = 'exam_number' in fetched_class
  67. tmp = ['info_title', 'qr_code', 'bar_code', 'choice', 'choice_m', 'exam_number_w']
  68. cond2 = True in [True for ele in tmp if ele in fetched_class] # 第一面特征
  69. cond3 = 'exam_number_w' in fetched_class
  70. cond4 = 'exam_number_s' in fetched_class
  71. # if cond1 and cond3 and not cond4:
  72. if cond1 and cond3:
  73. regions = adjust_exam_number(regions)
  74. if not cond1 and cond4:
  75. exam_number_list = exam_number_infer_by_s(image, regions)
  76. regions.extend(exam_number_list)
  77. if not cond1 and not cond4 and cond2 and ocr:
  78. exam_number_list = infer_exam_number(image, ocr, regions)
  79. regions.extend(exam_number_list)
  80. except Exception as e:
  81. traceback.print_exc()
  82. logger.info('试卷:{} 考号推断失败: {}'.format(image_path, e))
  83. if sheet_infer_dict['choice_m']:
  84. try:
  85. choice_m_list = infer_choice_m(image, regions, col_split_x, ocr)
  86. #remain_choice_m = []
  87. if len(choice_m_list) > 0:
  88. choice_m_old_list = [ele for ele in regions if 'choice_m' == ele['class_name']]
  89. for infer_box in choice_m_list.copy():
  90. infer_loc = infer_box['bounding_box']
  91. for tf_box in choice_m_old_list:
  92. tf_loc = tf_box['bounding_box']
  93. iou = utils.cal_iou(infer_loc, tf_loc)
  94. # if iou[0] > 0.70 or iou[1] > 0.70 or iou[2] > 0.70:
  95. # if iou[0] > 0.70 or iou[2] > 0.70:
  96. if iou[0] > 0.85:
  97. # if infer_box not in remain_choice_m:
  98. # remain_choice_m.append(infer_box)
  99. # choice_m_list.remove(infer_box)
  100. regions.remove(tf_box)
  101. # break
  102. elif iou[0] > 0:
  103. choice_m_list.remove(infer_box)
  104. break
  105. #remain_choice_m.extend(choice_m_list)
  106. # regions = [ele for ele in regions if 'choice_m' != ele['class_name']]
  107. # regions.extend(remain_choice_m)
  108. regions.extend(choice_m_list)
  109. infer_choice_m_flag = True
  110. except Exception as e:
  111. traceback.print_exc()
  112. logger.info('试卷:{} 选择题推断失败: {}'.format(image_path, e))
  113. if sheet_infer_dict['solve']:
  114. try:
  115. include_class = ['info_title',
  116. 'bar_code',
  117. 'choice_m',
  118. 'cloze',
  119. 'cloze_s',
  120. 'exam_number',
  121. 'solve',
  122. 'composition',
  123. 'correction'
  124. ]
  125. regions_subset = [ele for ele in regions if ele['class_name'] in include_class]
  126. col_regions = sheet_sorted(regions_subset, col_split_x)
  127. top = min([ele['bounding_box']['ymin'] for ele in regions])
  128. bottom = max([ele['bounding_box']['ymax'] for ele in regions])
  129. seal_area = [ele for ele in regions if 'seal' in ele['class_name']]
  130. if len(seal_area) > 0:
  131. right, left = w, 1
  132. for ele in seal_area:
  133. if ele['bounding_box']['xmax'] > w // 2:
  134. right = ele['bounding_box']['xmin']
  135. if ele['bounding_box']['xmax'] < w // 2:
  136. left = ele['bounding_box']['xmax']
  137. else:
  138. left = min([ele['bounding_box']['xmin'] for ele in regions])
  139. right = max([ele['bounding_box']['xmax'] for ele in regions])
  140. solve_regions = infer_solve(regions, left, right, top, bottom, col_regions, col_split_x)
  141. regions.append(solve_regions)
  142. except Exception as e:
  143. traceback.print_exc()
  144. logger.info('试卷:{} 解答题补全推断失败: {}'.format(image_path, e))
  145. if sheet_infer_dict['common_sheet']:
  146. try:
  147. regions = box_infer_and_complete(image, regions, ocr)
  148. except Exception as e:
  149. traceback.print_exc()
  150. logger.info('试卷:{} 识别框补全推断失败: {}'.format(image_path, e))
  151. try:
  152. adjust_regions = adjust_item_edge_by_gray_image(image, regions)
  153. except Exception as e:
  154. adjust_regions = regions
  155. traceback.print_exc()
  156. logger.info('试卷:{} 自适应边框失败: {}'.format(image_path, e))
  157. sheets_dict.update({'regions': adjust_regions})
  158. # generate xml
  159. tree = ET.parse(xml_template_path)
  160. xml_save_path = sheets_dict['img_name'].replace('.jpg', '.xml')
  161. root = tree.getroot()
  162. series = ET.SubElement(root, 'paper_id')
  163. series.text = series_number
  164. img_shape = image.shape
  165. project = ET.SubElement(root, 'size', {})
  166. width = ET.SubElement(project, 'width')
  167. width.text = str(img_shape[1])
  168. height = ET.SubElement(project, 'height')
  169. height.text = str(img_shape[0])
  170. depth = ET.SubElement(project, 'depth')
  171. if len(img_shape) >= 3:
  172. depth.text = '3'
  173. else:
  174. depth.text = '1'
  175. for ele in regions:
  176. name = ele['class_name']
  177. xmin = ele['bounding_box']['xmin']
  178. ymin = ele['bounding_box']['ymin']
  179. xmax = ele['bounding_box']['xmax']
  180. ymax = ele['bounding_box']['ymax']
  181. tree = create_xml(name, tree, xmin, ymin, xmax, ymax)
  182. tree.write(xml_save_path)
  183. return sheets_dict, xml_save_path
  184. def choice(image, regions, xml_path, conf_thresh, mns_thresh, choice_sess):
  185. model_type = 'choice'
  186. classes = model_dict[model_type]['classes']
  187. coordinate_bias_dict = model_dict[model_type]['class_coordinate_bias']
  188. choice_list = []
  189. for ele in regions:
  190. if ele["class_name"] == 'choice':
  191. choice_bbox = ele['bounding_box']
  192. left = choice_bbox['xmin']
  193. top = choice_bbox['ymin']
  194. choice_img = utils.crop_region(image, choice_bbox)
  195. choice_dict_tf = resolve_choice. \
  196. get_single_image_sheet_regions('choice', choice_img, classes,
  197. choice_sess.sess, choice_sess.net, conf_thresh, mns_thresh,
  198. coordinate_bias_dict)
  199. choice_list = choice_list + choice_line_box.choice_line(left, top, choice_img, choice_dict_tf, xml_path)
  200. return choice_list
  201. def choice_row_col(image, regions, xml_path, conf_thresh, mns_thresh, choice_sess):
  202. model_type = 'choice_m'
  203. classes = model_dict[model_type]['classes']
  204. coordinate_bias_dict = model_dict[model_type]['class_coordinate_bias']
  205. choice_list = []
  206. for ele in regions:
  207. if ele["class_name"] == 'choice':
  208. choice_box = ele['bounding_box']
  209. left = choice_box['xmin']
  210. top = choice_box['ymin']
  211. choice_img = utils.crop_region(image, choice_box)
  212. choice_m_dict_tf = resolve_choice. \
  213. get_single_image_sheet_regions('choice_m', choice_img, classes,
  214. choice_sess.sess, choice_sess.net, conf_thresh, mns_thresh,
  215. coordinate_bias_dict)
  216. choice_list = choice_list + choice_line_box.choice_line_with_number(left, top, choice_img, choice_m_dict_tf, xml_path)
  217. return choice_list
  218. def choice_m_row_col(image, regions, xml_path):
  219. choice_m_dict_tf = [ele for ele in regions if ele['class_name'] == 'choice_m']
  220. # choice_m_row_col_with_number
  221. choice_list = []
  222. try:
  223. # choice_list = choice_box.get_number_by_enlarge_choice_m(image, choice_m_dict_tf, xml_path)
  224. # if infer_choice_m_flag:
  225. # choice_list = choice_line_box.choice_m_adjust(image, choice_m_dict_tf)
  226. #
  227. # else:
  228. # choice_list = choice_line_box.choice_m_row_col(image, choice_m_dict_tf, xml_path) # 找选择题行列、分数
  229. choice_list = choice_line_box.choice_m_row_col(image, choice_m_dict_tf, xml_path) # 找选择题行列、分数
  230. tree = ET.parse(xml_path) # xml tree
  231. for index_num, box in enumerate(choice_list):
  232. if len(box['bounding_box']) > 0:
  233. abcd = box['bounding_box']
  234. number = str(box['number'])
  235. name = '{}_{}*{}_{}_{}'.format('choice_m', box['rows'], box['cols'], box['direction'], number)
  236. tree = utils.create_xml(name, tree,
  237. abcd['xmin'], abcd['ymin'],
  238. abcd['xmax'], abcd['ymax'])
  239. tree.write(xml_path)
  240. except Exception as e:
  241. traceback.print_exc()
  242. print(e)
  243. return choice_list
  244. def exam_number(image, regions, xml_path):
  245. exam_number_dict = {}
  246. for ele in regions:
  247. if ele["class_name"] == 'exam_number':
  248. exam_number_dict = ele
  249. exam_number_box = exam_number_dict['bounding_box']
  250. left = exam_number_box['xmin']
  251. top = exam_number_box['ymin']
  252. exam_number_img = utils.crop_region(image, exam_number_box)
  253. # exam_number_dict = resolve_exam_number_box.exam_number(left, top, exam_number_img, xml_path)
  254. exam_number_dict = resolve_exam_number_box.exam_number_whole(left, top, exam_number_img, xml_path)
  255. # print(exam_number_dict)
  256. return exam_number_dict
  257. def exam_number_row_col(image, regions, xml_path):
  258. exam_number_dict = {}
  259. for ele in regions:
  260. if ele["class_name"] == 'exam_number':
  261. exam_number_dict = ele
  262. exam_number_box = exam_number_dict['bounding_box']
  263. left = exam_number_box['xmin']
  264. top = exam_number_box['ymin']
  265. exam_number_img = utils.crop_region(image, exam_number_box)
  266. exam_number_row_col_dict = exam_number_row_column.get_exam_number_row_and_col(left, top, exam_number_img)
  267. tree = ET.parse(xml_path) # xml tree
  268. if len(exam_number_row_col_dict) > 0:
  269. exam_number_box = exam_number_row_col_dict['bounding_box']
  270. name = '{}_{}*{}_{}'.format('exam_number',
  271. exam_number_row_col_dict['rows'],
  272. exam_number_row_col_dict['cols'],
  273. exam_number_row_col_dict['direction'])
  274. tree = utils.create_xml(name, tree,
  275. exam_number_box['xmin'], exam_number_box['ymin'],
  276. exam_number_box['xmax'], exam_number_box['ymax'])
  277. tree.write(xml_path)
  278. return [exam_number_row_col_dict]
  279. else:
  280. tree = utils.create_xml('exam_number', tree,
  281. exam_number_box['xmin'], exam_number_box['ymin'],
  282. exam_number_box['xmax'], exam_number_box['ymax'])
  283. tree.write(xml_path)
  284. return []
  285. def cloze(image, regions, xml_path, conf_thresh, mns_thresh, cloze_sess):
  286. classes = model_dict['cloze']['classes']
  287. coordinate_bias_dict = model_dict['cloze']['class_coordinate_bias']
  288. cloze_list = []
  289. for ele in regions:
  290. if ele["class_name"] == 'cloze':
  291. cloze_box = ele['bounding_box']
  292. left = cloze_box['xmin']
  293. top = cloze_box['ymin']
  294. cloze_img = utils.crop_region(image, cloze_box)
  295. cloze_dict_tf = resolve_cloze.get_single_image_sheet_regions('cloze', cloze_img, classes,
  296. cloze_sess.sess, cloze_sess.net, conf_thresh,
  297. mns_thresh, coordinate_bias_dict)
  298. cloze_list = cloze_list + resolve_cloze_line_box.cloze_line(left, top, cloze_img, cloze_dict_tf['regions'], xml_path)
  299. return cloze_list
  300. def solve_with_mark(image, regions, xml_path):
  301. solve_list = []
  302. mark_list = []
  303. for ele in regions.copy():
  304. if 'solve' in ele["class_name"]:
  305. exam_number_box = ele['bounding_box']
  306. left = exam_number_box['xmin']
  307. top = exam_number_box['ymin']
  308. exam_number_img = utils.crop_region(image, exam_number_box)
  309. solve_mark_dict = resolve_mark_box.solve_mark(left, top, exam_number_img, xml_path)
  310. if len(solve_mark_dict) > 0:
  311. ele['class_name'] = 'solve_'+str(solve_mark_dict['number'])
  312. solve_list.append(ele)
  313. mark_list.append(solve_mark_dict)
  314. return solve_list, mark_list
  315. def solve(image, regions, xml_path):
  316. solve_list = []
  317. tree = ET.parse(xml_path)
  318. for ele in regions.copy():
  319. if 'solve' in ele["class_name"]:
  320. exam_number_box = ele['bounding_box']
  321. exam_number_img = utils.crop_region(image, exam_number_box)
  322. number = resolve_mark_line_box.solve_line(exam_number_img)
  323. solve_dict = {'number': number, 'location': exam_number_box, 'default_points': 12}
  324. solve_list.append(solve_dict)
  325. tree = utils.create_xml(str(number), tree,
  326. exam_number_box['xmin'], exam_number_box['ymin'],
  327. exam_number_box['xmax'], exam_number_box['ymax'])
  328. tree.write(xml_path)
  329. return solve_list
  330. def solve_with_number(regions, xml_path):
  331. solve_list = []
  332. for ele in regions:
  333. if 'solve' in ele["class_name"] or 'composition' in ele["class_name"]:
  334. solve_dict = {'number': -1, 'default_points': -1}
  335. ele.update(solve_dict)
  336. solve_list.append(ele)
  337. tree = ET.parse(xml_path) # xml tree
  338. for index_num, box in enumerate(solve_list):
  339. if len(box['bounding_box']) > 0:
  340. abcd = box['bounding_box']
  341. number = str(box['number'])
  342. default_points = box["default_points"]
  343. name = '{}_{}_{}'.format(box["class_name"], number, default_points)
  344. tree = utils.create_xml(name, tree,
  345. abcd['xmin'], abcd['ymin'],
  346. abcd['xmax'], abcd['ymax'])
  347. tree.write(xml_path)
  348. return solve_list
  349. def cloze_with_number(regions, xml_path):
  350. cloze_list = []
  351. for ele in regions:
  352. if 'cloze' == ele["class_name"] or "cloze_s" == ele["class_name"]:
  353. cloze_dict = {'number': -1, 'default_points': -1}
  354. ele.update(cloze_dict)
  355. cloze_list.append(ele)
  356. tree = ET.parse(xml_path) # xml tree
  357. for index_num, box in enumerate(cloze_list):
  358. if len(box['bounding_box']) > 0:
  359. abcd = box['bounding_box']
  360. number = str(box['number'])
  361. default_points = box["default_points"]
  362. name = '{}_{}_{}'.format(box["class_name"], number, default_points)
  363. tree = utils.create_xml(name, tree,
  364. abcd['xmin'], abcd['ymin'],
  365. abcd['xmax'], abcd['ymax'])
  366. tree.write(xml_path)
  367. return cloze_list
  368. def make_together(image_path):
  369. sheet_sess = TfSess('sheet')
  370. choice_sess = TfSess('choice')
  371. cloze_sess = TfSess('cloze')
  372. raw_img = read_single_img(image_path)
  373. conf_thresh_0 = 0.7
  374. mns_thresh_0 = 0.3
  375. series_number = 123456789
  376. subject = 'english'
  377. sheets_dict_0, xml_save_path = sheet(series_number, image_path, raw_img, conf_thresh_0, mns_thresh_0, subject, sheet_sess)
  378. # 手动修改faster_rcnn识别生成的框
  379. sheets_dict_0 = read_xml_to_json(xml_save_path)
  380. regions = sheets_dict_0['regions']
  381. classes_name = str([ele['class_name'] for ele in regions])
  382. if 'choice' in classes_name:
  383. try:
  384. sheets_dict_0['choice'] = choice(raw_img, regions, xml_save_path, conf_thresh_0, mns_thresh_0, choice_sess)
  385. except Exception:
  386. traceback.print_exc()
  387. if 'exam_number' in classes_name:
  388. try:
  389. sheets_dict_0['exam_number'] = exam_number(raw_img, regions, xml_save_path)
  390. except Exception:
  391. traceback.print_exc()
  392. if 'cloze' in classes_name:
  393. try:
  394. sheets_dict_0['cloze'] = cloze(raw_img, regions, xml_save_path, conf_thresh_0, mns_thresh_0, cloze_sess)
  395. except Exception:
  396. traceback.print_exc()
  397. if 'solve' in classes_name:
  398. try:
  399. solve_list, mark_list = solve(raw_img, regions, xml_save_path,)
  400. sheets_dict_0['solve'] = solve_list
  401. sheets_dict_0['mark'] = mark_list
  402. except Exception:
  403. traceback.print_exc()
  404. # print(sheets_dict_0)
  405. return sheets_dict_0
  406. # if __name__ == '__main__':
  407. # start_time = time.time()
  408. #
  409. # image_path_0 = os.path.join(r'C:\Users\Administrator\Desktop\sheet\correct\back_sizes\template',
  410. # '20180719004308818_0020.jpg')
  411. # make_together(image_path_0)
  412. # end_time = time.time()
  413. # print('time cost: ', (end_time - start_time))