analysis_sheet.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. # @Author : lightXu
  2. # @File : analysis_sheet.py
  3. import time
  4. import os
  5. import traceback
  6. import numpy as np
  7. import cv2
  8. from django.conf import settings
  9. import segment.logging_config as logging
  10. from segment.sheet_resolve.lib.model.test import im_detect
  11. from segment.sheet_resolve.lib.model.nms_wrapper import nms
  12. from segment.sheet_resolve.lib.utils.timer import Timer
  13. from segment.sheet_resolve.tools import utils
  14. from segment.sheet_resolve.analysis.solve.optional_solve import resolve_optional_choice
  15. logger = logging.getLogger(settings.LOGGING_TYPE)
  16. def analysis_single_image_with_regions(analysis_type, classes,
  17. sess, net,
  18. im_raw, conf_thresh, mns_thresh,
  19. coordinate_bias_dict):
  20. """Detect object classes in an image using pre-computed object proposals."""
  21. size = im_raw.shape
  22. # Detect all object classes and regress object bounds
  23. timer = Timer()
  24. timer.tic()
  25. if '_blank' in analysis_type:
  26. analysis_type = analysis_type.replace('_blank', '')
  27. if analysis_type in ['unknown_subject', 'math', 'math_zxhx', 'english', 'chinese',
  28. 'physics', 'chemistry', 'biology', 'politics', 'history',
  29. 'geography', 'science_comprehensive', 'arts_comprehensive'
  30. ]:
  31. analysis_type = 'sheet'
  32. # im, ratio = utils.img_resize(analysis_type, im_raw)
  33. im, ratio = utils.resize_faster_rcnn(analysis_type, im_raw)
  34. scores, boxes = im_detect(analysis_type, sess, net, im)
  35. timer.toc()
  36. print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time, boxes.shape[0]))
  37. content_list = []
  38. analysis_cls_list = []
  39. qr_code_info = 'Nan'
  40. for cls_ind, cls in enumerate(classes[1:]): # classes
  41. cls_ind += 1 # because we skipped background
  42. cls_boxes = boxes[:, 4 * cls_ind:4 * (cls_ind + 1)]
  43. cls_scores = scores[:, cls_ind]
  44. dets = np.hstack((cls_boxes,
  45. cls_scores[:, np.newaxis])).astype(np.float32)
  46. keep = nms(dets, mns_thresh)
  47. dets = dets[keep, :]
  48. inds = np.where(dets[:, -1] >= conf_thresh)[0]
  49. if len(inds) > 0:
  50. if cls in list(coordinate_bias_dict.keys()):
  51. xmin_bias = coordinate_bias_dict[cls]['xmin_bias']
  52. ymin_bias = coordinate_bias_dict[cls]['ymin_bias']
  53. xmax_bias = coordinate_bias_dict[cls]['xmax_bias']
  54. ymax_bias = coordinate_bias_dict[cls]['ymax_bias']
  55. else:
  56. xmin_bias = 0
  57. ymin_bias = 0
  58. xmax_bias = 0
  59. ymax_bias = 0
  60. for i in inds:
  61. bbox = dets[i, :4]
  62. score = '{:.4f}'.format(dets[i, -1])
  63. xmin = int(int(bbox[0]) * ratio[0]) + xmin_bias
  64. ymin = int(int(bbox[1]) * ratio[1]) + ymin_bias
  65. xmax = int(int(bbox[2]) * ratio[0]) + xmax_bias
  66. ymax = int(int(bbox[3]) * ratio[1]) + ymax_bias
  67. xmin = (xmin if (xmin > 0) else 1)
  68. ymin = (ymin if (ymin > 0) else 1)
  69. xmax = (xmax if (xmax < size[1]) else size[1] - 1)
  70. ymax = (ymax if (ymax < size[0]) else size[0] - 1)
  71. if cls in ['solve0', 'composition0']:
  72. cls = cls.replace('0', '')
  73. bbox_dict = {"xmin": xmin, "ymin": ymin, "xmax": xmax, "ymax": ymax}
  74. class_dict = {"class_name": cls, "bounding_box": bbox_dict, "score": score}
  75. # if cls == 'qr_code':
  76. # qr_img = utils.crop_region(im_raw, bbox_dict)
  77. # qr_path = r'./qr_code.jpg'
  78. # cv2.imwrite(qr_path, qr_img)
  79. # qr_code_info = utils.check_qr_code_with_region_img(qr_path)
  80. # os.remove(qr_path)
  81. content_list.append(class_dict)
  82. return content_list, analysis_cls_list, qr_code_info
  83. def get_single_image_sheet_regions(analysis_type, img_path, img, classes,
  84. sess, net, conf_thresh, mns_thresh,
  85. coordinate_bias_dict):
  86. start_time = time.time()
  87. print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
  88. print('analysis for JPG {}'.format(img_path))
  89. content, cls, qr_code_info = \
  90. analysis_single_image_with_regions(analysis_type, classes, sess, net,
  91. img, conf_thresh, mns_thresh,
  92. coordinate_bias_dict)
  93. analysis_type.replace('_blank', '')
  94. img_dict = {"img_name": img_path,
  95. # 'qr_code': qr_code_info,
  96. 'subject': analysis_type,
  97. "regions": content,
  98. }
  99. end_time = time.time()
  100. print(end_time - start_time)
  101. return img_dict
  102. def question_number_format(init_number, crt_numbers, regions):
  103. """
  104. 将重复或者是-1的题号改为501
  105. :param init_number: 初始替换的题号
  106. :param crt_numbers: 目前已经有的题号
  107. :param regions: 答题卡各区域
  108. :return:
  109. """
  110. logger.info('regions: {}'.format(regions))
  111. for region in regions:
  112. logger.info('region: {}'.format(region))
  113. if region['class_name'] == 'optional_choice':
  114. continue
  115. while init_number in crt_numbers:
  116. init_number += 1
  117. numbers = region.get("number")
  118. if numbers and (isinstance(numbers, int) or isinstance(numbers, float)):
  119. if numbers <= 0 or numbers in crt_numbers or numbers >= 1000:
  120. if not region.get("span"):
  121. numbers = init_number
  122. crt_numbers.append(numbers)
  123. init_number += 1
  124. region.update({"number": numbers})
  125. crt_numbers.append(numbers)
  126. if numbers and isinstance(numbers, list):
  127. for i, num in enumerate(numbers):
  128. if num <= 0 or num in crt_numbers or num >= 1000:
  129. if not region.get("span"):
  130. numbers[i] = init_number
  131. crt_numbers.append(init_number)
  132. init_number += 1
  133. region.update({"number": numbers})
  134. crt_numbers.extend(numbers)
  135. return regions, init_number, crt_numbers
  136. def resolve_select_s(image, bbox):
  137. box_region = utils.crop_region(image, bbox)
  138. left = bbox['xmin']
  139. top = bbox['ymin']
  140. right = bbox['xmax']
  141. bottom = bbox['ymax']
  142. if (right - left) >= (bottom - top):
  143. direction = 180
  144. else:
  145. direction = 90
  146. try:
  147. res = resolve_optional_choice(left, top, direction, box_region)
  148. except Exception as e:
  149. res = {'class_name': 'optional_choice',
  150. 'rows': 1, 'cols': 2,
  151. 'number': [501, 502],
  152. 'single_width': right - left,
  153. 'single_height': bottom - top,
  154. 'bounding_box': {'xmin': left,
  155. 'ymin': top,
  156. 'xmax': right,
  157. 'ymax': bottom}}
  158. return res
  159. def box_region_format(sheet_dict, image, subject, shrink=True):
  160. include_class = ['anchor_point',
  161. 'bar_code',
  162. 'choice_m',
  163. 'cloze',
  164. 'cloze_s',
  165. 'exam_number_col_row',
  166. 'optional_choice',
  167. 'optional_solve',
  168. # 'qr_code',
  169. 'solve',
  170. 'optional_solve',
  171. 'composition',
  172. # 'correction'
  173. ]
  174. default_points_dict = {'choice_m': 5, "cloze": 5, 'solve': 12, 'optional_solve': 10, 'cloze_s': 5,
  175. "composition": 60}
  176. if subject in ["english", 'physics', 'chemistry', 'biology', 'science_comprehensive']:
  177. default_points_dict = {'choice_m': 2, "cloze": 2, 'solve': 10, 'optional_solve': 10, 'cloze_s': 2,
  178. "composition": 25}
  179. sheet_regions = sheet_dict['regions']
  180. select_s_list = []
  181. for i in range(len(sheet_regions) - 1, -1, -1):
  182. if subject == "math":
  183. if sheet_regions[i]['class_name'] == 'cloze':
  184. sheet_regions[i]['class_name'] = 'cloze_big' # math exclude cloze big
  185. if sheet_regions[i]['class_name'] == 'cloze_s':
  186. sheet_regions[i]['class_name'] = 'cloze' # math exclude cloze big
  187. if subject == "english":
  188. if sheet_regions[i]['class_name'] == 'cloze':
  189. sheet_regions[i]['class_name'] = 'solve'
  190. if sheet_regions[i]['class_name'] == 'correction':
  191. sheet_regions[i]['class_name'] = 'solve'
  192. if sheet_regions[i]['class_name'] in ['solve0']:
  193. sheet_regions[i]['class_name'] = 'solve'
  194. if sheet_regions[i]['class_name'] in ['composition0']:
  195. sheet_regions[i]['class_name'] = 'composition'
  196. if sheet_regions[i]['class_name'] == 'select_s':
  197. select_s_list.append(sheet_regions[i])
  198. if shrink:
  199. if sheet_regions[i]['class_name'] not in include_class:
  200. sheet_regions.pop(i)
  201. # 去重
  202. sheet_tmp = sheet_regions.copy()
  203. remove_index = []
  204. for i, region in enumerate(sheet_tmp):
  205. if i not in remove_index:
  206. box = region['bounding_box']
  207. name = region['class_name']
  208. for j, region_in in enumerate(sheet_tmp):
  209. box_in = region_in['bounding_box']
  210. name_in = region_in['class_name']
  211. iou = utils.cal_iou(box, box_in)
  212. if name == name_in and (iou[0] > 0.75 or iou[1] > 0.85 or iou[2] > 0.85) and i != j:
  213. box_area = (box['xmax'] - box['xmin']) * (box['ymax'] - box['ymin'])
  214. box_in_area = (box_in['xmax'] - box_in['xmin']) * (box_in['ymax'] - box_in['ymin'])
  215. if box_area >= box_in_area:
  216. sheet_regions.remove(region_in)
  217. remove_index.append(j)
  218. else:
  219. sheet_regions.remove(region)
  220. remove_index.append(i)
  221. break
  222. # 合并select_s
  223. optional_choice_tmp = []
  224. select_s_list_copy = select_s_list.copy()
  225. if len(select_s_list) > 0:
  226. for ele in sheet_regions:
  227. if ele['class_name'] == 'solve':
  228. solve_box = (ele['bounding_box']['xmin'], ele['bounding_box']['ymin'],
  229. ele['bounding_box']['xmax'], ele['bounding_box']['ymax'])
  230. xn, yn, xm, ym = 9999, 9999, 0, 0
  231. merge = False
  232. for select_s in select_s_list:
  233. select_s_box = (select_s['bounding_box']['xmin'], select_s['bounding_box']['ymin'],
  234. select_s['bounding_box']['xmax'], select_s['bounding_box']['ymax'])
  235. if utils.decide_coordinate_contains(select_s_box, solve_box):
  236. merge = True
  237. xn = min(xn, select_s_box[0])
  238. yn = min(yn, select_s_box[1])
  239. xm = max(xm, select_s_box[2])
  240. ym = max(ym, select_s_box[3])
  241. select_s_list_copy.remove(select_s)
  242. if merge:
  243. new_box = {'xmin': xn, 'ymin': yn, 'xmax': xm, 'ymax': ym}
  244. optional_choice_info = resolve_select_s(image, new_box)
  245. optional_choice_tmp.append(optional_choice_info)
  246. for ele in select_s_list_copy:
  247. box = ele['bounding_box']
  248. optional_choice_info = resolve_select_s(image, box)
  249. optional_choice_tmp.append(optional_choice_info)
  250. optional_choice_tmp_ = optional_choice_tmp.copy()
  251. for ele in sheet_regions:
  252. if len(optional_choice_tmp) > 0:
  253. if ele['class_name'] == 'solve':
  254. solve_box = (ele['bounding_box']['xmin'], ele['bounding_box']['ymin'],
  255. ele['bounding_box']['xmax'], ele['bounding_box']['ymax'])
  256. for optional_choice in optional_choice_tmp_:
  257. optional_choice_box = (optional_choice['bounding_box']['xmin'], optional_choice['bounding_box']['ymin'],
  258. optional_choice['bounding_box']['xmax'], optional_choice['bounding_box']['ymax'])
  259. if utils.decide_coordinate_contains(optional_choice_box, solve_box):
  260. optional_choice_tmp.remove(optional_choice)
  261. ele['class_name'] = 'optional_solve'
  262. choice_numbers = optional_choice['number']
  263. solve_numbers = ele['number']
  264. if choice_numbers[0] < 500:
  265. ele['number'] = choice_numbers
  266. ele['default_points'] = [ele['default_points']] * len(choice_numbers)
  267. else:
  268. tmp = [solve_numbers] * len(choice_numbers)
  269. for i, num in enumerate(tmp):
  270. tmp[i] = num + i
  271. ele['number'] = tmp
  272. optional_choice['number'] = tmp
  273. ele['default_points'] = [ele['default_points']] * len(choice_numbers)
  274. break
  275. else:
  276. continue
  277. if ele['class_name'] in ["choice_m", "cloze", "cloze_s", "solve", "optional_solve", "composition"]:
  278. if isinstance(ele['default_points'], list):
  279. for i, dp in enumerate(ele['default_points']):
  280. if dp < 1: # 小于一分
  281. ele['default_points'][i] = default_points_dict[ele['class_name']]
  282. if isinstance(ele['default_points'], int) or isinstance(ele['default_points'], float):
  283. if ele['default_points'] < 1: # 小于一分
  284. ele['default_points'] = default_points_dict[ele['class_name']]
  285. # select_s 在解答区域外侧
  286. if len(optional_choice_tmp) > 0:
  287. for oc in optional_choice_tmp:
  288. optional_choice_box = (oc['bounding_box']['xmin'], oc['bounding_box']['ymin'],
  289. oc['bounding_box']['xmax'], oc['bounding_box']['ymax'],
  290. oc['bounding_box']['xmin']
  291. + (oc['bounding_box']['xmax'] - oc['bounding_box']['xmin']) // 2)
  292. for sr in sheet_regions:
  293. if sr['class_name'] == 'solve':
  294. solve_box = (sr['bounding_box']['xmin'], sr['bounding_box']['ymin'],
  295. sr['bounding_box']['xmax'], sr['bounding_box']['ymax'])
  296. if (optional_choice_box[1] <= solve_box[1] and
  297. solve_box[0] < optional_choice_box[4] < solve_box[2] and
  298. abs(optional_choice_box[1] - solve_box[1]) < solve_box[3] - solve_box[1]):
  299. sr['class_name'] = 'optional_solve'
  300. choice_numbers = oc['number']
  301. solve_numbers = sr['number']
  302. if choice_numbers[0] < 500:
  303. sr['number'] = choice_numbers
  304. sr['default_points'] = [sr['default_points']] * len(choice_numbers)
  305. else:
  306. tmp = [solve_numbers] * len(choice_numbers)
  307. for i, num in enumerate(tmp):
  308. tmp[i] = num + i
  309. sr['number'] = tmp
  310. oc['number'] = tmp
  311. sr['default_points'] = [sr['default_points']] * len(choice_numbers)
  312. break
  313. if len(optional_choice_tmp_):
  314. sheet_regions.extend(optional_choice_tmp_)
  315. # 去重
  316. sheet_tmp = sheet_regions.copy()
  317. remove_index = []
  318. for i, region in enumerate(sheet_tmp):
  319. if i not in remove_index:
  320. box = region['bounding_box']
  321. name = region['class_name']
  322. for j, region_in in enumerate(sheet_tmp):
  323. box_in = region_in['bounding_box']
  324. name_in = region_in['class_name']
  325. iou = utils.cal_iou(box, box_in)
  326. if name == name_in and (iou[0] > 0.75 or iou[1] > 0.85 or iou[2] > 0.85) and i != j:
  327. box_area = (box['xmax'] - box['xmin']) * (box['ymax'] - box['ymin'])
  328. box_in_area = (box_in['xmax'] - box_in['xmin']) * (box_in['ymax'] - box_in['ymin'])
  329. if box_area >= box_in_area:
  330. sheet_regions.remove(region_in)
  331. remove_index.append(j)
  332. else:
  333. sheet_regions.remove(region)
  334. remove_index.append(i)
  335. break
  336. sheet_dict.update({'regions': sheet_regions})
  337. return sheet_dict
  338. def merge_span_boxes(col_sheets):
  339. if len(col_sheets) <= 1:
  340. return col_sheets
  341. for i, cur_col in enumerate(col_sheets[:-1]):
  342. next_col = col_sheets[i + 1]
  343. if not cur_col or not next_col:
  344. continue
  345. current_bottom_box = cur_col[-1] # 当前栏的最后一个,bottom
  346. next_col_top_box = next_col[0] # 下一栏的第一个,top
  347. b_name = current_bottom_box['class_name']
  348. t_name = next_col_top_box['class_name']
  349. if b_name == t_name == 'solve':
  350. b_number = current_bottom_box['number']
  351. t_number = next_col_top_box['number']
  352. if b_number >= 500 or t_number >= 500 or b_number == t_number:
  353. numbers = min(b_number, t_number)
  354. crt_points = current_bottom_box['default_points']
  355. next_points = next_col_top_box['default_points']
  356. # default_points = max(current_bottom_box['default_points'], next_col_top_box['default_points'])
  357. default_points = crt_points
  358. current_bottom_box.update({'number': numbers, 'default_points': default_points, "span": True})
  359. next_col_top_box.update({'number': numbers, 'default_points': default_points,
  360. "span": True, "span_id": current_bottom_box["span_id"] + 1})
  361. elif b_name == t_name == 'composition':
  362. b_number = current_bottom_box['number']
  363. t_number = next_col_top_box['number']
  364. numbers = min(b_number, t_number)
  365. default_points = max(current_bottom_box['default_points'], next_col_top_box['default_points'])
  366. current_bottom_box.update({'number': numbers, 'default_points': default_points, "span": True})
  367. next_col_top_box.update({'number': numbers, 'default_points': default_points,
  368. "span": True, "span_id": current_bottom_box["span_id"] + 1})
  369. else:
  370. continue
  371. return col_sheets