analysis_sheet.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. # @Author : lightXu
  2. # @File : analysis_sheet.py
  3. import time
  4. import os
  5. import traceback
  6. import numpy as np
  7. import cv2
  8. from segment.sheet_resolve.lib.model.test import im_detect
  9. from segment.sheet_resolve.lib.model.nms_wrapper import nms
  10. from segment.sheet_resolve.lib.utils.timer import Timer
  11. from segment.sheet_resolve.tools import utils
  12. from segment.sheet_resolve.analysis.solve.optional_solve import find_contours, resolve_optional_choice
  13. def analysis_single_image_with_regions(analysis_type, classes,
  14. sess, net,
  15. im_raw, conf_thresh, mns_thresh,
  16. coordinate_bias_dict):
  17. """Detect object classes in an image using pre-computed object proposals."""
  18. size = im_raw.shape
  19. # Detect all object classes and regress object bounds
  20. timer = Timer()
  21. timer.tic()
  22. if analysis_type in ['unknown_subject', 'math', 'math_zxhx', 'english', 'chinese',
  23. 'physics', 'chemistry', 'biology', 'politics', 'history',
  24. 'geography', 'science_comprehensive', 'arts_comprehensive'
  25. ]:
  26. analysis_type = 'sheet'
  27. im, ratio = utils.img_resize(analysis_type, im_raw)
  28. scores, boxes = im_detect(analysis_type, sess, net, im)
  29. timer.toc()
  30. print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time, boxes.shape[0]))
  31. content_list = []
  32. analysis_cls_list = []
  33. qr_code_info = 'Nan'
  34. for cls_ind, cls in enumerate(classes[1:]): # classes
  35. cls_ind += 1 # because we skipped background
  36. cls_boxes = boxes[:, 4 * cls_ind:4 * (cls_ind + 1)]
  37. cls_scores = scores[:, cls_ind]
  38. dets = np.hstack((cls_boxes,
  39. cls_scores[:, np.newaxis])).astype(np.float32)
  40. keep = nms(dets, mns_thresh)
  41. dets = dets[keep, :]
  42. # vis_detections(im, cls, dets, ax, thresh=conf_thresh)
  43. inds = np.where(dets[:, -1] >= conf_thresh)[0]
  44. if len(inds) > 0:
  45. if cls in list(coordinate_bias_dict.keys()):
  46. xmin_bias = coordinate_bias_dict[cls]['xmin_bias']
  47. ymin_bias = coordinate_bias_dict[cls]['ymin_bias']
  48. xmax_bias = coordinate_bias_dict[cls]['xmax_bias']
  49. ymax_bias = coordinate_bias_dict[cls]['ymax_bias']
  50. else:
  51. xmin_bias = 0
  52. ymin_bias = 0
  53. xmax_bias = 0
  54. ymax_bias = 0
  55. for i in inds:
  56. bbox = dets[i, :4]
  57. score = '{:.4f}'.format(dets[i, -1])
  58. xmin = int(int(bbox[0]) * ratio[0]) + xmin_bias
  59. ymin = int(int(bbox[1]) * ratio[1]) + ymin_bias
  60. xmax = int(int(bbox[2]) * ratio[0]) + xmax_bias
  61. ymax = int(int(bbox[3]) * ratio[1]) + ymax_bias
  62. xmin = (xmin if (xmin > 0) else 1)
  63. ymin = (ymin if (ymin > 0) else 1)
  64. xmax = (xmax if (xmax < size[1]) else size[1] - 1)
  65. ymax = (ymax if (ymax < size[0]) else size[0] - 1)
  66. if cls in ['solve0', ]:
  67. cls = 'solve'
  68. bbox_dict = {"xmin": xmin, "ymin": ymin, "xmax": xmax, "ymax": ymax}
  69. # class_dict = {"class_name": cls, "bounding_box": bbox_dict, "score": score}
  70. class_dict = {"class_name": cls, "bounding_box": bbox_dict}
  71. # if cls == 'qr_code':
  72. # qr_img = utils.crop_region(im_raw, bbox_dict)
  73. # qr_path = r'./qr_code.jpg'
  74. # cv2.imwrite(qr_path, qr_img)
  75. # qr_code_info = utils.check_qr_code_with_region_img(qr_path)
  76. # os.remove(qr_path)
  77. content_list.append(class_dict)
  78. return content_list, analysis_cls_list, qr_code_info
  79. def get_single_image_sheet_regions(analysis_type, img_path, img, classes,
  80. sess, net, conf_thresh, mns_thresh,
  81. coordinate_bias_dict):
  82. start_time = time.time()
  83. print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
  84. print('analysis for JPG {}'.format(img_path))
  85. content, cls, qr_code_info = \
  86. analysis_single_image_with_regions(analysis_type, classes, sess, net,
  87. img, conf_thresh, mns_thresh,
  88. coordinate_bias_dict)
  89. img_dict = {"img_name": img_path,
  90. # 'qr_code': qr_code_info,
  91. 'subject': analysis_type,
  92. "regions": content,
  93. }
  94. end_time = time.time()
  95. print(end_time - start_time)
  96. return img_dict
  97. def question_number_format(init_number, crt_numbers, sheet_dict):
  98. for region in sheet_dict['regions']:
  99. numbers = region.get("number")
  100. if numbers and isinstance(numbers, int):
  101. if numbers <= 0 or numbers in crt_numbers or numbers >= 1000:
  102. numbers = init_number
  103. crt_numbers.append(numbers)
  104. init_number += 1
  105. region.update({"number": numbers})
  106. crt_numbers.append(numbers)
  107. if numbers and isinstance(numbers, list):
  108. for i, num in enumerate(numbers):
  109. if num <= 0 or num in crt_numbers or num >= 1000:
  110. numbers[i] = init_number
  111. crt_numbers.append(init_number)
  112. init_number += 1
  113. region.update({"number": numbers})
  114. crt_numbers.extend(numbers)
  115. return sheet_dict, init_number, crt_numbers
  116. def box_region_format(sheet_dict, image, subject, shrink=True):
  117. include_class = ['anchor_point',
  118. 'bar_code',
  119. 'choice_m',
  120. 'cloze',
  121. 'cloze_s',
  122. 'exam_number_col_row',
  123. 'optional_choice',
  124. 'optional_solve',
  125. # 'qr_code',
  126. 'solve',
  127. 'optional_solve',
  128. 'composition',
  129. # 'correction'
  130. ]
  131. sheet_regions = sheet_dict['regions']
  132. optional_solve_tmp = []
  133. default_points_dict = {'choice_m': 5, "cloze": 5, 'solve': 12, 'cloze_s': 5, "composition": 60}
  134. if subject == "english":
  135. default_points_dict = {'choice_m': 2, "cloze": 2, 'solve': 2, 'cloze_s': 2, "composition": 25}
  136. for i in range(len(sheet_regions) - 1, -1, -1):
  137. if subject == "math":
  138. if sheet_regions[i]['class_name'] == 'cloze':
  139. sheet_regions[i]['class_name'] = 'cloze_big' # math exclude cloze big
  140. if sheet_regions[i]['class_name'] == 'cloze_s':
  141. sheet_regions[i]['class_name'] = 'cloze' # math exclude cloze big
  142. if subject == "english":
  143. if sheet_regions[i]['class_name'] == 'solve':
  144. sheet_regions[i]['class_name'] = 'cloze'
  145. if sheet_regions[i]['class_name'] == 'correction':
  146. sheet_regions[i]['class_name'] = 'solve'
  147. for i in range(len(sheet_regions) - 1, -1, -1):
  148. if sheet_regions[i]['class_name'] in ['solve0']:
  149. sheet_regions[i]['class_name'] = 'solve'
  150. if sheet_regions[i]['class_name'] in ['composition0']:
  151. sheet_regions[i]['class_name'] = 'composition'
  152. if sheet_regions[i]['class_name'] == 'select_s':
  153. sheet_regions[i]['class_name'] = 'optional_choice'
  154. optional_solve_tmp.append(sheet_regions[i])
  155. sheet_regions.pop(i)
  156. if shrink:
  157. if sheet_regions[i]['class_name'] not in include_class:
  158. sheet_regions.pop(i)
  159. for ele in sheet_regions:
  160. if ele['class_name'] == 'solve':
  161. solve_box = (ele['bounding_box']['xmin'], ele['bounding_box']['ymin'],
  162. ele['bounding_box']['xmax'], ele['bounding_box']['ymax'])
  163. for optional_solve in optional_solve_tmp:
  164. optional_solve_box = (optional_solve['bounding_box']['xmin'], optional_solve['bounding_box']['ymin'],
  165. optional_solve['bounding_box']['xmax'], optional_solve['bounding_box']['ymax'])
  166. if utils.decide_coordinate_contains(optional_solve_box, solve_box):
  167. ele['class_name'] = 'optional_solve'
  168. break
  169. else:
  170. continue
  171. if ele['class_name'] == "composition":
  172. if isinstance(ele['default_points'], list):
  173. for i, dp in enumerate(ele['default_points']):
  174. if dp != default_points_dict[ele['class_name']]:
  175. ele['default_points'][i] = default_points_dict[ele['class_name']]
  176. if isinstance(ele['default_points'], int):
  177. if ele['default_points'] != default_points_dict[ele['class_name']]:
  178. ele['default_points'] = default_points_dict[ele['class_name']]
  179. if ele['class_name'] in ["choice_m", "cloze", "cloze_s", "solve"]:
  180. if isinstance(ele['default_points'], list):
  181. for i, dp in enumerate(ele['default_points']):
  182. if dp == -1:
  183. ele['default_points'][i] = default_points_dict[ele['class_name']]
  184. if isinstance(ele['default_points'], int):
  185. if ele['default_points'] == -1:
  186. ele['default_points'] = default_points_dict[ele['class_name']]
  187. for ele in optional_solve_tmp: # 选做题
  188. bbox = ele['bounding_box']
  189. box_region = utils.crop_region(image, bbox)
  190. left = bbox['xmin']
  191. top = bbox['ymin']
  192. right = bbox['xmax']
  193. bottom = bbox['ymax']
  194. if (right - left) >= (bottom-top):
  195. direction = 180
  196. else:
  197. direction = 90
  198. # res = find_contours(left, top, box_region)
  199. try:
  200. res = resolve_optional_choice(left, top, direction, box_region)
  201. except Exception as e:
  202. res = {'rows': 1, 'cols': 2,
  203. 'option': 'A, B',
  204. 'single_width': (right - left) // 3,
  205. 'single_height': bottom - top,
  206. 'bounding_box': {'xmin': left,
  207. 'ymin': top,
  208. 'xmax': right,
  209. 'ymax': bottom}}
  210. res['class_name'] = 'optional_choice'
  211. sheet_regions.append(res)
  212. # iou
  213. sheet_tmp = sheet_regions.copy()
  214. remove_index = []
  215. for i, region in enumerate(sheet_tmp):
  216. if i not in remove_index:
  217. box = region['bounding_box']
  218. for j, region_in in enumerate(sheet_tmp):
  219. box_in = region_in['bounding_box']
  220. iou = utils.cal_iou(box, box_in)
  221. if iou[0] > 0.75 and i != j:
  222. sheet_regions.remove(region)
  223. remove_index.append(j)
  224. break
  225. sheet_dict.update({'regions': sheet_regions})
  226. return sheet_dict