123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673 |
- # @Author : lightXu
- # @File : choice_infer.py
- import os
- import traceback
- import time
- import random
- from django.conf import settings
- from segment.sheet_resolve.tools import utils, brain_api
- from itertools import chain
- import re
- import numpy as np
- import cv2
- import xml.etree.cElementTree as ET
- from segment.sheet_resolve.tools.utils import crop_region_direct, create_xml, infer_number, combine_char_in_raw_format
- from sklearn.cluster import DBSCAN
- from segment.sheet_resolve.analysis.sheet.ocr_sheet import ocr2sheet
- def get_split_index(array, dif=0):
- array = np.array(array)
- interval_list = np.abs(array[1:] - array[:-1])
- split_index = [0]
- for i, interval in enumerate(interval_list):
- if dif:
- split_dif = dif
- else:
- split_dif = np.mean(interval_list)
- if interval > split_dif:
- split_index.append(i + 1)
- split_index.append(len(array))
- split_index = sorted(list(set(split_index)))
- return split_index
- def adjust_choice_m(image, xe, ye):
- dilate = 1
- blur = 5
- # Convert to gray
- image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
- if blur != 0:
- image = cv2.GaussianBlur(image, (blur, blur), 0)
- # Apply threshold to get image with only b&w (binarization)
- image = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
- kernel = np.ones((ye, xe), np.uint8) # y轴膨胀, x轴膨胀
- dst = cv2.dilate(image, kernel, iterations=1)
- (major, minor, _) = cv2.__version__.split(".")
- contours = cv2.findContours(dst, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
- cnts = contours[0] if int(major) > 3 else contours[1]
- # _, cnts, hierarchy = cv2.findContours(dst, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
- right_limit = 0
- bottom_limit = 0
- for cnt_id, cnt in enumerate(reversed(cnts)):
- x, y, w, h = cv2.boundingRect(cnt)
- if x + w > right_limit:
- right_limit = x + w
- if y + h > bottom_limit:
- bottom_limit = y + h
- return right_limit, bottom_limit
- def find_digital(ocr_raw_list):
- pattern = r'\d+'
- x_list = []
- y_list = []
- digital_list = list()
- chars_list = list()
- height_list, width_list = list(), list()
- ocr_dict_list = combine_char_in_raw_format(ocr_raw_list)
- for i, ele in enumerate(ocr_dict_list):
- words = ele['words']
- words = words.replace(' ', '').upper() # 去除空格
- digital_words_m = re.finditer(pattern, words)
- digital_index_list = [(m.group(), m.span()) for m in digital_words_m if m]
- chars_index = [ele for ele in range(0, len(ele['chars']))]
- digital_index_detail_list = []
- for letter_info in digital_index_list:
- number = letter_info[0]
- index_start = letter_info[1][0]
- index_end = letter_info[1][1] - 1
- char_start = ele['chars'][index_start]
- char_end = ele['chars'][index_end]
- if index_start == index_end:
- digital_index_detail_list += [index_start]
- else:
- digital_index_detail_list += chars_index[index_start:index_end + 1]
- letter_loc_xmin = int(char_start['location']['left'])
- letter_loc_ymin = min(int(char_start['location']['top']), int(char_end['location']['top']))
- letter_loc_xmax = int(char_end['location']['left']) + int(char_end['location']['width'])
- letter_loc_ymax = max(int(char_start['location']['top']) + int(char_start['location']['height']),
- int(char_end['location']['top']) + int(char_end['location']['height']))
- mid_x = letter_loc_xmin + (letter_loc_xmax - letter_loc_xmin) // 2
- mid_y = letter_loc_ymin + (letter_loc_ymax - letter_loc_ymin) // 2
- # print(number, (mid_x, mid_y))
- x_list.append(mid_x)
- y_list.append(mid_y)
- height_list.append(letter_loc_ymax - letter_loc_ymin)
- width_list.append(letter_loc_xmax - letter_loc_xmin)
- number_loc = (letter_loc_xmin, letter_loc_ymin, letter_loc_xmax, letter_loc_ymax, mid_x, mid_y)
- digital_list.append({"digital": int(number), "loc": number_loc})
- current_chars = [char for index, char in enumerate(ele['chars'])
- if index not in digital_index_detail_list and char['char'] not in ['.', ',', '。', '、']]
- chars_list += current_chars
- d_mean_height = sum(height_list) // len(height_list)
- d_mean_width = sum(width_list) // len(width_list)
- # mean_height = max(height_list)
- # mean_width = max(width_list)
- # print(x_list)
- # print(y_list)
- return digital_list, chars_list, d_mean_height, d_mean_width
- def cluster2choice_m_(cluster_list, m_h, m_w):
- numbers = [ele['digital'] for ele in cluster_list]
- loc_top_interval = (np.array([ele['loc'][3] for ele in cluster_list][1:]) -
- np.array([ele['loc'][3] for ele in cluster_list][:-1]))
- split_index = [0]
- for i, interval in enumerate(loc_top_interval):
- if interval > m_h * 1.5:
- split_index.append(i + 1)
- split_index.append(len(cluster_list))
- split_index = sorted(list(set(split_index)))
- block_list = []
- for i in range(len(split_index) - 1):
- block = cluster_list[split_index[i]: split_index[i + 1]]
- xmin = min([ele["loc"][0] for ele in block])
- ymin = min([ele["loc"][1] for ele in block])
- xmax = max([ele["loc"][2] for ele in block])
- ymax = max([ele["loc"][3] for ele in block])
- numbers = [ele['digital'] for ele in block]
- choice_m = {"number": numbers, "loc": (xmin, ymin, xmax, ymax)}
- block_list.append(choice_m)
- return block_list
- def cluster2choice_m(cluster_list, mean_width):
- # 比较x坐标,去掉误差值
- numbers_x = [ele['loc'][4] for ele in cluster_list]
- numbers_x_array = np.array(numbers_x)
- numbers_x_interval = np.abs((numbers_x_array[1:] - numbers_x_array[:-1]))
- error_index_superset = np.where(numbers_x_interval >= mean_width)[0]
- error_index_superset_interval = error_index_superset[1:] - error_index_superset[:-1]
- t_index = list(np.where(error_index_superset_interval > 1)[0] + 1)
- t_index.insert(0, 0)
- t_index.append(len(error_index_superset))
- error = []
- for i in range(0, len(t_index) - 1):
- a = t_index[i]
- b = t_index[i + 1]
- block = list(error_index_superset[a: b])
- error += block[1:]
- cluster_list = [ele for i, ele in enumerate(cluster_list) if i not in error]
- numbers = [ele['digital'] for ele in cluster_list]
- numbers_array = np.array(numbers)
- # numbers_y = [ele['loc'][5] for ele in cluster_list]
- # numbers_y_array = np.array(numbers_y)
- # numbers_y_interval = np.abs((numbers_y_array[1:] - numbers_y_array[:-1]))
- # split_index = [0]
- # for i, interval in enumerate(numbers_y_interval):
- # if interval > np.mean(numbers_y_interval):
- # split_index.append(i + 1)
- #
- # split_index.append(len(cluster_list))
- # split_index = sorted(list(set(split_index)))
- # for i in range(len(split_index) - 1):
- # block = cluster_list[split_index[i]: split_index[i + 1]]
- # block_numbers = numbers_array[split_index[i]: split_index[i + 1]]
- # 确定数字题号的位置,前提:同block题号是某等差数列的子集
- numbers_sum = numbers_array + np.flipud(numbers_array)
- counts = np.bincount(numbers_sum)
- mode_times = np.max(counts)
- mode_value = np.argmax(counts)
- if mode_times != len(numbers) and mode_times >= 2:
- # 启动题号补全
- number_interval_list = abs(numbers_array[1:] - numbers_array[:-1])
- number_interval_counts = np.bincount(number_interval_list)
- # number_interval_mode_times = np.max(number_interval_counts)
- number_interval_mode_value = np.argmax(number_interval_counts)
- suspect_index = np.where(numbers_sum != mode_value)[0]
- numbers_array_len = len(numbers_array)
- for suspect in suspect_index:
- if suspect == 0:
- cond_left = False
- cond_right = numbers_array[suspect + 1] == numbers_array[suspect] + number_interval_mode_value
- elif suspect == numbers_array_len - 1:
- cond_right = False
- cond_left = numbers_array[suspect - 1] == numbers_array[suspect] - number_interval_mode_value
- else:
- cond_left = numbers_array[suspect - 1] == numbers_array[suspect] - number_interval_mode_value
- cond_right = numbers_array[suspect + 1] == numbers_array[suspect] + number_interval_mode_value
- if cond_left or cond_right:
- pass
- else:
- numbers_array[suspect] = -1
- numbers_array = infer_number(numbers_array, number_interval_mode_value) # 推断题号
- numbers_array = np.array(numbers_array)
- numbers_interval = np.abs(numbers_array[1:] - numbers_array[:-1])
- split_index = [0]
- for i, interval in enumerate(numbers_interval):
- if interval > np.mean(numbers_interval):
- split_index.append(i + 1)
- split_index.append(len(cluster_list))
- split_index = sorted(list(set(split_index)))
- block_list = []
- for i in range(len(split_index) - 1):
- block = cluster_list[split_index[i]: split_index[i + 1]]
- block_numbers = numbers_array[split_index[i]: split_index[i + 1]]
- xmin = min([ele["loc"][0] for ele in block])
- ymin = min([ele["loc"][1] for ele in block])
- xmax = max([ele["loc"][2] for ele in block])
- ymax = max([ele["loc"][3] for ele in block])
- mid_x = xmin + (xmax - xmin) // 2
- mid_y = ymin + (ymax - ymin) // 2
- choice_m = {"numbers": list(block_numbers), "loc": [xmin, ymin, xmax, ymax, mid_x, mid_y]}
- block_list.append(choice_m)
- return block_list
- def cluster_and_anti_abnormal(image, xml_path, digital_list, chars_list,
- mean_height, mean_width, choice_s_height, choice_s_width, limit_loc):
- limit_left, limit_top, limit_right, limit_bottom = limit_loc
- limit_width, limit_height = limit_right - limit_left, limit_bottom - limit_top
- arr = np.ones((len(digital_list), 2))
- for i, ele in enumerate(digital_list):
- arr[i] = np.array([ele["loc"][-2], ele["loc"][-1]])
- if choice_s_height != 0:
- eps = int(choice_s_height * 2.5)
- else:
- eps = int(mean_height * 3)
- print("eps: ", eps)
- db = DBSCAN(eps=eps, min_samples=2, metric='chebyshev').fit(arr)
- labels = db.labels_
- # print(labels)
- cluster_label = []
- for ele in labels:
- if ele not in cluster_label and ele != -1:
- cluster_label.append(ele)
- a_e_dict = {k: [] for k in cluster_label}
- choice_m_numbers_list = []
- for index, ele in enumerate(labels):
- if ele != -1:
- a_e_dict[ele].append(digital_list[index])
- for ele in cluster_label:
- cluster = a_e_dict[ele]
- choice_m_numbers_list += cluster2choice_m(cluster, mean_width)
- all_list_nums = [ele["numbers"] for ele in choice_m_numbers_list]
- all_nums_len = [len(ele) for ele in all_list_nums]
- all_nums = list(chain.from_iterable(all_list_nums))
- counts = np.bincount(np.array(all_nums_len))
- if np.max(counts) < 2:
- mode_value = max(all_nums_len)
- else:
- mode_value = np.argmax(counts)
- mode_value = all_nums_len[np.where(np.array(all_nums_len) == mode_value)[0][-1]]
- if mode_value > 1: # 缺失补全
- error_index_list = list(np.where(np.array(all_nums_len) != mode_value)[0])
- all_height = [ele["loc"][3] - ele["loc"][1] for index, ele
- in enumerate(choice_m_numbers_list) if index not in error_index_list]
- choice_m_mean_height = int(sum(all_height) / len(all_height))
- for e_index in list(error_index_list):
- current_choice_m = choice_m_numbers_list[e_index]
- current_numbers_list = list(all_list_nums[e_index])
- current_len = all_nums_len[e_index]
- dif = mode_value - current_len
- if 1 in current_numbers_list:
- t2 = current_numbers_list + [-1] * dif
- infer_t1_list = infer_number(t2) # 后补
- infer_t2_list = infer_number(t2) # 后补
- cond1 = False
- cond2 = True
- else:
- t1_cond = [True] * dif
- t2_cond = [True] * dif
- t1 = [-1] * dif + current_numbers_list
- infer_t1_list = infer_number(t1) # 前补
- t2 = current_numbers_list + [-1] * dif
- infer_t2_list = infer_number(t2) # 后补
- for i in range(0, dif):
- t1_infer = infer_t1_list[i]
- t2_infer = infer_t2_list[-i - 1]
- if t1_infer == 0 or t1_infer in all_nums:
- t1_cond[i] = False
- if t2_infer in all_nums:
- t2_cond[i] = False
- cond1 = not (False in t1_cond)
- cond2 = not (False in t2_cond)
- if cond1 and not cond2:
- current_loc = current_choice_m["loc"]
- current_height = current_loc[3] - current_loc[1]
- infer_height = max((choice_m_mean_height - current_height), int(dif * current_height / current_len))
- choice_m_numbers_list[e_index]["loc"][1] = current_loc[1] - infer_height
- choice_m_numbers_list[e_index]["loc"][5] = (choice_m_numbers_list[e_index]["loc"][1] +
- (choice_m_numbers_list[e_index]["loc"][3] -
- choice_m_numbers_list[e_index]["loc"][1]) // 2)
- choice_m_numbers_list[e_index]["numbers"] = infer_t1_list
- all_nums.extend(infer_t1_list[:dif])
- if not cond1 and cond2:
- current_loc = current_choice_m["loc"]
- current_height = current_loc[3] - current_loc[1]
- infer_height = max((choice_m_mean_height - current_height), int(dif * current_height / current_len))
- infer_bottom = min(current_loc[3] + infer_height, limit_height-1)
- if infer_bottom <= limit_height:
- choice_m_numbers_list[e_index]["loc"][3] = infer_bottom
- choice_m_numbers_list[e_index]["loc"][5] = (choice_m_numbers_list[e_index]["loc"][1] +
- (choice_m_numbers_list[e_index]["loc"][3] -
- choice_m_numbers_list[e_index]["loc"][1]) // 2)
- choice_m_numbers_list[e_index]["numbers"] = infer_t2_list
- all_nums.extend(infer_t2_list[-dif:])
- else:
- # cond1 = cond2 = true, 因为infer选择题时已横向排序, 默认这种情况不会出现
- pass
- for ele in choice_m_numbers_list:
- loc = ele["loc"]
- if loc[3] - loc[1] >= loc[2] - loc[0]:
- direction = 180
- else:
- direction = 90
- ele.update({'direction': direction})
- # tree = ET.parse(xml_path)
- # for index, choice_m in enumerate(choice_m_numbers_list):
- # name = str(choice_m["numbers"])
- # xmin, ymin, xmax, ymax, _, _ = choice_m["loc"]
- # tree = create_xml(name, tree, str(xmin + limit_left), str(ymin + limit_top), str(xmax + limit_left), str(ymax + limit_top))
- #
- # tree.write(xml_path)
- choice_m_numbers_list = sorted(choice_m_numbers_list, key=lambda x: x['loc'][3] - x['loc'][1], reverse=True)
- choice_m_numbers_right_limit = max([ele['loc'][2] for ele in choice_m_numbers_list])
- remain_len = len(choice_m_numbers_list)
- choice_m_list = list()
- need_revised_choice_m_list = list()
- while remain_len > 0:
- # 先确定属于同行的数据,然后找字母划分block
- # random_index = random.randint(0, len(choice_m_numbers_list)-1)
- random_index = 0
- # print(random_index)
- ymax_limit = choice_m_numbers_list[random_index]["loc"][3]
- ymin_limit = choice_m_numbers_list[random_index]["loc"][1]
- # choice_m_numbers_list.pop(random_index)
- # 当前行的choice_m
- current_row_choice_m_d = [ele for ele in choice_m_numbers_list if ymin_limit < ele["loc"][5] < ymax_limit]
- current_row_choice_m_d = sorted(current_row_choice_m_d, key=lambda x: x["loc"][0])
- # current_row_choice_m_d.append(choice_m_numbers_list[random_index])
- split_pix = sorted([ele["loc"][0] for ele in current_row_choice_m_d]) # xmin排序
- split_index = get_split_index(split_pix, dif=choice_s_width*0.8)
- split_pix = [split_pix[ele] for ele in split_index[:-1]]
- block_list = []
- for i in range(len(split_index) - 1):
- block = current_row_choice_m_d[split_index[i]: split_index[i + 1]]
- if len(block) > 1:
- remain_len = remain_len - (len(block) - 1)
- numbers_new = []
- loc_new = [[], [], [], []]
- for blk in block:
- loc_old = blk["loc"]
- numbers_new.extend(blk["numbers"])
- for ii in range(4):
- loc_new[ii].append(loc_old[ii])
- loc_new[0] = min(loc_new[0])
- loc_new[1] = min(loc_new[1])
- loc_new[2] = max(loc_new[2])
- loc_new[3] = max(loc_new[3])
- loc_new.append(loc_new[0] + (loc_new[2] - loc_new[0]) // 2)
- loc_new.append(loc_new[1] + (loc_new[3] - loc_new[1]) // 2)
- block = [{"numbers": sorted(numbers_new), "loc": loc_new, "direction": block[0]["direction"]}]
- block_list.extend(block)
- current_row_choice_m_d = block_list
- current_row_chars = [ele for ele in chars_list
- if ymin_limit < (ele["location"]["top"] + ele["location"]["height"] // 2) < ymax_limit]
- # if not current_row_chars:
- # max_char_width = choice_s_width // 4
- # row_chars_xmax = choice_m_numbers_right_limit + int(choice_s_width * 1.5)
- # else:
- # max_char_width = max([ele["location"]["width"] for ele in current_row_chars]) // 2
- # row_chars_xmax = max(
- # [ele["location"]["left"] + ele["location"]["width"] for ele in current_row_chars]) + max_char_width * 2
- # split_index.append(row_chars_xmax) # 边界
- split_pix.append(round(split_pix[-1] + choice_s_width * 1.2))
- for i in range(0, len(split_pix) - 1):
- left_limit = split_pix[i]
- right_limit = split_pix[i + 1]
- block_chars = [ele for ele in current_row_chars
- if left_limit < (ele["location"]["left"] + ele["location"]["width"] // 2) < right_limit]
- # chars_xmin = min([ele["location"]["left"] for ele in block_chars]) - max_char_width
- # chars_xmax = max(
- # [ele["location"]["left"] + ele["location"]["width"] for ele in block_chars]) + max_char_width
- # a_z = '_ABCD_FGH__K_MNOPQRSTUVWXYZ' EIJL -> _
- # a_z = '_ABCDEFGHI_K_MNOPQRSTUVWXYZ'
- a_z = '_ABCD_FGHT'
- # letter_text = set([ele['char'].upper() for ele in block_chars if ele['char'].upper() in a_z])
- letter_index = [a_z.index(ele['char'].upper()) for ele in block_chars if ele['char'].upper() in a_z]
- letter_index_times = {ele: 0 for ele in set(letter_index)}
- for l_index in letter_index:
- letter_index_times[l_index] += 1
- if (a_z.index("T") in letter_index) and (a_z.index("F") in letter_index):
- choice_option = "T, F"
- cols = 2
- else:
- if len(letter_index) < 1:
- tmp = 4
- choice_option = 'A,B,C,D'
- else:
- tmp = max(set(letter_index))
- # while letter_index_times[tmp] < 2 and tmp > 3:
- # t_list = list(set(letter_index))
- # t_list.remove(tmp)
- # tmp = max(t_list)
- choice_option = ",".join(a_z[min(letter_index):tmp + 1])
- cols = tmp
- bias = 3 # pix
- current_loc = current_row_choice_m_d[i]["loc"]
- location = dict(xmin=(current_loc[2] + bias) + limit_left, # 当前数字xmax右边
- # xmin=max(current_loc[2] + bias, chars_xmin) + limit_left,
- ymin=current_loc[1] + limit_top,
- xmax=(right_limit - bias) + limit_left,
- # xmax=min(chars_xmax, right_limit - bias) + limit_left,
- ymax=current_loc[3] + limit_top)
- try:
- choice_m_img = utils.crop_region(image, location)
- right_loc, bottom_loc = adjust_choice_m(choice_m_img, mean_height, mean_width * 2)
- if right_loc > 0:
- location.update(dict(xmax=right_loc + location['xmin']))
- if bottom_loc > 0:
- location.update(dict(ymax=bottom_loc + location['ymin']))
- except Exception as e:
- print(e)
- traceback.print_exc()
- tmp_w, tmp_h = location['xmax'] - location['xmin'], location['ymax'] - location['ymin'],
- numbers = current_row_choice_m_d[i]["numbers"]
- direction = current_row_choice_m_d[i]["direction"]
- if direction == 180:
- choice_m = dict(class_name='choice_m',
- number=numbers,
- bounding_box=location,
- choice_option=choice_option,
- default_points=[5] * len(numbers),
- direction=direction,
- cols=cols,
- rows=len(numbers),
- single_width=tmp_w // cols,
- single_height=tmp_h // len(numbers))
- else:
- choice_m = dict(class_name='choice_m',
- number=numbers,
- bounding_box=location,
- choice_option=choice_option,
- default_points=[5] * len(numbers),
- direction=direction,
- cols=len(numbers),
- rows=cols,
- single_width=tmp_w // len(numbers),
- single_height=tmp_h // cols
- )
- if tmp_w > 2 * choice_s_width:
- need_revised_choice_m_list.append(choice_m)
- else:
- choice_m_list.append(choice_m)
- remain_len = remain_len - len(current_row_choice_m_d)
- for ele in choice_m_numbers_list.copy():
- if ele in current_row_choice_m_d:
- choice_m_numbers_list.remove(ele)
- for ele in choice_m_numbers_list.copy():
- if ele in current_row_chars:
- choice_m_numbers_list.remove(ele)
- # 单独一行不聚类
- for i, revised_choice_m in enumerate(need_revised_choice_m_list):
- loc = revised_choice_m['bounding_box']
- left_part_loc = loc.copy()
- left_part_loc.update({'xmax': loc['xmin']+choice_s_width})
- choice_m_img = utils.crop_region(image, left_part_loc)
- right_loc, bottom_loc = adjust_choice_m(choice_m_img, mean_height, mean_width * 2)
- if right_loc > 0:
- left_part_loc.update(dict(xmax=right_loc + left_part_loc['xmin']))
- if bottom_loc > 0:
- left_part_loc.update(dict(ymax=bottom_loc + left_part_loc['ymin']))
- left_tmp_height = left_part_loc['ymax'] - left_part_loc['ymin']
- right_part_loc = loc.copy()
- # right_part_loc.update({'xmin': loc['xmax']-choice_s_width})
- right_part_loc.update({'xmin': left_part_loc['xmax']+5})
- choice_m_img = utils.crop_region(image, right_part_loc)
- right_loc, bottom_loc = adjust_choice_m(choice_m_img, mean_height, mean_width * 2)
- if right_loc > 0:
- right_part_loc.update(dict(xmax=right_loc + right_part_loc['xmin']))
- if bottom_loc > 0:
- right_part_loc.update(dict(ymax=bottom_loc + right_part_loc['ymin']))
- right_tmp_height = right_part_loc['ymax'] - right_part_loc['ymin']
- number_len = max(1, int(revised_choice_m['rows'] // (left_tmp_height // right_tmp_height)))
- number = [ele+revised_choice_m['number'][-1]+1 for ele in range(number_len)]
- rows = len(number)
- revised_choice_m.update({'bounding_box': left_part_loc})
- choice_m_list.append(revised_choice_m)
- tmp = revised_choice_m.copy()
- tmp.update({'bounding_box': right_part_loc, 'number': number, 'rows': rows})
- choice_m_list.append(tmp)
- tmp = choice_m_list.copy()
- for ele in tmp:
- loc = ele["bounding_box"]
- w, h = loc['xmax'] - loc['xmin'], loc['ymax'] - loc['ymin']
- if 2*w*h < choice_s_width*choice_s_height:
- choice_m_list.remove(ele)
- return choice_m_list
- def infer_choice_m(image, tf_sheet, ocr, xml=None):
- infer_box_list = ocr2sheet(image, tf_sheet, ocr, xml)
- # print(sheet_region_list)
- choice_m_list = []
- choice_s_h_list = [int(ele['bounding_box']['ymax']) - int(ele['bounding_box']['ymin']) for ele in tf_sheet
- if ele['class_name'] == 'choice_s']
- if choice_s_h_list:
- choice_s_height = sum(choice_s_h_list) // len(choice_s_h_list)
- else:
- choice_s_height = 0
- choice_s_w_list = [int(ele['bounding_box']['xmax']) - int(ele['bounding_box']['xmin']) for ele in tf_sheet
- if ele['class_name'] == 'choice_s']
- if choice_s_w_list:
- choice_s_width = sum(choice_s_w_list) // len(choice_s_w_list)
- else:
- choice_s_width = 0
- for infer_box in infer_box_list:
- # {'loc': [240, 786, 1569, 1368]}
- loc = infer_box['loc']
- xmin, ymin, xmax, ymax = loc[0], loc[1], loc[2], loc[3]
- choice_flag = False
- for ele in tf_sheet:
- if ele['class_name'] in ['choice_m', 'choice_s']:
- tf_loc = ele['bounding_box']
- tf_loc_l = tf_loc['xmin']
- tf_loc_t = tf_loc['ymin']
- if xmin < tf_loc_l < xmax and ymin < tf_loc_t < ymax:
- choice_flag = True
- break
- if choice_flag:
- infer_image = utils.crop_region_direct(image, loc)
- try:
- save_dir = os.path.join(settings.MEDIA_ROOT, 'tmp')
- if not os.path.exists(save_dir):
- os.makedirs(save_dir)
- save_path = os.path.join(save_dir, 'choice.jpeg')
- cv2.imwrite(save_path, infer_image)
- img_tmp = utils.read_single_img(save_path)
- os.remove(save_path)
- ocr = brain_api.get_ocr_text_and_coordinate(img_tmp, 'accurate', 'CHN_ENG')
- except Exception as e:
- print('write choice and ocr failed')
- traceback.print_exc()
- ocr = brain_api.get_ocr_text_and_coordinate(infer_image, 'accurate', 'CHN_ENG')
- try:
- digital_list, chars_list, digital_mean_h, digital_mean_w = find_digital(ocr)
- choice_m = cluster_and_anti_abnormal(image, xml, digital_list, chars_list,
- digital_mean_h, digital_mean_w,
- choice_s_height, choice_s_width, loc)
- choice_m_list.extend(choice_m)
- except Exception as e:
- traceback.print_exc()
- print('not found choice feature')
- pass
- # print(choice_m_list)
- # tf_choice_sheet = [ele for ele in tf_sheet if ele['class_name'] == 'choice_m']
- sheet_tmp = choice_m_list.copy()
- remove_index = []
- for i, region in enumerate(sheet_tmp):
- if i not in remove_index:
- box = region['bounding_box']
- for j, region_in in enumerate(sheet_tmp):
- box_in = region_in['bounding_box']
- iou = utils.cal_iou(box, box_in)
- if iou[0] > 0.85 and i != j:
- choice_m_list.remove(region)
- remove_index.append(j)
- break
- return choice_m_list
|