choice_infer.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673
  1. # @Author : lightXu
  2. # @File : choice_infer.py
  3. import os
  4. import traceback
  5. import time
  6. import random
  7. from django.conf import settings
  8. from segment.sheet_resolve.tools import utils, brain_api
  9. from itertools import chain
  10. import re
  11. import numpy as np
  12. import cv2
  13. import xml.etree.cElementTree as ET
  14. from segment.sheet_resolve.tools.utils import crop_region_direct, create_xml, infer_number, combine_char_in_raw_format
  15. from sklearn.cluster import DBSCAN
  16. from segment.sheet_resolve.analysis.sheet.ocr_sheet import ocr2sheet
  17. def get_split_index(array, dif=0):
  18. array = np.array(array)
  19. interval_list = np.abs(array[1:] - array[:-1])
  20. split_index = [0]
  21. for i, interval in enumerate(interval_list):
  22. if dif:
  23. split_dif = dif
  24. else:
  25. split_dif = np.mean(interval_list)
  26. if interval > split_dif:
  27. split_index.append(i + 1)
  28. split_index.append(len(array))
  29. split_index = sorted(list(set(split_index)))
  30. return split_index
  31. def adjust_choice_m(image, xe, ye):
  32. dilate = 1
  33. blur = 5
  34. # Convert to gray
  35. image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
  36. if blur != 0:
  37. image = cv2.GaussianBlur(image, (blur, blur), 0)
  38. # Apply threshold to get image with only b&w (binarization)
  39. image = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
  40. kernel = np.ones((ye, xe), np.uint8) # y轴膨胀, x轴膨胀
  41. dst = cv2.dilate(image, kernel, iterations=1)
  42. (major, minor, _) = cv2.__version__.split(".")
  43. contours = cv2.findContours(dst, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  44. cnts = contours[0] if int(major) > 3 else contours[1]
  45. # _, cnts, hierarchy = cv2.findContours(dst, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  46. right_limit = 0
  47. bottom_limit = 0
  48. for cnt_id, cnt in enumerate(reversed(cnts)):
  49. x, y, w, h = cv2.boundingRect(cnt)
  50. if x + w > right_limit:
  51. right_limit = x + w
  52. if y + h > bottom_limit:
  53. bottom_limit = y + h
  54. return right_limit, bottom_limit
  55. def find_digital(ocr_raw_list):
  56. pattern = r'\d+'
  57. x_list = []
  58. y_list = []
  59. digital_list = list()
  60. chars_list = list()
  61. height_list, width_list = list(), list()
  62. ocr_dict_list = combine_char_in_raw_format(ocr_raw_list)
  63. for i, ele in enumerate(ocr_dict_list):
  64. words = ele['words']
  65. words = words.replace(' ', '').upper() # 去除空格
  66. digital_words_m = re.finditer(pattern, words)
  67. digital_index_list = [(m.group(), m.span()) for m in digital_words_m if m]
  68. chars_index = [ele for ele in range(0, len(ele['chars']))]
  69. digital_index_detail_list = []
  70. for letter_info in digital_index_list:
  71. number = letter_info[0]
  72. index_start = letter_info[1][0]
  73. index_end = letter_info[1][1] - 1
  74. char_start = ele['chars'][index_start]
  75. char_end = ele['chars'][index_end]
  76. if index_start == index_end:
  77. digital_index_detail_list += [index_start]
  78. else:
  79. digital_index_detail_list += chars_index[index_start:index_end + 1]
  80. letter_loc_xmin = int(char_start['location']['left'])
  81. letter_loc_ymin = min(int(char_start['location']['top']), int(char_end['location']['top']))
  82. letter_loc_xmax = int(char_end['location']['left']) + int(char_end['location']['width'])
  83. letter_loc_ymax = max(int(char_start['location']['top']) + int(char_start['location']['height']),
  84. int(char_end['location']['top']) + int(char_end['location']['height']))
  85. mid_x = letter_loc_xmin + (letter_loc_xmax - letter_loc_xmin) // 2
  86. mid_y = letter_loc_ymin + (letter_loc_ymax - letter_loc_ymin) // 2
  87. # print(number, (mid_x, mid_y))
  88. x_list.append(mid_x)
  89. y_list.append(mid_y)
  90. height_list.append(letter_loc_ymax - letter_loc_ymin)
  91. width_list.append(letter_loc_xmax - letter_loc_xmin)
  92. number_loc = (letter_loc_xmin, letter_loc_ymin, letter_loc_xmax, letter_loc_ymax, mid_x, mid_y)
  93. digital_list.append({"digital": int(number), "loc": number_loc})
  94. current_chars = [char for index, char in enumerate(ele['chars'])
  95. if index not in digital_index_detail_list and char['char'] not in ['.', ',', '。', '、']]
  96. chars_list += current_chars
  97. d_mean_height = sum(height_list) // len(height_list)
  98. d_mean_width = sum(width_list) // len(width_list)
  99. # mean_height = max(height_list)
  100. # mean_width = max(width_list)
  101. # print(x_list)
  102. # print(y_list)
  103. return digital_list, chars_list, d_mean_height, d_mean_width
  104. def cluster2choice_m_(cluster_list, m_h, m_w):
  105. numbers = [ele['digital'] for ele in cluster_list]
  106. loc_top_interval = (np.array([ele['loc'][3] for ele in cluster_list][1:]) -
  107. np.array([ele['loc'][3] for ele in cluster_list][:-1]))
  108. split_index = [0]
  109. for i, interval in enumerate(loc_top_interval):
  110. if interval > m_h * 1.5:
  111. split_index.append(i + 1)
  112. split_index.append(len(cluster_list))
  113. split_index = sorted(list(set(split_index)))
  114. block_list = []
  115. for i in range(len(split_index) - 1):
  116. block = cluster_list[split_index[i]: split_index[i + 1]]
  117. xmin = min([ele["loc"][0] for ele in block])
  118. ymin = min([ele["loc"][1] for ele in block])
  119. xmax = max([ele["loc"][2] for ele in block])
  120. ymax = max([ele["loc"][3] for ele in block])
  121. numbers = [ele['digital'] for ele in block]
  122. choice_m = {"number": numbers, "loc": (xmin, ymin, xmax, ymax)}
  123. block_list.append(choice_m)
  124. return block_list
  125. def cluster2choice_m(cluster_list, mean_width):
  126. # 比较x坐标,去掉误差值
  127. numbers_x = [ele['loc'][4] for ele in cluster_list]
  128. numbers_x_array = np.array(numbers_x)
  129. numbers_x_interval = np.abs((numbers_x_array[1:] - numbers_x_array[:-1]))
  130. error_index_superset = np.where(numbers_x_interval >= mean_width)[0]
  131. error_index_superset_interval = error_index_superset[1:] - error_index_superset[:-1]
  132. t_index = list(np.where(error_index_superset_interval > 1)[0] + 1)
  133. t_index.insert(0, 0)
  134. t_index.append(len(error_index_superset))
  135. error = []
  136. for i in range(0, len(t_index) - 1):
  137. a = t_index[i]
  138. b = t_index[i + 1]
  139. block = list(error_index_superset[a: b])
  140. error += block[1:]
  141. cluster_list = [ele for i, ele in enumerate(cluster_list) if i not in error]
  142. numbers = [ele['digital'] for ele in cluster_list]
  143. numbers_array = np.array(numbers)
  144. # numbers_y = [ele['loc'][5] for ele in cluster_list]
  145. # numbers_y_array = np.array(numbers_y)
  146. # numbers_y_interval = np.abs((numbers_y_array[1:] - numbers_y_array[:-1]))
  147. # split_index = [0]
  148. # for i, interval in enumerate(numbers_y_interval):
  149. # if interval > np.mean(numbers_y_interval):
  150. # split_index.append(i + 1)
  151. #
  152. # split_index.append(len(cluster_list))
  153. # split_index = sorted(list(set(split_index)))
  154. # for i in range(len(split_index) - 1):
  155. # block = cluster_list[split_index[i]: split_index[i + 1]]
  156. # block_numbers = numbers_array[split_index[i]: split_index[i + 1]]
  157. # 确定数字题号的位置,前提:同block题号是某等差数列的子集
  158. numbers_sum = numbers_array + np.flipud(numbers_array)
  159. counts = np.bincount(numbers_sum)
  160. mode_times = np.max(counts)
  161. mode_value = np.argmax(counts)
  162. if mode_times != len(numbers) and mode_times >= 2:
  163. # 启动题号补全
  164. number_interval_list = abs(numbers_array[1:] - numbers_array[:-1])
  165. number_interval_counts = np.bincount(number_interval_list)
  166. # number_interval_mode_times = np.max(number_interval_counts)
  167. number_interval_mode_value = np.argmax(number_interval_counts)
  168. suspect_index = np.where(numbers_sum != mode_value)[0]
  169. numbers_array_len = len(numbers_array)
  170. for suspect in suspect_index:
  171. if suspect == 0:
  172. cond_left = False
  173. cond_right = numbers_array[suspect + 1] == numbers_array[suspect] + number_interval_mode_value
  174. elif suspect == numbers_array_len - 1:
  175. cond_right = False
  176. cond_left = numbers_array[suspect - 1] == numbers_array[suspect] - number_interval_mode_value
  177. else:
  178. cond_left = numbers_array[suspect - 1] == numbers_array[suspect] - number_interval_mode_value
  179. cond_right = numbers_array[suspect + 1] == numbers_array[suspect] + number_interval_mode_value
  180. if cond_left or cond_right:
  181. pass
  182. else:
  183. numbers_array[suspect] = -1
  184. numbers_array = infer_number(numbers_array, number_interval_mode_value) # 推断题号
  185. numbers_array = np.array(numbers_array)
  186. numbers_interval = np.abs(numbers_array[1:] - numbers_array[:-1])
  187. split_index = [0]
  188. for i, interval in enumerate(numbers_interval):
  189. if interval > np.mean(numbers_interval):
  190. split_index.append(i + 1)
  191. split_index.append(len(cluster_list))
  192. split_index = sorted(list(set(split_index)))
  193. block_list = []
  194. for i in range(len(split_index) - 1):
  195. block = cluster_list[split_index[i]: split_index[i + 1]]
  196. block_numbers = numbers_array[split_index[i]: split_index[i + 1]]
  197. xmin = min([ele["loc"][0] for ele in block])
  198. ymin = min([ele["loc"][1] for ele in block])
  199. xmax = max([ele["loc"][2] for ele in block])
  200. ymax = max([ele["loc"][3] for ele in block])
  201. mid_x = xmin + (xmax - xmin) // 2
  202. mid_y = ymin + (ymax - ymin) // 2
  203. choice_m = {"numbers": list(block_numbers), "loc": [xmin, ymin, xmax, ymax, mid_x, mid_y]}
  204. block_list.append(choice_m)
  205. return block_list
  206. def cluster_and_anti_abnormal(image, xml_path, digital_list, chars_list,
  207. mean_height, mean_width, choice_s_height, choice_s_width, limit_loc):
  208. limit_left, limit_top, limit_right, limit_bottom = limit_loc
  209. limit_width, limit_height = limit_right - limit_left, limit_bottom - limit_top
  210. arr = np.ones((len(digital_list), 2))
  211. for i, ele in enumerate(digital_list):
  212. arr[i] = np.array([ele["loc"][-2], ele["loc"][-1]])
  213. if choice_s_height != 0:
  214. eps = int(choice_s_height * 2.5)
  215. else:
  216. eps = int(mean_height * 3)
  217. print("eps: ", eps)
  218. db = DBSCAN(eps=eps, min_samples=2, metric='chebyshev').fit(arr)
  219. labels = db.labels_
  220. # print(labels)
  221. cluster_label = []
  222. for ele in labels:
  223. if ele not in cluster_label and ele != -1:
  224. cluster_label.append(ele)
  225. a_e_dict = {k: [] for k in cluster_label}
  226. choice_m_numbers_list = []
  227. for index, ele in enumerate(labels):
  228. if ele != -1:
  229. a_e_dict[ele].append(digital_list[index])
  230. for ele in cluster_label:
  231. cluster = a_e_dict[ele]
  232. choice_m_numbers_list += cluster2choice_m(cluster, mean_width)
  233. all_list_nums = [ele["numbers"] for ele in choice_m_numbers_list]
  234. all_nums_len = [len(ele) for ele in all_list_nums]
  235. all_nums = list(chain.from_iterable(all_list_nums))
  236. counts = np.bincount(np.array(all_nums_len))
  237. if np.max(counts) < 2:
  238. mode_value = max(all_nums_len)
  239. else:
  240. mode_value = np.argmax(counts)
  241. mode_value = all_nums_len[np.where(np.array(all_nums_len) == mode_value)[0][-1]]
  242. if mode_value > 1: # 缺失补全
  243. error_index_list = list(np.where(np.array(all_nums_len) != mode_value)[0])
  244. all_height = [ele["loc"][3] - ele["loc"][1] for index, ele
  245. in enumerate(choice_m_numbers_list) if index not in error_index_list]
  246. choice_m_mean_height = int(sum(all_height) / len(all_height))
  247. for e_index in list(error_index_list):
  248. current_choice_m = choice_m_numbers_list[e_index]
  249. current_numbers_list = list(all_list_nums[e_index])
  250. current_len = all_nums_len[e_index]
  251. dif = mode_value - current_len
  252. if 1 in current_numbers_list:
  253. t2 = current_numbers_list + [-1] * dif
  254. infer_t1_list = infer_number(t2) # 后补
  255. infer_t2_list = infer_number(t2) # 后补
  256. cond1 = False
  257. cond2 = True
  258. else:
  259. t1_cond = [True] * dif
  260. t2_cond = [True] * dif
  261. t1 = [-1] * dif + current_numbers_list
  262. infer_t1_list = infer_number(t1) # 前补
  263. t2 = current_numbers_list + [-1] * dif
  264. infer_t2_list = infer_number(t2) # 后补
  265. for i in range(0, dif):
  266. t1_infer = infer_t1_list[i]
  267. t2_infer = infer_t2_list[-i - 1]
  268. if t1_infer == 0 or t1_infer in all_nums:
  269. t1_cond[i] = False
  270. if t2_infer in all_nums:
  271. t2_cond[i] = False
  272. cond1 = not (False in t1_cond)
  273. cond2 = not (False in t2_cond)
  274. if cond1 and not cond2:
  275. current_loc = current_choice_m["loc"]
  276. current_height = current_loc[3] - current_loc[1]
  277. infer_height = max((choice_m_mean_height - current_height), int(dif * current_height / current_len))
  278. choice_m_numbers_list[e_index]["loc"][1] = current_loc[1] - infer_height
  279. choice_m_numbers_list[e_index]["loc"][5] = (choice_m_numbers_list[e_index]["loc"][1] +
  280. (choice_m_numbers_list[e_index]["loc"][3] -
  281. choice_m_numbers_list[e_index]["loc"][1]) // 2)
  282. choice_m_numbers_list[e_index]["numbers"] = infer_t1_list
  283. all_nums.extend(infer_t1_list[:dif])
  284. if not cond1 and cond2:
  285. current_loc = current_choice_m["loc"]
  286. current_height = current_loc[3] - current_loc[1]
  287. infer_height = max((choice_m_mean_height - current_height), int(dif * current_height / current_len))
  288. infer_bottom = min(current_loc[3] + infer_height, limit_height-1)
  289. if infer_bottom <= limit_height:
  290. choice_m_numbers_list[e_index]["loc"][3] = infer_bottom
  291. choice_m_numbers_list[e_index]["loc"][5] = (choice_m_numbers_list[e_index]["loc"][1] +
  292. (choice_m_numbers_list[e_index]["loc"][3] -
  293. choice_m_numbers_list[e_index]["loc"][1]) // 2)
  294. choice_m_numbers_list[e_index]["numbers"] = infer_t2_list
  295. all_nums.extend(infer_t2_list[-dif:])
  296. else:
  297. # cond1 = cond2 = true, 因为infer选择题时已横向排序, 默认这种情况不会出现
  298. pass
  299. for ele in choice_m_numbers_list:
  300. loc = ele["loc"]
  301. if loc[3] - loc[1] >= loc[2] - loc[0]:
  302. direction = 180
  303. else:
  304. direction = 90
  305. ele.update({'direction': direction})
  306. # tree = ET.parse(xml_path)
  307. # for index, choice_m in enumerate(choice_m_numbers_list):
  308. # name = str(choice_m["numbers"])
  309. # xmin, ymin, xmax, ymax, _, _ = choice_m["loc"]
  310. # tree = create_xml(name, tree, str(xmin + limit_left), str(ymin + limit_top), str(xmax + limit_left), str(ymax + limit_top))
  311. #
  312. # tree.write(xml_path)
  313. choice_m_numbers_list = sorted(choice_m_numbers_list, key=lambda x: x['loc'][3] - x['loc'][1], reverse=True)
  314. choice_m_numbers_right_limit = max([ele['loc'][2] for ele in choice_m_numbers_list])
  315. remain_len = len(choice_m_numbers_list)
  316. choice_m_list = list()
  317. need_revised_choice_m_list = list()
  318. while remain_len > 0:
  319. # 先确定属于同行的数据,然后找字母划分block
  320. # random_index = random.randint(0, len(choice_m_numbers_list)-1)
  321. random_index = 0
  322. # print(random_index)
  323. ymax_limit = choice_m_numbers_list[random_index]["loc"][3]
  324. ymin_limit = choice_m_numbers_list[random_index]["loc"][1]
  325. # choice_m_numbers_list.pop(random_index)
  326. # 当前行的choice_m
  327. current_row_choice_m_d = [ele for ele in choice_m_numbers_list if ymin_limit < ele["loc"][5] < ymax_limit]
  328. current_row_choice_m_d = sorted(current_row_choice_m_d, key=lambda x: x["loc"][0])
  329. # current_row_choice_m_d.append(choice_m_numbers_list[random_index])
  330. split_pix = sorted([ele["loc"][0] for ele in current_row_choice_m_d]) # xmin排序
  331. split_index = get_split_index(split_pix, dif=choice_s_width*0.8)
  332. split_pix = [split_pix[ele] for ele in split_index[:-1]]
  333. block_list = []
  334. for i in range(len(split_index) - 1):
  335. block = current_row_choice_m_d[split_index[i]: split_index[i + 1]]
  336. if len(block) > 1:
  337. remain_len = remain_len - (len(block) - 1)
  338. numbers_new = []
  339. loc_new = [[], [], [], []]
  340. for blk in block:
  341. loc_old = blk["loc"]
  342. numbers_new.extend(blk["numbers"])
  343. for ii in range(4):
  344. loc_new[ii].append(loc_old[ii])
  345. loc_new[0] = min(loc_new[0])
  346. loc_new[1] = min(loc_new[1])
  347. loc_new[2] = max(loc_new[2])
  348. loc_new[3] = max(loc_new[3])
  349. loc_new.append(loc_new[0] + (loc_new[2] - loc_new[0]) // 2)
  350. loc_new.append(loc_new[1] + (loc_new[3] - loc_new[1]) // 2)
  351. block = [{"numbers": sorted(numbers_new), "loc": loc_new, "direction": block[0]["direction"]}]
  352. block_list.extend(block)
  353. current_row_choice_m_d = block_list
  354. current_row_chars = [ele for ele in chars_list
  355. if ymin_limit < (ele["location"]["top"] + ele["location"]["height"] // 2) < ymax_limit]
  356. # if not current_row_chars:
  357. # max_char_width = choice_s_width // 4
  358. # row_chars_xmax = choice_m_numbers_right_limit + int(choice_s_width * 1.5)
  359. # else:
  360. # max_char_width = max([ele["location"]["width"] for ele in current_row_chars]) // 2
  361. # row_chars_xmax = max(
  362. # [ele["location"]["left"] + ele["location"]["width"] for ele in current_row_chars]) + max_char_width * 2
  363. # split_index.append(row_chars_xmax) # 边界
  364. split_pix.append(round(split_pix[-1] + choice_s_width * 1.2))
  365. for i in range(0, len(split_pix) - 1):
  366. left_limit = split_pix[i]
  367. right_limit = split_pix[i + 1]
  368. block_chars = [ele for ele in current_row_chars
  369. if left_limit < (ele["location"]["left"] + ele["location"]["width"] // 2) < right_limit]
  370. # chars_xmin = min([ele["location"]["left"] for ele in block_chars]) - max_char_width
  371. # chars_xmax = max(
  372. # [ele["location"]["left"] + ele["location"]["width"] for ele in block_chars]) + max_char_width
  373. # a_z = '_ABCD_FGH__K_MNOPQRSTUVWXYZ' EIJL -> _
  374. # a_z = '_ABCDEFGHI_K_MNOPQRSTUVWXYZ'
  375. a_z = '_ABCD_FGHT'
  376. # letter_text = set([ele['char'].upper() for ele in block_chars if ele['char'].upper() in a_z])
  377. letter_index = [a_z.index(ele['char'].upper()) for ele in block_chars if ele['char'].upper() in a_z]
  378. letter_index_times = {ele: 0 for ele in set(letter_index)}
  379. for l_index in letter_index:
  380. letter_index_times[l_index] += 1
  381. if (a_z.index("T") in letter_index) and (a_z.index("F") in letter_index):
  382. choice_option = "T, F"
  383. cols = 2
  384. else:
  385. if len(letter_index) < 1:
  386. tmp = 4
  387. choice_option = 'A,B,C,D'
  388. else:
  389. tmp = max(set(letter_index))
  390. # while letter_index_times[tmp] < 2 and tmp > 3:
  391. # t_list = list(set(letter_index))
  392. # t_list.remove(tmp)
  393. # tmp = max(t_list)
  394. choice_option = ",".join(a_z[min(letter_index):tmp + 1])
  395. cols = tmp
  396. bias = 3 # pix
  397. current_loc = current_row_choice_m_d[i]["loc"]
  398. location = dict(xmin=(current_loc[2] + bias) + limit_left, # 当前数字xmax右边
  399. # xmin=max(current_loc[2] + bias, chars_xmin) + limit_left,
  400. ymin=current_loc[1] + limit_top,
  401. xmax=(right_limit - bias) + limit_left,
  402. # xmax=min(chars_xmax, right_limit - bias) + limit_left,
  403. ymax=current_loc[3] + limit_top)
  404. try:
  405. choice_m_img = utils.crop_region(image, location)
  406. right_loc, bottom_loc = adjust_choice_m(choice_m_img, mean_height, mean_width * 2)
  407. if right_loc > 0:
  408. location.update(dict(xmax=right_loc + location['xmin']))
  409. if bottom_loc > 0:
  410. location.update(dict(ymax=bottom_loc + location['ymin']))
  411. except Exception as e:
  412. print(e)
  413. traceback.print_exc()
  414. tmp_w, tmp_h = location['xmax'] - location['xmin'], location['ymax'] - location['ymin'],
  415. numbers = current_row_choice_m_d[i]["numbers"]
  416. direction = current_row_choice_m_d[i]["direction"]
  417. if direction == 180:
  418. choice_m = dict(class_name='choice_m',
  419. number=numbers,
  420. bounding_box=location,
  421. choice_option=choice_option,
  422. default_points=[5] * len(numbers),
  423. direction=direction,
  424. cols=cols,
  425. rows=len(numbers),
  426. single_width=tmp_w // cols,
  427. single_height=tmp_h // len(numbers))
  428. else:
  429. choice_m = dict(class_name='choice_m',
  430. number=numbers,
  431. bounding_box=location,
  432. choice_option=choice_option,
  433. default_points=[5] * len(numbers),
  434. direction=direction,
  435. cols=len(numbers),
  436. rows=cols,
  437. single_width=tmp_w // len(numbers),
  438. single_height=tmp_h // cols
  439. )
  440. if tmp_w > 2 * choice_s_width:
  441. need_revised_choice_m_list.append(choice_m)
  442. else:
  443. choice_m_list.append(choice_m)
  444. remain_len = remain_len - len(current_row_choice_m_d)
  445. for ele in choice_m_numbers_list.copy():
  446. if ele in current_row_choice_m_d:
  447. choice_m_numbers_list.remove(ele)
  448. for ele in choice_m_numbers_list.copy():
  449. if ele in current_row_chars:
  450. choice_m_numbers_list.remove(ele)
  451. # 单独一行不聚类
  452. for i, revised_choice_m in enumerate(need_revised_choice_m_list):
  453. loc = revised_choice_m['bounding_box']
  454. left_part_loc = loc.copy()
  455. left_part_loc.update({'xmax': loc['xmin']+choice_s_width})
  456. choice_m_img = utils.crop_region(image, left_part_loc)
  457. right_loc, bottom_loc = adjust_choice_m(choice_m_img, mean_height, mean_width * 2)
  458. if right_loc > 0:
  459. left_part_loc.update(dict(xmax=right_loc + left_part_loc['xmin']))
  460. if bottom_loc > 0:
  461. left_part_loc.update(dict(ymax=bottom_loc + left_part_loc['ymin']))
  462. left_tmp_height = left_part_loc['ymax'] - left_part_loc['ymin']
  463. right_part_loc = loc.copy()
  464. # right_part_loc.update({'xmin': loc['xmax']-choice_s_width})
  465. right_part_loc.update({'xmin': left_part_loc['xmax']+5})
  466. choice_m_img = utils.crop_region(image, right_part_loc)
  467. right_loc, bottom_loc = adjust_choice_m(choice_m_img, mean_height, mean_width * 2)
  468. if right_loc > 0:
  469. right_part_loc.update(dict(xmax=right_loc + right_part_loc['xmin']))
  470. if bottom_loc > 0:
  471. right_part_loc.update(dict(ymax=bottom_loc + right_part_loc['ymin']))
  472. right_tmp_height = right_part_loc['ymax'] - right_part_loc['ymin']
  473. number_len = max(1, int(revised_choice_m['rows'] // (left_tmp_height // right_tmp_height)))
  474. number = [ele+revised_choice_m['number'][-1]+1 for ele in range(number_len)]
  475. rows = len(number)
  476. revised_choice_m.update({'bounding_box': left_part_loc})
  477. choice_m_list.append(revised_choice_m)
  478. tmp = revised_choice_m.copy()
  479. tmp.update({'bounding_box': right_part_loc, 'number': number, 'rows': rows})
  480. choice_m_list.append(tmp)
  481. tmp = choice_m_list.copy()
  482. for ele in tmp:
  483. loc = ele["bounding_box"]
  484. w, h = loc['xmax'] - loc['xmin'], loc['ymax'] - loc['ymin']
  485. if 2*w*h < choice_s_width*choice_s_height:
  486. choice_m_list.remove(ele)
  487. return choice_m_list
  488. def infer_choice_m(image, tf_sheet, ocr, xml=None):
  489. infer_box_list = ocr2sheet(image, tf_sheet, ocr, xml)
  490. # print(sheet_region_list)
  491. choice_m_list = []
  492. choice_s_h_list = [int(ele['bounding_box']['ymax']) - int(ele['bounding_box']['ymin']) for ele in tf_sheet
  493. if ele['class_name'] == 'choice_s']
  494. if choice_s_h_list:
  495. choice_s_height = sum(choice_s_h_list) // len(choice_s_h_list)
  496. else:
  497. choice_s_height = 0
  498. choice_s_w_list = [int(ele['bounding_box']['xmax']) - int(ele['bounding_box']['xmin']) for ele in tf_sheet
  499. if ele['class_name'] == 'choice_s']
  500. if choice_s_w_list:
  501. choice_s_width = sum(choice_s_w_list) // len(choice_s_w_list)
  502. else:
  503. choice_s_width = 0
  504. for infer_box in infer_box_list:
  505. # {'loc': [240, 786, 1569, 1368]}
  506. loc = infer_box['loc']
  507. xmin, ymin, xmax, ymax = loc[0], loc[1], loc[2], loc[3]
  508. choice_flag = False
  509. for ele in tf_sheet:
  510. if ele['class_name'] in ['choice_m', 'choice_s']:
  511. tf_loc = ele['bounding_box']
  512. tf_loc_l = tf_loc['xmin']
  513. tf_loc_t = tf_loc['ymin']
  514. if xmin < tf_loc_l < xmax and ymin < tf_loc_t < ymax:
  515. choice_flag = True
  516. break
  517. if choice_flag:
  518. infer_image = utils.crop_region_direct(image, loc)
  519. try:
  520. save_dir = os.path.join(settings.MEDIA_ROOT, 'tmp')
  521. if not os.path.exists(save_dir):
  522. os.makedirs(save_dir)
  523. save_path = os.path.join(save_dir, 'choice.jpeg')
  524. cv2.imwrite(save_path, infer_image)
  525. img_tmp = utils.read_single_img(save_path)
  526. os.remove(save_path)
  527. ocr = brain_api.get_ocr_text_and_coordinate(img_tmp, 'accurate', 'CHN_ENG')
  528. except Exception as e:
  529. print('write choice and ocr failed')
  530. traceback.print_exc()
  531. ocr = brain_api.get_ocr_text_and_coordinate(infer_image, 'accurate', 'CHN_ENG')
  532. try:
  533. digital_list, chars_list, digital_mean_h, digital_mean_w = find_digital(ocr)
  534. choice_m = cluster_and_anti_abnormal(image, xml, digital_list, chars_list,
  535. digital_mean_h, digital_mean_w,
  536. choice_s_height, choice_s_width, loc)
  537. choice_m_list.extend(choice_m)
  538. except Exception as e:
  539. traceback.print_exc()
  540. print('not found choice feature')
  541. pass
  542. # print(choice_m_list)
  543. # tf_choice_sheet = [ele for ele in tf_sheet if ele['class_name'] == 'choice_m']
  544. sheet_tmp = choice_m_list.copy()
  545. remove_index = []
  546. for i, region in enumerate(sheet_tmp):
  547. if i not in remove_index:
  548. box = region['bounding_box']
  549. for j, region_in in enumerate(sheet_tmp):
  550. box_in = region_in['bounding_box']
  551. iou = utils.cal_iou(box, box_in)
  552. if iou[0] > 0.85 and i != j:
  553. choice_m_list.remove(region)
  554. remove_index.append(j)
  555. break
  556. return choice_m_list