choice_box.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490
  1. # @Author : lightXu
  2. # @File : choice_box.py
  3. # @Time : 2018/11/22 0022 下午 16:01
  4. import re
  5. import time
  6. import xml.etree.cElementTree as ET
  7. import cv2
  8. import numpy as np
  9. from segment.sheet_resolve.analysis.choice.choice_m_row_column import get_choice_m_row_and_col
  10. from segment.sheet_resolve.tools import utils
  11. from segment.sheet_resolve.tools.brain_api import get_ocr_text_and_coordinate
  12. def get_interval(word_result_list):
  13. all_char_str = ''
  14. location = []
  15. for i, chars_dict in enumerate(word_result_list):
  16. chars_list = chars_dict['chars']
  17. for ele in chars_list:
  18. all_char_str = all_char_str + ele['char']
  19. location.append(ele['location'])
  20. pattern1 = re.compile(r"\]\[")
  21. pattern2 = re.compile(r"\[[ABCD]")
  22. def intervel(pattern):
  23. group_list = []
  24. for i in pattern.finditer(all_char_str):
  25. # print(i.group() + str(i.span()))
  26. group_list.append(list(i.span()))
  27. # print(group_list)
  28. sum_intervel = 0
  29. size = 0
  30. for group in group_list:
  31. left_x, right_x = location[group[0]]['left'] \
  32. + location[group[0]]['width'], location[group[1] - 1]['left']
  33. if abs(location[group[0]]['top'] - location[group[1]]['top']) < location[group[0]]['height']:
  34. if right_x - left_x > 0:
  35. sum_intervel = sum_intervel + right_x - left_x
  36. size += 1
  37. # print(sum_intervel // size)
  38. return sum_intervel // size
  39. intervel_width1 = intervel(pattern1)
  40. intervel_width2 = intervel(pattern2)
  41. return (intervel_width1 + intervel_width2) * 2 // 3
  42. def preprocess(image0, xe, ye):
  43. scale = 0
  44. dilate = 1
  45. blur = 5
  46. # 预处理图像
  47. img = image0
  48. # rescale the image
  49. if scale != 0:
  50. img = cv2.resize(img, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
  51. # Convert to gray
  52. img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  53. # # Apply dilation and erosion to remove some noise
  54. # if dilate != 0:
  55. # kernel = np.ones((dilate, dilate), np.uint8)
  56. # img = cv2.dilate(img, kernel, iterations=1)
  57. # img = cv2.erode(img, kernel, iterations=1)
  58. # Apply blur to smooth out the edges
  59. # if blur != 0:
  60. # img = cv2.GaussianBlur(img, (blur, blur), 0)
  61. # Apply threshold to get image with only b&w (binarization)
  62. img = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
  63. # cv2.namedWindow('image', cv2.WINDOW_NORMAL)
  64. # cv2.imshow('image', img)
  65. # if cv2.waitKey(0) == 27:
  66. # cv2.destroyAllWindows()
  67. # cv2.imwrite('otsu.jpg', img)
  68. kernel = np.ones((ye, xe), np.uint8) # y轴膨胀, x轴膨胀
  69. dst = cv2.dilate(img, kernel, iterations=1)
  70. # cv2.imshow('dilate', dst)
  71. # if cv2.waitKey(0) == 27:
  72. # cv2.destroyAllWindows()
  73. return dst
  74. def contours(image):
  75. _, cnts, hierarchy = cv2.findContours(image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  76. bboxes = []
  77. for cnt_id, cnt in enumerate(reversed(cnts)):
  78. x, y, w, h = cv2.boundingRect(cnt)
  79. bboxes.append((x, y, x + w, y + h))
  80. return bboxes
  81. def box_coordinates(img):
  82. img_arr = np.asarray(img)
  83. def axix_break_point(img, tolerance_number, axis):
  84. sum_x_axis = img.sum(axis=axis)
  85. sum_x_axis[sum_x_axis > 255 * tolerance_number] = 1 # 白色有字
  86. sum_x_axis[sum_x_axis != 1] = 0 # 黑色无字
  87. sum_x_axis_list = list(sum_x_axis)
  88. sum_x_axis_list.append(0) # 最后几行到结束有字时,使索引值增加最后一位
  89. split_x_index = []
  90. num = 1
  91. for index, ele in enumerate(sum_x_axis_list):
  92. num = num % 2
  93. if ele == num:
  94. # print(i)
  95. num = num + 1
  96. split_x_index.append(index)
  97. # print('length: ', len(split_x_index), split_x_index)
  98. return split_x_index
  99. y_break_points_list = axix_break_point(img_arr, 1, axis=1)
  100. x_break_points_list = axix_break_point(img_arr, 1, axis=0)
  101. all_coordinates = []
  102. for i in range(0, len(y_break_points_list), 2): # y轴分组
  103. ymin = y_break_points_list[i]
  104. ymax = y_break_points_list[i + 1]
  105. for j in range(0, len(x_break_points_list), 2):
  106. xmin = x_break_points_list[j]
  107. xmax = x_break_points_list[j + 1]
  108. all_coordinates.append([xmin, ymin, xmax, ymax])
  109. return all_coordinates
  110. def get_choice_box_coordinate(word_result_list, choice_img, cv_box_list, choice_bbox_list):
  111. shape = choice_img.shape
  112. y, x = shape[0], shape[1]
  113. # cv2.imshow('ocr_region', ocr_region)
  114. # if cv2.waitKey(0) == 27:
  115. # cv2.destroyAllWindows()
  116. all_digital_list = []
  117. digital_model = re.compile(r'\d')
  118. for i, chars_dict in enumerate(word_result_list):
  119. chars_list = chars_dict['chars']
  120. for ele in chars_list:
  121. if digital_model.search(ele['char']):
  122. all_digital_list.append(ele)
  123. new_all_digital_list = []
  124. i = 1
  125. while i <= len(all_digital_list):
  126. pre_one = all_digital_list[i - 1]
  127. if i == len(all_digital_list):
  128. new_all_digital_list.append(pre_one)
  129. break
  130. rear_one = all_digital_list[i]
  131. condition1 = abs(pre_one['location']['top'] - rear_one['location']['top']) < pre_one['location'][
  132. 'height'] # 两字高度差小于一字高度
  133. condition2 = pre_one['location']['left'] + 2 * pre_one['location']['width'] > rear_one['location'][
  134. 'left'] # 某字宽度的2倍大于两字间间隔
  135. if condition1:
  136. if condition2:
  137. new_char = pre_one['char'] + rear_one['char']
  138. new_location = {'left': pre_one['location']['left'],
  139. 'top': min(pre_one['location']['top'], rear_one['location']['top']),
  140. 'width': rear_one['location']['left'] + rear_one['location']['width'] -
  141. pre_one['location']['left'],
  142. 'height': max(pre_one['location']['height'], rear_one['location']['height'])}
  143. new_all_digital_list.append({'char': new_char, 'location': new_location})
  144. i = i + 1 + 1
  145. else:
  146. new_all_digital_list.append(pre_one)
  147. i = i + 1
  148. else:
  149. new_all_digital_list.append(pre_one) # 遇到字符y轴相差过大就结束
  150. i = i + 1
  151. content_list = list()
  152. for index, box in enumerate(choice_bbox_list['regions']): # rcnn识别的框匹配题号
  153. box = box['bounding_box']
  154. box_coordinate = (box['xmin'], box['ymin'], box['xmax'], box['ymax'])
  155. horizontal = box['xmax'] - box['xmin'] >= box['ymax'] - box['ymin']
  156. vertical = box['xmax'] - box['xmin'] < box['ymax'] - box['ymin']
  157. choice_number = {'number': 99, 'location': box_coordinate}
  158. content_list.insert(index, choice_number)
  159. for digital in new_all_digital_list:
  160. digital_coordiante = (digital['location']['left'], digital['location']['top'],
  161. digital['location']['left'] + digital['location']['width'],
  162. digital['location']['top'] + digital['location']['height'])
  163. if utils.decide_coordinate_contains(digital_coordiante, box_coordinate):
  164. if horizontal:
  165. box['xmin'] = digital['location']['left'] + digital['location']['width'] + 1 # 从数字处截取
  166. if vertical:
  167. box['ymin'] = digital['location']['top'] + digital['location']['height'] + 1
  168. box_coordinate = (box['xmin'], box['ymin'], box['xmax'], box['ymax'])
  169. content_list[index]['number'] = digital['char']
  170. content_list[index]['location'] = box_coordinate
  171. break
  172. for box in content_list:
  173. box_coordinate = (box['location'][0], box['location'][1], box['location'][2], box['location'][3])
  174. mtx = []
  175. for cv_box in cv_box_list:
  176. if utils.decide_coordinate_contains(cv_box, box_coordinate): # 若fasterrcnn未识别到选项框,单独的ABCD也舍去
  177. mtx.append(cv_box)
  178. matrix = np.asarray(sorted(mtx))
  179. dif = matrix[1:, 0] - matrix[:-1, 2] # 后一个char的left与起一个char的right的差
  180. dif[dif < 0] = 0
  181. dif_length = np.mean(dif) # 小于平均间隔的合并
  182. block_list = utils.box_by_x_intervel(matrix, dif_length)
  183. # block_list = utils.box_by_x_intervel(matrix, 5)
  184. box['abcd'] = block_list
  185. return content_list
  186. def choice(left, top, image, choice_bbox_list, xml_path):
  187. a_z = '_ABCDEFGHIJKLMTUNOPQRSVWXYZ'
  188. t1 = time.time()
  189. word_result_list0 = get_ocr_text_and_coordinate(image, ocr_accuracy='accurate', language_type='ENG')
  190. t2 = time.time()
  191. print('choice ocr time cost: ', t2 - t1)
  192. # print(word_result_list0)
  193. # try:
  194. # intervel_x = get_interval(word_result_list0)
  195. # except Exception:
  196. # intervel_x = 15
  197. intervel_x = 3
  198. img = preprocess(image, intervel_x, 3)
  199. cv_box_list0 = box_coordinates(img)
  200. content_list = get_choice_box_coordinate(word_result_list0, image, cv_box_list0, choice_bbox_list)
  201. tree = ET.parse(xml_path) # xml tree
  202. w = content_list[0]['location'][2] - content_list[0]['location'][0]
  203. h = content_list[0]['location'][3] - content_list[0]['location'][1]
  204. def xml(xml_tree, sorted_abcd_list, bias=0):
  205. ii = 0
  206. for i, choice_bbox in enumerate(sorted_abcd_list):
  207. area = (choice_bbox[2] - choice_bbox[0]) * (choice_bbox[3] - choice_bbox[1])
  208. if area > 400:
  209. name = '{:02d}_{}'.format(int(choice['number']), a_z[ii + bias])
  210. xml_tree = utils.create_xml(name, xml_tree,
  211. choice_bbox[0] + left, choice_bbox[1] + top, choice_bbox[2] + left,
  212. choice_bbox[3] + top)
  213. ii += 1
  214. return xml_tree
  215. def get_json(ajson_list, sorted_abcd_list, bias=0):
  216. ii = 0
  217. for i, choice_bbox in enumerate(sorted_abcd_list):
  218. area = (choice_bbox[2] - choice_bbox[0]) * (choice_bbox[3] - choice_bbox[1])
  219. if area > 400:
  220. name = '{:02d}_{}'.format(int(choice['number']), a_z[ii + bias])
  221. region = [choice_bbox[0] + left, choice_bbox[1] + top, choice_bbox[2] + left, choice_bbox[3] + top]
  222. ajson_list.append({'number': name, 'region': region})
  223. ii += 1
  224. return ajson_list
  225. json_list = []
  226. for index_num, choice in enumerate(content_list):
  227. abcd = choice['abcd']
  228. if int(choice['number']) == 99:
  229. if w >= h:
  230. tree = xml(tree, sorted(abcd))
  231. json_list = get_json(json_list, sorted(abcd))
  232. else:
  233. tree = xml(tree, sorted(abcd, key=lambda x: (x[1], x[0])))
  234. json_list = get_json(json_list, sorted(abcd, key=lambda x: (x[1], x[0])))
  235. else:
  236. if w >= h:
  237. tree = xml(tree, sorted(abcd), bias=1)
  238. json_list = get_json(json_list, sorted(abcd), bias=1)
  239. else:
  240. tree = xml(tree, sorted(abcd, key=lambda x: (x[1], x[0])), bias=1)
  241. json_list = get_json(json_list, sorted(abcd, key=lambda x: (x[1], x[0])), bias=1)
  242. tree.write(xml_path)
  243. return json_list
  244. def get_number_by_enlarge_choice_m(image, choice_m_region_list, xml_path):
  245. a_z = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
  246. choice_m_dict_list = [] # choice_m region with same index
  247. choice_m_enlarge = []
  248. left, top, right, bottom = 9999, 9999, 0, 0
  249. for _, box in enumerate(choice_m_region_list):
  250. box = box['bounding_box']
  251. m_left, m_top = box['xmin'], box['ymin'],
  252. width, height = box['xmax'] - box['xmin'], box['ymax'] - box['ymin']
  253. box_coordinate = (m_left, m_top, box['xmax'], box['ymax'])
  254. single_choice_m = utils.crop_region_direct(image, box_coordinate)
  255. row_col_dict = get_choice_m_row_and_col(m_left, m_top, single_choice_m)
  256. choice_m_dict_list.append(row_col_dict)
  257. box_coordinate_enlarge = (
  258. m_left - int(width / 2), m_top - int(height / 2), box['xmax'], box['ymax']) # 扩大的choice_m, 多个分散choice_m
  259. choice_m_enlarge.append(box_coordinate_enlarge)
  260. left = min(left, box_coordinate_enlarge[0])
  261. top = min(top, box_coordinate_enlarge[1])
  262. right = max(right, box_coordinate_enlarge[2])
  263. bottom = max(bottom, box_coordinate_enlarge[3])
  264. choice_whole_region = utils.crop_region_direct(image, (left, top, right, bottom))
  265. # cv2.imwrite(r'C:\Users\Administrator\Desktop\test\sheet\choice_enlarge.jpg', choice_whole_region)
  266. # cv2.imshow('img', choice_whole_region)
  267. # cv2.waitKey(0)
  268. # cv2.destroyAllWindows()
  269. choice_region_text = get_ocr_text_and_coordinate(choice_whole_region)
  270. all_digital_list = []
  271. pattern = re.compile(r'\d')
  272. for i, chars_dict in enumerate(choice_region_text):
  273. chars_list = chars_dict['chars']
  274. for ele in chars_list:
  275. if pattern.search(ele['char']):
  276. all_digital_list.append(ele)
  277. combined_digital_list = utils.combine_char(all_digital_list)
  278. direction_list = []
  279. for index, enlarge_box in enumerate(choice_m_enlarge):
  280. digital_list = []
  281. xmin, ymin, xmax, ymax = 9999, 9999, 0, 0
  282. choice_m_dict = choice_m_dict_list[index]
  283. choice_m_dict_box = (choice_m_dict['bounding_box']['xmin'], choice_m_dict['bounding_box']['ymin'],
  284. choice_m_dict['bounding_box']['xmax'], choice_m_dict['bounding_box']['ymax'],)
  285. for jndex, digital_box in enumerate(combined_digital_list):
  286. digital_coordinate = (digital_box['location']['left'] + left,
  287. digital_box['location']['top'] + top,
  288. digital_box['location']['left'] + digital_box['location']['width'] + left,
  289. digital_box['location']['top'] + digital_box['location']['height'] + top)
  290. digital_box.update({'coordinate': digital_coordinate})
  291. if (utils.decide_coordinate_contains(digital_coordinate, enlarge_box)) and not \
  292. (utils.decide_coordinate_contains(digital_coordinate, choice_m_dict_box)):
  293. digital_list.append(digital_box)
  294. xmin = min(xmin, digital_box['coordinate'][0])
  295. ymin = min(ymin, digital_box['coordinate'][1])
  296. xmax = max(xmax, digital_box['coordinate'][2])
  297. ymax = max(ymax, digital_box['coordinate'][3])
  298. digital_list_coordinate = (xmin, ymin, xmax, ymax)
  299. direction = utils.decide_choice_m_left_top(digital_list_coordinate, choice_m_dict_box)
  300. if int(direction):
  301. choice_m_dict['direction'] = direction
  302. direction_list.append(direction)
  303. if direction == '180': # 数字垂直排列
  304. std_num_length = choice_m_dict['rows']
  305. choice_option = a_z[:choice_m_dict['cols']].replace('', ',')[1:-1]
  306. default_points = [-1] * std_num_length
  307. choice_m_dict.update({'option': choice_option, 'default_points': default_points})
  308. sorted(digital_list, key=lambda k: k.get('coordinate')[1])
  309. choice_ymin = choice_m_dict['bounding_box']['ymin']
  310. single_height = choice_m_dict['single_height']
  311. mean_interval = ((choice_m_dict['bounding_box']['ymax'] - choice_m_dict['bounding_box']['ymin'])
  312. - single_height * std_num_length) / (std_num_length - 1)
  313. spilt_index = [choice_ymin - mean_interval / 2 + (single_height + mean_interval) * ele for ele in
  314. range(std_num_length + 1)]
  315. number_list = [-1] * std_num_length
  316. number_location = [(-1, -1, -1, -1)] * std_num_length
  317. for i in range(0, len(spilt_index) - 1):
  318. start = spilt_index[i]
  319. end = spilt_index[i + 1]
  320. number_location[i] = (xmin, start, xmax, end)
  321. for digital_coordinate in digital_list:
  322. middle_y = (digital_coordinate['coordinate'][3] - digital_coordinate['coordinate'][1]) / 2 + \
  323. digital_coordinate['coordinate'][1]
  324. middle_x = (digital_coordinate['coordinate'][2] - digital_coordinate['coordinate'][0]) / 2 + \
  325. digital_coordinate['coordinate'][0]
  326. if (start <= middle_y <= end
  327. and
  328. middle_x < choice_m_dict['bounding_box']['xmin']): # 数字在choice_m外侧
  329. number_list[i] = int(digital_coordinate['char'])
  330. number_location[i] = digital_coordinate['coordinate']
  331. number_list = _infer_number(number_list)
  332. choice_m_dict['number'] = _infer_number(number_list)
  333. # choice_m_dict['number'] = [{'number': number,
  334. # 'location': {'xmin': xi, 'ymin': yi, 'xmax': xm, 'ymax': ym}}
  335. # for number in number_list
  336. # for (xi, yi, xm, ym) in number_location]
  337. if direction == '90': # 数字水平排列
  338. std_num_length = choice_m_dict['cols']
  339. choice_option = a_z[:std_num_length].replace('', ',')[1:-1]
  340. default_points = [-1] * std_num_length
  341. choice_m_dict.update({'option': choice_option, 'default_points': default_points})
  342. sorted(digital_list, key=lambda k: k.get('coordinate')[0])
  343. choice_xmin = choice_m_dict['bounding_box']['ymin']
  344. single_width = choice_m_dict['single_width']
  345. mean_interval = ((choice_m_dict['bounding_box']['xmax'] - choice_m_dict['bounding_box']['xmin'])
  346. - single_width * std_num_length) / (std_num_length - 1)
  347. spilt_index = [choice_xmin - mean_interval / 2 + (single_width + mean_interval) * ele for ele in
  348. range(std_num_length)]
  349. number_list = [-1] * std_num_length
  350. number_location = [(-1, -1, -1, -1)] * std_num_length
  351. for i in range(0, len(spilt_index) - 1):
  352. start = spilt_index[i]
  353. end = spilt_index[i + 1]
  354. number_location[i] = (start, ymin, end, ymax)
  355. for digital_coordinate in digital_list:
  356. middle_y = (digital_coordinate['coordinate'][3] - digital_coordinate['coordinate'][1]) / 2 + \
  357. digital_coordinate['coordinate'][1]
  358. middle_x = (digital_coordinate['coordinate'][2] - digital_coordinate['coordinate'][0]) / 2 + \
  359. digital_coordinate['coordinate'][0]
  360. if start <= middle_x <= end and middle_y < choice_m_dict['bounding_box']['ymin']:
  361. number_list[i] = int(digital_coordinate['char'])
  362. number_location[i] = digital_coordinate['coordinate']
  363. number_list = _infer_number(number_list)
  364. choice_m_dict['number'] = _infer_number(number_list)
  365. # choice_m_dict['number'] = [{'number': number,
  366. # 'location': {'xmin': xi, 'ymin': yi, 'xmax': xm, 'ymax': ym}}
  367. # for number in number_list
  368. # for (xi, yi, xm, ym) in number_location]
  369. else:
  370. choice_m_dict['direction'] = '0'
  371. choice_m_dict['number'] = [-1]
  372. choice_m_dict['default_points'] = [-1]
  373. count180 = ','.join(direction_list).count('180')
  374. count90 = ','.join(direction_list).count('90')
  375. infer_direction = ['180', '90'][[count180, count90].index(max(count180, count90))]
  376. for ele in choice_m_dict_list:
  377. if ele['direction'] != '0':
  378. ele.update({'direction': infer_direction})
  379. # tree = ET.parse(xml_path) # xml tree
  380. # for index_num, choice_box in enumerate(choice_m_dict_list):
  381. # if len(choice_box['bounding_box']) > 0:
  382. # abcd = choice_box['bounding_box']
  383. # number = str(choice_box['number'])
  384. # name = '{}_{}*{}_{}_{}'.format('choice_m', choice_box['rows'],
  385. # choice_box['cols'], choice_box['direction'],
  386. # number)
  387. # tree = utils.create_xml(name, tree,
  388. # abcd['xmin'], abcd['ymin'],
  389. # abcd['xmax'], abcd['ymax'])
  390. #
  391. # tree.write(xml_path)
  392. return choice_m_dict_list
  393. def _infer_number(number_list):
  394. if -1 not in number_list or sum(number_list) == -1 * len(number_list):
  395. return number_list
  396. else:
  397. for n_index in range(0, len(number_list) - 1):
  398. if n_index == 0:
  399. if number_list[n_index] != -1:
  400. if len(number_list) > 1 and number_list[n_index + 1] == -1:
  401. number_list[n_index + 1] = number_list[n_index] + 1
  402. if number_list[n_index] != -1:
  403. if number_list[n_index - 1] == -1:
  404. number_list[n_index - 1] = number_list[n_index] - 1
  405. if number_list[n_index + 1] == -1:
  406. number_list[n_index + 1] = number_list[n_index] + 1
  407. return _infer_number(number_list)