choice_infer.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671
  1. # @Author : lightXu
  2. # @File : choice_infer.py
  3. import os
  4. import traceback
  5. import time
  6. import random
  7. from django.conf import settings
  8. from segment.sheet_resolve.tools import utils, brain_api
  9. from itertools import chain
  10. import re
  11. import numpy as np
  12. import cv2
  13. import xml.etree.cElementTree as ET
  14. from segment.sheet_resolve.tools.utils import crop_region_direct, create_xml, infer_number, combine_char_in_raw_format
  15. from sklearn.cluster import DBSCAN
  16. from segment.sheet_resolve.analysis.sheet.ocr_sheet import ocr2sheet
  17. def get_split_index(array, dif=0):
  18. array = np.array(array)
  19. interval_list = np.abs(array[1:] - array[:-1])
  20. split_index = [0]
  21. for i, interval in enumerate(interval_list):
  22. if dif:
  23. split_dif = dif
  24. else:
  25. split_dif = np.mean(interval_list)
  26. if interval > split_dif:
  27. split_index.append(i + 1)
  28. split_index.append(len(array))
  29. split_index = sorted(list(set(split_index)))
  30. return split_index
  31. def adjust_choice_m(image, xe, ye):
  32. dilate = 1
  33. blur = 5
  34. # Convert to gray
  35. image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
  36. if blur != 0:
  37. image = cv2.GaussianBlur(image, (blur, blur), 0)
  38. # Apply threshold to get image with only b&w (binarization)
  39. image = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
  40. kernel = np.ones((ye, xe), np.uint8) # y轴膨胀, x轴膨胀
  41. dst = cv2.dilate(image, kernel, iterations=1)
  42. (major, minor, _) = cv2.__version__.split(".")
  43. contours = cv2.findContours(dst, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  44. cnts = contours[0] if int(major) > 3 else contours[1]
  45. # _, cnts, hierarchy = cv2.findContours(dst, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  46. right_limit = 0
  47. bottom_limit = 0
  48. for cnt_id, cnt in enumerate(reversed(cnts)):
  49. x, y, w, h = cv2.boundingRect(cnt)
  50. if x + w > right_limit:
  51. right_limit = x + w
  52. if y + h > bottom_limit:
  53. bottom_limit = y + h
  54. return right_limit, bottom_limit
  55. def find_digital(ocr_raw_list):
  56. pattern = r'\d+'
  57. x_list = []
  58. y_list = []
  59. digital_list = list()
  60. chars_list = list()
  61. height_list, width_list = list(), list()
  62. ocr_dict_list = combine_char_in_raw_format(ocr_raw_list)
  63. for i, ele in enumerate(ocr_dict_list):
  64. words = ele['words']
  65. words = words.replace(' ', '').upper() # 去除空格
  66. digital_words_m = re.finditer(pattern, words)
  67. digital_index_list = [(m.group(), m.span()) for m in digital_words_m if m]
  68. chars_index = [ele for ele in range(0, len(ele['chars']))]
  69. digital_index_detail_list = []
  70. for letter_info in digital_index_list:
  71. number = letter_info[0]
  72. index_start = letter_info[1][0]
  73. index_end = letter_info[1][1] - 1
  74. char_start = ele['chars'][index_start]
  75. char_end = ele['chars'][index_end]
  76. if index_start == index_end:
  77. digital_index_detail_list += [index_start]
  78. else:
  79. digital_index_detail_list += chars_index[index_start:index_end + 1]
  80. letter_loc_xmin = int(char_start['location']['left'])
  81. letter_loc_ymin = min(int(char_start['location']['top']), int(char_end['location']['top']))
  82. letter_loc_xmax = int(char_end['location']['left']) + int(char_end['location']['width'])
  83. letter_loc_ymax = max(int(char_start['location']['top']) + int(char_start['location']['height']),
  84. int(char_end['location']['top']) + int(char_end['location']['height']))
  85. mid_x = letter_loc_xmin + (letter_loc_xmax - letter_loc_xmin) // 2
  86. mid_y = letter_loc_ymin + (letter_loc_ymax - letter_loc_ymin) // 2
  87. # print(number, (mid_x, mid_y))
  88. x_list.append(mid_x)
  89. y_list.append(mid_y)
  90. height_list.append(letter_loc_ymax - letter_loc_ymin)
  91. width_list.append(letter_loc_xmax - letter_loc_xmin)
  92. number_loc = (letter_loc_xmin, letter_loc_ymin, letter_loc_xmax, letter_loc_ymax, mid_x, mid_y)
  93. digital_list.append({"digital": int(number), "loc": number_loc})
  94. current_chars = [char for index, char in enumerate(ele['chars'])
  95. if index not in digital_index_detail_list and char['char'] not in ['.', ',', '。', '、']]
  96. chars_list += current_chars
  97. d_mean_height = sum(height_list) // len(height_list)
  98. d_mean_width = sum(width_list) // len(width_list)
  99. # mean_height = max(height_list)
  100. # mean_width = max(width_list)
  101. # print(x_list)
  102. # print(y_list)
  103. return digital_list, chars_list, d_mean_height, d_mean_width
  104. def cluster2choice_m_(cluster_list, m_h, m_w):
  105. numbers = [ele['digital'] for ele in cluster_list]
  106. loc_top_interval = (np.array([ele['loc'][3] for ele in cluster_list][1:]) -
  107. np.array([ele['loc'][3] for ele in cluster_list][:-1]))
  108. split_index = [0]
  109. for i, interval in enumerate(loc_top_interval):
  110. if interval > m_h * 1.5:
  111. split_index.append(i + 1)
  112. split_index.append(len(cluster_list))
  113. split_index = sorted(list(set(split_index)))
  114. block_list = []
  115. for i in range(len(split_index) - 1):
  116. block = cluster_list[split_index[i]: split_index[i + 1]]
  117. xmin = min([ele["loc"][0] for ele in block])
  118. ymin = min([ele["loc"][1] for ele in block])
  119. xmax = max([ele["loc"][2] for ele in block])
  120. ymax = max([ele["loc"][3] for ele in block])
  121. numbers = [ele['digital'] for ele in block]
  122. choice_m = {"number": numbers, "loc": (xmin, ymin, xmax, ymax)}
  123. block_list.append(choice_m)
  124. return block_list
  125. def cluster2choice_m(cluster_list, mean_width):
  126. # 比较x坐标,去掉误差值
  127. numbers_x = [ele['loc'][4] for ele in cluster_list]
  128. numbers_x_array = np.array(numbers_x)
  129. numbers_x_interval = np.abs((numbers_x_array[1:] - numbers_x_array[:-1]))
  130. error_index_superset = np.where(numbers_x_interval >= mean_width)[0]
  131. error_index_superset_interval = error_index_superset[1:] - error_index_superset[:-1]
  132. t_index = list(np.where(error_index_superset_interval > 1)[0] + 1)
  133. t_index.insert(0, 0)
  134. t_index.append(len(error_index_superset))
  135. error = []
  136. for i in range(0, len(t_index) - 1):
  137. a = t_index[i]
  138. b = t_index[i + 1]
  139. block = list(error_index_superset[a: b])
  140. error += block[1:]
  141. cluster_list = [ele for i, ele in enumerate(cluster_list) if i not in error]
  142. numbers = [ele['digital'] for ele in cluster_list]
  143. numbers_array = np.array(numbers)
  144. # numbers_y = [ele['loc'][5] for ele in cluster_list]
  145. # numbers_y_array = np.array(numbers_y)
  146. # numbers_y_interval = np.abs((numbers_y_array[1:] - numbers_y_array[:-1]))
  147. # split_index = [0]
  148. # for i, interval in enumerate(numbers_y_interval):
  149. # if interval > np.mean(numbers_y_interval):
  150. # split_index.append(i + 1)
  151. #
  152. # split_index.append(len(cluster_list))
  153. # split_index = sorted(list(set(split_index)))
  154. # for i in range(len(split_index) - 1):
  155. # block = cluster_list[split_index[i]: split_index[i + 1]]
  156. # block_numbers = numbers_array[split_index[i]: split_index[i + 1]]
  157. # 确定数字题号的位置,前提:同block题号是某等差数列的子集
  158. numbers_sum = numbers_array + np.flipud(numbers_array)
  159. counts = np.bincount(numbers_sum)
  160. mode_times = np.max(counts)
  161. mode_value = np.argmax(counts)
  162. if mode_times != len(numbers) and mode_times >= 2:
  163. # 启动题号补全
  164. number_interval_list = abs(numbers_array[1:] - numbers_array[:-1])
  165. number_interval_counts = np.bincount(number_interval_list)
  166. # number_interval_mode_times = np.max(number_interval_counts)
  167. number_interval_mode_value = np.argmax(number_interval_counts)
  168. suspect_index = np.where(numbers_sum != mode_value)[0]
  169. numbers_array_len = len(numbers_array)
  170. for suspect in suspect_index:
  171. if suspect == 0:
  172. cond_left = False
  173. cond_right = numbers_array[suspect + 1] == numbers_array[suspect] + number_interval_mode_value
  174. elif suspect == numbers_array_len - 1:
  175. cond_right = False
  176. cond_left = numbers_array[suspect - 1] == numbers_array[suspect] - number_interval_mode_value
  177. else:
  178. cond_left = numbers_array[suspect - 1] == numbers_array[suspect] - number_interval_mode_value
  179. cond_right = numbers_array[suspect + 1] == numbers_array[suspect] + number_interval_mode_value
  180. if cond_left or cond_right:
  181. pass
  182. else:
  183. numbers_array[suspect] = -1
  184. numbers_array = infer_number(numbers_array, number_interval_mode_value) # 推断题号
  185. numbers_interval = np.abs(numbers_array[1:] - numbers_array[:-1])
  186. split_index = [0]
  187. for i, interval in enumerate(numbers_interval):
  188. if interval > np.mean(numbers_interval):
  189. split_index.append(i + 1)
  190. split_index.append(len(cluster_list))
  191. split_index = sorted(list(set(split_index)))
  192. block_list = []
  193. for i in range(len(split_index) - 1):
  194. block = cluster_list[split_index[i]: split_index[i + 1]]
  195. block_numbers = numbers_array[split_index[i]: split_index[i + 1]]
  196. xmin = min([ele["loc"][0] for ele in block])
  197. ymin = min([ele["loc"][1] for ele in block])
  198. xmax = max([ele["loc"][2] for ele in block])
  199. ymax = max([ele["loc"][3] for ele in block])
  200. mid_x = xmin + (xmax - xmin) // 2
  201. mid_y = ymin + (ymax - ymin) // 2
  202. choice_m = {"numbers": list(block_numbers), "loc": [xmin, ymin, xmax, ymax, mid_x, mid_y]}
  203. block_list.append(choice_m)
  204. return block_list
  205. def cluster_and_anti_abnormal(image, xml_path, digital_list, chars_list,
  206. mean_height, mean_width, choice_s_height, choice_s_width, limit_loc):
  207. limit_left, limit_top, limit_right, limit_bottom = limit_loc
  208. limit_width, limit_height = limit_right - limit_left, limit_bottom - limit_top
  209. arr = np.ones((len(digital_list), 2))
  210. for i, ele in enumerate(digital_list):
  211. arr[i] = np.array([ele["loc"][-2], ele["loc"][-1]])
  212. if choice_s_height != 0:
  213. eps = int(choice_s_height * 2)
  214. else:
  215. eps = int(mean_height * 2.5)
  216. print("eps: ", eps)
  217. db = DBSCAN(eps=eps, min_samples=2, metric='chebyshev').fit(arr)
  218. labels = db.labels_
  219. # print(labels)
  220. cluster_label = []
  221. for ele in labels:
  222. if ele not in cluster_label and ele != -1:
  223. cluster_label.append(ele)
  224. a_e_dict = {k: [] for k in cluster_label}
  225. choice_m_numbers_list = []
  226. for index, ele in enumerate(labels):
  227. if ele != -1:
  228. a_e_dict[ele].append(digital_list[index])
  229. for ele in cluster_label:
  230. cluster = a_e_dict[ele]
  231. choice_m_numbers_list += cluster2choice_m(cluster, mean_width)
  232. all_list_nums = [ele["numbers"] for ele in choice_m_numbers_list]
  233. all_nums_len = [len(ele) for ele in all_list_nums]
  234. all_nums = list(chain.from_iterable(all_list_nums))
  235. counts = np.bincount(np.array(all_nums_len))
  236. if np.max(counts) < 2:
  237. mode_value = max(all_nums_len)
  238. else:
  239. mode_value = np.argmax(counts)
  240. mode_value = all_nums_len[np.where(np.array(all_nums_len) == mode_value)[0][-1]]
  241. if mode_value > 1: # 缺失补全
  242. error_index_list = list(np.where(np.array(all_nums_len) != mode_value)[0])
  243. all_height = [ele["loc"][3] - ele["loc"][1] for index, ele
  244. in enumerate(choice_m_numbers_list) if index not in error_index_list]
  245. choice_m_mean_height = int(sum(all_height) / len(all_height))
  246. for e_index in list(error_index_list):
  247. current_choice_m = choice_m_numbers_list[e_index]
  248. current_numbers_list = list(all_list_nums[e_index])
  249. current_len = all_nums_len[e_index]
  250. dif = mode_value - current_len
  251. if 1 in current_numbers_list:
  252. t2 = current_numbers_list + [-1] * dif
  253. infer_t1_list = infer_number(t2) # 后补
  254. infer_t2_list = infer_number(t2) # 后补
  255. cond1 = False
  256. cond2 = True
  257. else:
  258. t1_cond = [True] * dif
  259. t2_cond = [True] * dif
  260. t1 = [-1] * dif + current_numbers_list
  261. infer_t1_list = infer_number(t1) # 前补
  262. t2 = current_numbers_list + [-1] * dif
  263. infer_t2_list = infer_number(t2) # 后补
  264. for i in range(0, dif):
  265. t1_infer = infer_t1_list[i]
  266. t2_infer = infer_t2_list[-i - 1]
  267. if t1_infer == 0 or t1_infer in all_nums:
  268. t1_cond[i] = False
  269. if t2_infer in all_nums:
  270. t2_cond[i] = False
  271. cond1 = not (False in t1_cond)
  272. cond2 = not (False in t2_cond)
  273. if cond1 and not cond2:
  274. current_loc = current_choice_m["loc"]
  275. current_height = current_loc[3] - current_loc[1]
  276. infer_height = max((choice_m_mean_height - current_height), int(dif * current_height / current_len))
  277. choice_m_numbers_list[e_index]["loc"][1] = current_loc[1] - infer_height
  278. choice_m_numbers_list[e_index]["loc"][5] = (choice_m_numbers_list[e_index]["loc"][1] +
  279. (choice_m_numbers_list[e_index]["loc"][3] -
  280. choice_m_numbers_list[e_index]["loc"][1]) // 2)
  281. choice_m_numbers_list[e_index]["numbers"] = infer_t1_list
  282. all_nums.extend(infer_t1_list[:dif])
  283. if not cond1 and cond2:
  284. current_loc = current_choice_m["loc"]
  285. current_height = current_loc[3] - current_loc[1]
  286. infer_height = max((choice_m_mean_height - current_height), int(dif * current_height / current_len))
  287. infer_bottom = min(current_loc[3] + infer_height, limit_height-1)
  288. if infer_bottom <= limit_height:
  289. choice_m_numbers_list[e_index]["loc"][3] = infer_bottom
  290. choice_m_numbers_list[e_index]["loc"][5] = (choice_m_numbers_list[e_index]["loc"][1] +
  291. (choice_m_numbers_list[e_index]["loc"][3] -
  292. choice_m_numbers_list[e_index]["loc"][1]) // 2)
  293. choice_m_numbers_list[e_index]["numbers"] = infer_t2_list
  294. all_nums.extend(infer_t2_list[-dif:])
  295. else:
  296. # cond1 = cond2 = true, 因为infer选择题时已横向排序, 默认这种情况不会出现
  297. pass
  298. for ele in choice_m_numbers_list:
  299. loc = ele["loc"]
  300. if loc[3] - loc[1] >= loc[2] - loc[0]:
  301. direction = 180
  302. else:
  303. direction = 90
  304. ele.update({'direction': direction})
  305. # tree = ET.parse(xml_path)
  306. # for index, choice_m in enumerate(choice_m_numbers_list):
  307. # name = str(choice_m["numbers"])
  308. # xmin, ymin, xmax, ymax, _, _ = choice_m["loc"]
  309. # tree = create_xml(name, tree, str(xmin + limit_left), str(ymin + limit_top), str(xmax + limit_left), str(ymax + limit_top))
  310. #
  311. # tree.write(xml_path)
  312. choice_m_numbers_list = sorted(choice_m_numbers_list, key=lambda x: x['loc'][3] - x['loc'][1], reverse=True)
  313. choice_m_numbers_right_limit = max([ele['loc'][2] for ele in choice_m_numbers_list])
  314. remain_len = len(choice_m_numbers_list)
  315. choice_m_list = list()
  316. need_revised_choice_m_list = list()
  317. while remain_len > 0:
  318. # 先确定属于同行的数据,然后找字母划分block
  319. # random_index = random.randint(0, len(choice_m_numbers_list)-1)
  320. random_index = 0
  321. # print(random_index)
  322. ymax_limit = choice_m_numbers_list[random_index]["loc"][3]
  323. ymin_limit = choice_m_numbers_list[random_index]["loc"][1]
  324. # choice_m_numbers_list.pop(random_index)
  325. # 当前行的choice_m
  326. current_row_choice_m_d = [ele for ele in choice_m_numbers_list if ymin_limit < ele["loc"][5] < ymax_limit]
  327. current_row_choice_m_d = sorted(current_row_choice_m_d, key=lambda x: x["loc"][0])
  328. # current_row_choice_m_d.append(choice_m_numbers_list[random_index])
  329. split_pix = sorted([ele["loc"][0] for ele in current_row_choice_m_d]) # xmin排序
  330. split_index = get_split_index(split_pix)
  331. split_pix = [split_pix[ele] for ele in split_index[:-1]]
  332. block_list = []
  333. for i in range(len(split_index) - 1):
  334. block = current_row_choice_m_d[split_index[i]: split_index[i + 1]]
  335. if len(block) > 1:
  336. remain_len = remain_len - (len(block) - 1)
  337. numbers_new = []
  338. loc_new = [[], [], [], []]
  339. for blk in block:
  340. loc_old = blk["loc"]
  341. numbers_new.extend(blk["numbers"])
  342. for ii in range(4):
  343. loc_new[ii].append(loc_old[ii])
  344. loc_new[0] = min(loc_new[0])
  345. loc_new[1] = min(loc_new[1])
  346. loc_new[2] = max(loc_new[2])
  347. loc_new[3] = max(loc_new[3])
  348. loc_new.append(loc_new[0] + (loc_new[2] - loc_new[0]) // 2)
  349. loc_new.append(loc_new[1] + (loc_new[3] - loc_new[1]) // 2)
  350. block = [{"numbers": sorted(numbers_new), "loc": loc_new, "direction": block[0]["direction"]}]
  351. block_list.extend(block)
  352. current_row_choice_m_d = block_list
  353. current_row_chars = [ele for ele in chars_list
  354. if ymin_limit < (ele["location"]["top"] + ele["location"]["height"] // 2) < ymax_limit]
  355. # if not current_row_chars:
  356. # max_char_width = choice_s_width // 4
  357. # row_chars_xmax = choice_m_numbers_right_limit + int(choice_s_width * 1.5)
  358. # else:
  359. # max_char_width = max([ele["location"]["width"] for ele in current_row_chars]) // 2
  360. # row_chars_xmax = max(
  361. # [ele["location"]["left"] + ele["location"]["width"] for ele in current_row_chars]) + max_char_width * 2
  362. # split_index.append(row_chars_xmax) # 边界
  363. split_pix.append(round(split_pix[-1] + choice_s_width * 1.2))
  364. for i in range(0, len(split_index) - 1):
  365. left_limit = split_index[i]
  366. right_limit = split_index[i + 1]
  367. block_chars = [ele for ele in current_row_chars
  368. if left_limit < (ele["location"]["left"] + ele["location"]["width"] // 2) < right_limit]
  369. # chars_xmin = min([ele["location"]["left"] for ele in block_chars]) - max_char_width
  370. # chars_xmax = max(
  371. # [ele["location"]["left"] + ele["location"]["width"] for ele in block_chars]) + max_char_width
  372. # a_z = '_ABCD_FGH__K_MNOPQRSTUVWXYZ' EIJL -> _
  373. # a_z = '_ABCDEFGHI_K_MNOPQRSTUVWXYZ'
  374. a_z = '_ABCD_FGHT'
  375. # letter_text = set([ele['char'].upper() for ele in block_chars if ele['char'].upper() in a_z])
  376. letter_index = [a_z.index(ele['char'].upper()) for ele in block_chars if ele['char'].upper() in a_z]
  377. letter_index_times = {ele: 0 for ele in set(letter_index)}
  378. for l_index in letter_index:
  379. letter_index_times[l_index] += 1
  380. if (a_z.index("T") in letter_index) and (a_z.index("F") in letter_index):
  381. choice_option = "T, F"
  382. cols = 2
  383. else:
  384. if len(letter_index) < 1:
  385. tmp = 4
  386. choice_option = 'A,B,C,D'
  387. else:
  388. tmp = max(set(letter_index))
  389. # while letter_index_times[tmp] < 2 and tmp > 3:
  390. # t_list = list(set(letter_index))
  391. # t_list.remove(tmp)
  392. # tmp = max(t_list)
  393. choice_option = ",".join(a_z[min(letter_index):tmp + 1])
  394. cols = tmp
  395. bias = 3 # pix
  396. current_loc = current_row_choice_m_d[i]["loc"]
  397. location = dict(xmin=(current_loc[2] + bias) + limit_left, # 当前数字xmax右边
  398. # xmin=max(current_loc[2] + bias, chars_xmin) + limit_left,
  399. ymin=current_loc[1] + limit_top,
  400. xmax=(right_limit - bias) + limit_left,
  401. # xmax=min(chars_xmax, right_limit - bias) + limit_left,
  402. ymax=current_loc[3] + limit_top)
  403. try:
  404. choice_m_img = utils.crop_region(image, location)
  405. right_loc, bottom_loc = adjust_choice_m(choice_m_img, mean_height, mean_width * 2)
  406. if right_loc > 0:
  407. location.update(dict(xmax=right_loc + location['xmin']))
  408. if bottom_loc > 0:
  409. location.update(dict(ymax=bottom_loc + location['ymin']))
  410. except Exception as e:
  411. print(e)
  412. traceback.print_exc()
  413. tmp_w, tmp_h = location['xmax'] - location['xmin'], location['ymax'] - location['ymin'],
  414. numbers = current_row_choice_m_d[i]["numbers"]
  415. direction = current_row_choice_m_d[i]["direction"]
  416. if direction == 180:
  417. choice_m = dict(class_name='choice_m',
  418. number=numbers,
  419. bounding_box=location,
  420. choice_option=choice_option,
  421. default_points=[5] * len(numbers),
  422. direction=direction,
  423. cols=cols,
  424. rows=len(numbers),
  425. single_width=tmp_w // cols,
  426. single_height=tmp_h // len(numbers))
  427. else:
  428. choice_m = dict(class_name='choice_m',
  429. number=numbers,
  430. bounding_box=location,
  431. choice_option=choice_option,
  432. default_points=[5] * len(numbers),
  433. direction=direction,
  434. cols=len(numbers),
  435. rows=cols,
  436. single_width=tmp_w // len(numbers),
  437. single_height=tmp_h // cols
  438. )
  439. if tmp_w > 2 * choice_s_width:
  440. need_revised_choice_m_list.append(choice_m)
  441. else:
  442. choice_m_list.append(choice_m)
  443. remain_len = remain_len - len(current_row_choice_m_d)
  444. for ele in choice_m_numbers_list.copy():
  445. if ele in current_row_choice_m_d:
  446. choice_m_numbers_list.remove(ele)
  447. for ele in choice_m_numbers_list.copy():
  448. if ele in current_row_chars:
  449. choice_m_numbers_list.remove(ele)
  450. # 单独一行不聚类
  451. for i, revised_choice_m in enumerate(need_revised_choice_m_list):
  452. loc = revised_choice_m['bounding_box']
  453. left_part_loc = loc.copy()
  454. left_part_loc.update({'xmax': loc['xmin']+choice_s_width})
  455. choice_m_img = utils.crop_region(image, left_part_loc)
  456. right_loc, bottom_loc = adjust_choice_m(choice_m_img, mean_height, mean_width * 2)
  457. if right_loc > 0:
  458. left_part_loc.update(dict(xmax=right_loc + left_part_loc['xmin']))
  459. if bottom_loc > 0:
  460. left_part_loc.update(dict(ymax=bottom_loc + left_part_loc['ymin']))
  461. left_tmp_height = left_part_loc['ymax'] - left_part_loc['ymin']
  462. right_part_loc = loc.copy()
  463. # right_part_loc.update({'xmin': loc['xmax']-choice_s_width})
  464. right_part_loc.update({'xmin': left_part_loc['xmax']+5})
  465. choice_m_img = utils.crop_region(image, right_part_loc)
  466. right_loc, bottom_loc = adjust_choice_m(choice_m_img, mean_height, mean_width * 2)
  467. if right_loc > 0:
  468. right_part_loc.update(dict(xmax=right_loc + right_part_loc['xmin']))
  469. if bottom_loc > 0:
  470. right_part_loc.update(dict(ymax=bottom_loc + right_part_loc['ymin']))
  471. right_tmp_height = right_part_loc['ymax'] - right_part_loc['ymin']
  472. number_len = max(1, int(revised_choice_m['rows'] // (left_tmp_height // right_tmp_height)))
  473. number = [ele+revised_choice_m['number'][-1]+1 for ele in range(number_len)]
  474. rows = len(number)
  475. revised_choice_m.update({'bounding_box': left_part_loc})
  476. choice_m_list.append(revised_choice_m)
  477. tmp = revised_choice_m.copy()
  478. tmp.update({'bounding_box': right_part_loc, 'number': number, 'rows': rows})
  479. choice_m_list.append(tmp)
  480. tmp = choice_m_list.copy()
  481. for ele in tmp:
  482. loc = ele["bounding_box"]
  483. w, h = loc['xmax'] - loc['xmin'], loc['ymax'] - loc['ymin']
  484. if w*h < choice_s_width*choice_s_height:
  485. choice_m_list.remove(ele)
  486. return choice_m_list
  487. def infer_choice_m(image, tf_sheet, ocr, xml=None):
  488. infer_box_list = ocr2sheet(image, tf_sheet, ocr, xml)
  489. # print(sheet_region_list)
  490. choice_m_list = []
  491. choice_s_h_list = [int(ele['bounding_box']['ymax']) - int(ele['bounding_box']['ymin']) for ele in tf_sheet
  492. if ele['class_name'] == 'choice_s']
  493. if choice_s_h_list:
  494. choice_s_height = sum(choice_s_h_list) // len(choice_s_h_list)
  495. else:
  496. choice_s_height = 0
  497. choice_s_w_list = [int(ele['bounding_box']['xmax']) - int(ele['bounding_box']['xmin']) for ele in tf_sheet
  498. if ele['class_name'] == 'choice_s']
  499. if choice_s_w_list:
  500. choice_s_width = sum(choice_s_w_list) // len(choice_s_w_list)
  501. else:
  502. choice_s_width = 0
  503. for infer_box in infer_box_list:
  504. # {'loc': [240, 786, 1569, 1368]}
  505. loc = infer_box['loc']
  506. xmin, ymin, xmax, ymax = loc[0], loc[1], loc[2], loc[3]
  507. choice_flag = False
  508. for ele in tf_sheet:
  509. if ele['class_name'] in ['choice_m', 'choice_s']:
  510. tf_loc = ele['bounding_box']
  511. tf_loc_l = tf_loc['xmin']
  512. tf_loc_t = tf_loc['ymin']
  513. if xmin < tf_loc_l < xmax and ymin < tf_loc_t < ymax:
  514. choice_flag = True
  515. break
  516. if choice_flag:
  517. infer_image = utils.crop_region_direct(image, loc)
  518. try:
  519. save_dir = os.path.join(settings.MEDIA_ROOT, 'tmp')
  520. if not os.path.exists(save_dir):
  521. os.makedirs(save_dir)
  522. save_path = os.path.join(save_dir, 'choice.jpeg')
  523. cv2.imwrite(save_path, infer_image)
  524. img_tmp = utils.read_single_img(save_path)
  525. os.remove(save_path)
  526. ocr = brain_api.get_ocr_text_and_coordinate(img_tmp, 'accurate', 'CHN_ENG')
  527. except Exception as e:
  528. print('write choice and ocr failed')
  529. traceback.print_exc()
  530. ocr = brain_api.get_ocr_text_and_coordinate(infer_image, 'accurate', 'CHN_ENG')
  531. try:
  532. digital_list, chars_list, digital_mean_h, digital_mean_w = find_digital(ocr)
  533. choice_m = cluster_and_anti_abnormal(image, xml, digital_list, chars_list,
  534. digital_mean_h, digital_mean_w,
  535. choice_s_height, choice_s_width, loc)
  536. choice_m_list.extend(choice_m)
  537. except Exception as e:
  538. traceback.print_exc()
  539. print('not found choice feature')
  540. pass
  541. # print(choice_m_list)
  542. # tf_choice_sheet = [ele for ele in tf_sheet if ele['class_name'] == 'choice_m']
  543. sheet_tmp = choice_m_list.copy()
  544. remove_index = []
  545. for i, region in enumerate(sheet_tmp):
  546. if i not in remove_index:
  547. box = region['bounding_box']
  548. for j, region_in in enumerate(sheet_tmp):
  549. box_in = region_in['bounding_box']
  550. iou = utils.cal_iou(box, box_in)
  551. if iou[0] > 0.85 and i != j:
  552. choice_m_list.remove(region)
  553. remove_index.append(j)
  554. break
  555. return choice_m_list