Pārlūkot izejas kodu

1.定位点修改;
2.选择题推断更新;
3.区域修正增加correction;
4.选择题修改。

lighttxu 4 gadi atpakaļ
vecāks
revīzija
1a27cfd39f

+ 68 - 12
segment/sheet_resolve/analysis/anchor/marker_detection.py

@@ -4,8 +4,8 @@ import math
 from .util import *
 import ctypes
 import time
-import sys
 import numpy as np
+
 try:
     temp = ctypes.windll.LoadLibrary('opencv_ffmpeg410_64.dll')
 except:
@@ -181,6 +181,7 @@ def detect_anchor_by_position(anchors, markers, image, method, debug=0):
 
 def detect_anchor_public(image, method='connected', debug=0):
     #   寻找第三方试卷最上方及最下方的定位点
+    #   anchor:[xmin, ymin, xmax, ymax]
     shift_threshold = 50  # 80
     height, width = image.shape[:2]
     h0, h1 = 0.1, 0.9
@@ -188,34 +189,45 @@ def detect_anchor_public(image, method='connected', debug=0):
     pos_threshold = 0.1
     area_threshold = 0.28  # 0.25   # 大定位点面积差阈值
     anchors_len_threshold = 2
-    shape_para = {'height': (80, 10), 'w2h': (3, 0.6), 'area': (6000, 500), 'area_ratio': 0.9}
-    blur_size = 3
-    sigma = 5
+    shape_para = {'height': (80, 10), 'w2h': (3, 0.6), 'area': (8000, 500), 'area_ratio': 0.9}
+    blur_size = 1
+    sigma = 1
+    ker_size1, ker_size2 = 2, 10
+
     shift_edge = 2  # 对定位点位置做微调偏移
 
-    binary = pre_process_for_anchors(image, debug=0, blur_size=blur_size, sigma=sigma)
-    binary = extract_feature(binary, method=4, debug=0)
+    binary = pre_process_for_anchors(image, h_ratio=0, blur_size=blur_size, sigma=sigma, debug=0)
+    binary = extract_feature(binary, ker_size1=ker_size1, ker_size2=ker_size2, method=4, debug=0)
     boxes = find_boxes(binary, method=method, debug=0)
     markers = find_marker_by_shape(boxes, shape_para=shape_para, debug=0)
     marker_list = collect_markers_by_position(markers, method='h', shift_threshold=shift_threshold)
+    anchor_list = []
+    for m in marker_list:
+        m = collect_markers_by_area(m)
+        anchor_list.append(m)
 
-    if len(marker_list) == 0:
+    if len(anchor_list) == 0:
         anchors = []
-    elif len(marker_list) == 1:
-        anchors = marker_list[0]
+    elif len(anchor_list) == 1:
+        anchors = anchor_list[0]
     else:
-        anchors = marker_list[0]
-        anchors.extend(marker_list[-1])
+        anchors = anchor_list[0]
+        anchors.extend(anchor_list[-1])
     anchors = [[a[0]-shift_edge, a[1]-shift_edge, a[2]-shift_edge-1, a[3]-shift_edge-1] for a in anchors]
 
     if debug == 1:
-        # print(anchors)
+        #   显示偏移处理后的定位点信息
+        print('Anchors after edge shifting:')
+        print(anchors)
         draw_box(image, anchors, (0, 255, 255), debug=1)
         plt.figure(figsize=(15, 10))
         plt.title(method)
         plt.imshow(image, cmap='gray')
         plt.show()
     elif debug == 2:
+        #   显示偏移处理前的定位点信息
+        print('Achors before edge shifting:')
+        print(markers)
         markers.sort(reverse=True, key=lambda x: x[4][1])
         draw_box(image, markers, debug=1)
         plt.figure(figsize=(15, 10))
@@ -642,3 +654,47 @@ def find_anchor(image, method='connected'):
         anchor_dict = {'class_name': 'anchor_point', 'bounding_box': bbox}
         anchors_list.append(anchor_dict)
     return anchors_list
+
+
+if __name__ == "__main__":
+    from pathlib import Path
+
+    # work_dir = sys.argv[1]
+    # if os.path.exists(work_dir):
+    #     img_list = glob.glob(os.path.join(work_dir, '*.jpg'))
+    #     for img_file in img_list:
+    #         print(img_file)
+    #         anchors, problem_markers = main(img_file, debug=1)
+    #
+    # else:
+    #     print('Directory not found!')
+
+    work_dir = Path(r'E:\data\location-point\8_5-lin\biology')
+    img_list = work_dir.rglob('*.jpg')
+    # img_list = glob.glob(os.path.join(work_dir, '*.jpg'))
+
+    tick = time.clock()
+    for img_file in img_list:
+        print(img_file)
+        image = read_single_img(img_file)
+        # rot_image, flag = rotate_by_anchor(image, method='connected', debug=0)
+        anchors = detect_anchor_public(image, debug=1)
+        # anchors, problem_markers = main(img_file, debug=1)
+    toc = time.clock()
+    print('time cost:', toc - tick)
+    #
+    # index = 14
+    # img_file = img_list[index]
+    # # img_file = r'E:\data\location_point\1.jpg'
+    # print(img_file)
+    # markers, image = detect_marker(img_file, debug=1)
+    # # for index in (0, 2, -20, -19, -15, -16, -1, 13, 16, 18, -34, -21, -30, 7, 14, -42):
+    # #     img_file = img_list[index]
+    # #     print(img_file)
+    # #     markers, image = main(img_file, debug=1)
+
+
+
+
+
+

+ 39 - 9
segment/sheet_resolve/analysis/anchor/util.py

@@ -19,6 +19,8 @@ def pre_process(image, blank_top=20, blank_bottom=-20, blur_size=5, sigma=5, deb
         gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
     elif image.ndim == 2:
         gray = image
+    else:
+        raise Exception('The dimension of the image should be 2 or 3!')
     #   裁边
     gray[0:blank_top, :] = 255
     gray[blank_bottom:, :] = 255
@@ -29,6 +31,7 @@ def pre_process(image, blank_top=20, blank_bottom=-20, blur_size=5, sigma=5, deb
     binary = cv2.threshold(pre, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)[1]
 
     if debug == 1:
+        #   显示灰度图和二值图
         plt.figure(figsize=(15, 10))
         plt.subplot(211)
         plt.title('gray')
@@ -41,31 +44,35 @@ def pre_process(image, blank_top=20, blank_bottom=-20, blur_size=5, sigma=5, deb
     return binary
 
 
-def pre_process_for_anchors(image, blank_top=20, blank_bottom=-20, blur_size=5, sigma=10, blank_size=20, debug=0):
+def pre_process_for_anchors(image, blank_top=20, blank_bottom=-20, h_ratio=(0.1, 0.9),
+                            blur_size=3, sigma=5, blank_size=20, debug=0):
     #   去掉中间内容,返回上下定位点的二值逆图
+    #   h_ration=0 则不裁去中间部分
 
-    h_ratio = (0.1, 0.9)
-    h0 = int(image.shape[0] * h_ratio[0])
-    h1 = int(image.shape[0] * h_ratio[1])
-
+    # assert(image.ndim == 2 or image.ndim == 3)
     if image.ndim == 3:
         gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
     elif image.ndim == 2:
         gray = image.copy()
-
+    else:
+        raise Exception('The dimension of the image should be 2 or 3!')
     #   裁边
     gray[0:blank_top, :] = 255
     gray[blank_bottom:, :] = 255
     gray[:, 0:blank_size] = 255
     gray[:, -blank_size:] = 255
     #   去掉中间内容
-    gray[h0:h1, :] = 255
+    if h_ratio != 0:
+        h0 = int(image.shape[0] * h_ratio[0])
+        h1 = int(image.shape[0] * h_ratio[1])
+        gray[h0:h1, :] = 255
 
     pre = 255 - gray
     pre = cv2.GaussianBlur(pre, (blur_size, blur_size), sigma)
     binary = cv2.threshold(pre, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)[1]
 
     if debug == 1:
+        #   显示灰度图和二值图
         plt.figure(figsize=(15, 10))
         plt.subplot(211)
         plt.title('gray')
@@ -106,6 +113,7 @@ def extract_feature(binary, method=4, ker_size1=2, ker_size2=10, debug=0):
     else:
         ret = binary
     if debug == 1:
+        #   显示特征提取其后的图像
         plt.figure(figsize=(15, 10))
         plt.subplot(211)
         plt.title('before feature extraction')
@@ -155,8 +163,13 @@ def find_boxes(binary, method='connected', debug=0):
         boxes = draw_contour(binary)
     elif method == 'connected':
         boxes = draw_connected_component(binary)
+    else:
+        raise Exception('Wrong method for finding boxes!')
+
     if debug == 1:
+        #   显示boxes信息
         boxes.sort(key=lambda x: x[4][1])
+        print('The features of the boxes after feature extractions:')
         for box in boxes:
             width = box[2] - box[0]
             height = box[3] - box[1]
@@ -191,7 +204,9 @@ def find_marker_by_shape(boxes,
             markers.append(box)
 
     if debug == 1:
+        #   显示通过形状参数找到的定位点信息
         markers.sort(reverse=True, key=lambda x: x[-1])
+        print('The features of the markers picking up by shapes:')
         for box in markers:
             width = box[2] - box[0]
             height = box[3] - box[1]
@@ -202,6 +217,8 @@ def find_marker_by_shape(boxes,
             print('width:{}, height:{}, centroid:{}, w_to_h:{}, area:{}, area ratio:{}'.
                   format(width, height, centroid, w_to_h, area, area_ratio))
     elif debug == 2:
+        #   显示所有定位点信息
+        print('The features of the markers without picking up by shapes:')
         for box in boxes:
             markers.append(box)
         for box in markers:
@@ -324,6 +341,15 @@ def collect_markers_by_position(boxes, method='h', shift_threshold=30, slope_thr
     return box_list
 
 
+def collect_markers_by_area(boxes, area_threshold=2):
+    #   去除面积相差过大的定位点
+    boxes.sort(reverse=True, key=lambda x: x[-1])
+    # print(boxes[0])
+    min_area = boxes[0][-1] / area_threshold
+    boxes = [b for b in boxes if b[-1] > min_area]
+    return boxes
+
+
 def check_with_anchor(problem_markers, top_anchors, page_width, column_num):
     #    根据top_anchors位置去除异常markers
     min_shift = 100
@@ -483,14 +509,16 @@ def draw_box(image, boxes, color=(0, 255, 0), debug=0):
             cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), color, 5)
 
     if debug == 1:
+        #   显示定位点信息
         for box in boxes:
             if len(box) == 4:
                 width = box[2] - box[0]
                 height = box[3] - box[1]
                 w_to_h = width / height
+                area = width * height
                 # centroid = (box[0]+box[1])/2, (box[1]+box[3])/2
-                print('width:{}, height:{}, w_to_h:{}, top_left:{}, bottom_right:{},'.
-                      format(width, height, w_to_h, box[0:2], box[2:4]))
+                print('width:{}, height:{}, w_to_h:{}, area:{}, top_left:{}, bottom_right:{},'.
+                      format(width, height, w_to_h, area, box[0:2], box[2:4]))
             elif len(box) > 4:
                 width = box[2] - box[0]
                 height = box[3] - box[1]
@@ -622,3 +650,5 @@ def find_column(anchors, width, column_num=2, debug=0):
         print('page width:', page_width, 'column number:', column_num, 'column position:', column_pos)
 
     return page_width, column_num, column_pos
+
+

+ 73 - 16
segment/sheet_resolve/analysis/choice/choice_line_box.py

@@ -291,6 +291,9 @@ def choice_m_row_col0(left, top, image, choice_bbox_list, xml_path, img, choice_
     choice_m_dict_list = []
     for index0, box in enumerate(choice_bbox_list['regions']):  # rcnn识别的框匹配题号
         box = box['bounding_box']
+        box['ymin'] = box['ymin'] - 3  # lf_7_2 1change
+        box['xmax'] = box['xmax'] + 3
+        box['ymax'] = box['ymax'] + 3
         m_left, m_top = box['xmin']+left, box['ymin']+top,
         box_coordiante = (m_left, m_top, box['xmax']+left, box['ymax']+top)
         tree = utils.create_xml('choice_m', tree,
@@ -376,13 +379,25 @@ def choice_bbox_vague(choice_m_list0, x_y_interval_ave, singe_box_width_height_a
             x_diff = x_y_interval_ave[0]
             s_width = singe_box_width_height_ave[0]
             choice_bbox = (np.hstack((np.array([min(xmin0) - x_diff - 3 * s_width, min(ymin0)]), np.array([max(xmax0), max(ymax0)])))).tolist()
-            choice_bbox_with_index_list = (choice_bbox, choice_m_list1[1])
+            choice_box = []
+            for element in choice_bbox:
+                if element < 0:
+                    choice_box.append(0)
+                else:
+                    choice_box.append(element)
+            choice_bbox_with_index_list = (choice_box, choice_m_list1[1])
             choice_bbox_all.append(choice_bbox_with_index_list)
         elif direction == 90:
             y_diff = x_y_interval_ave[1]
             s_height = singe_box_width_height_ave[1]
             choice_bbox = (np.hstack((np.array([min(xmin0), min(ymin0) - y_diff - 3 * s_height]), np.array([max(xmax0), max(ymax0)])))).tolist()
-            choice_bbox_with_index_list = (choice_bbox, choice_m_list1[1])
+            choice_box = []
+            for element in choice_bbox:
+                if element < 0:
+                    choice_box.append(0)
+                else:
+                    choice_box.append(element)
+            choice_bbox_with_index_list = (choice_box, choice_m_list1[1])
             choice_bbox_all.append(choice_bbox_with_index_list)
     return choice_bbox_all
 
@@ -418,6 +433,9 @@ def choice_m_adjust(image, choice_m_bbox_list):
     a_z = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
     for index0, choice_m in enumerate(choice_m_bbox_list):  # rcnn识别的框匹配题号
         box = choice_m['bounding_box']
+        box['ymin'] = choice_m['ymin'] - 3  # lf_7_2 1change
+        box['xmax'] = choice_m['xmax'] + 3
+        box['ymax'] = choice_m['ymax'] + 3
         m_left, m_top = box['xmin'], box['ymin'],
         # box_coordiante = (m_left, m_top, box['xmax'], box['ymax'])
         single_choice_m = utils.crop_region(image, box)
@@ -434,7 +452,7 @@ def choice_m_adjust(image, choice_m_bbox_list):
     return choice_m_bbox_list
 
 
-def choice_m_row_col(image, choice_m_bbox_list, direction, xml_path):
+def choice_m_row_col(image, choice_m_bbox_list, direction, subject, xml_path):
     a_z = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
     choice_m_dict_list = []
 
@@ -444,14 +462,16 @@ def choice_m_row_col(image, choice_m_bbox_list, direction, xml_path):
         choice_m_for_dircetion = utils.crop_region(image, choice_m_bbox_list[random_one]['bounding_box'])
         res_dict = get_ocr_text_and_coordinate(choice_m_for_dircetion, ocr_accuracy='accurate', language_type='ENG')
         direction = get_direction(res_dict)
-
     for index0, box in enumerate(choice_m_bbox_list):  # rcnn识别的框匹配题号
         box = box['bounding_box']
+        box['ymin'] = box['ymin'] - 3   # lf_7_2 1change
+        box['xmax'] = box['xmax'] + 3
+        box['ymax'] = box['ymax'] + 3
         m_left, m_top = box['xmin'], box['ymin'],
         # box_coordiante = (m_left, m_top, box['xmax'], box['ymax'])
         single_choice_m = utils.crop_region(image, box)
         try:
-            row_col_dict = get_choice_m_row_and_col(m_left, m_top, single_choice_m)     # 所有的小框,行列
+            row_col_dict = get_choice_m_row_and_col(m_left, m_top, single_choice_m)     # 所有的小框,行列
             if len(row_col_dict) > 0:
 
                 if direction == 90:
@@ -472,6 +492,8 @@ def choice_m_row_col(image, choice_m_bbox_list, direction, xml_path):
                 row_col_dict.update({'direction': direction, 'option': choice_option,
                                      'number': title_number, 'default_points': default_points})
                 choice_m_dict_list.append(row_col_dict)
+            # else:
+            #     del choice_m_bbox_list[index0]
         except Exception:
             pass
 
@@ -496,7 +518,8 @@ def choice_m_row_col(image, choice_m_bbox_list, direction, xml_path):
         else:
             for index, s_box in enumerate(choice_m_box_dict):
                 all_small_coordinate_dict = s_box['all_small_coordinate']
-                all_small_coordinate_list = [[ele['xmin'], ele['ymin'], ele['xmax'], ele['ymax']] for ele in all_small_coordinate_dict]
+                all_small_coordinate_list = [[ele['xmin'], ele['ymin'], ele['xmax'], ele['ymax']] for ele in
+                                             all_small_coordinate_dict]
                 col = s_box['cols']
                 rows = s_box['rows']
                 if rows == 1:
@@ -511,11 +534,35 @@ def choice_m_row_col(image, choice_m_bbox_list, direction, xml_path):
                         int(np.mean(s_box_array[:, 2])) - int(np.mean(s_box_array[:, 0])),
                         int(np.mean(s_box_array[:, 3])) - int(np.mean(s_box_array[:, 1])))
                     s_box_w_h.append(s_box_wid_hei)
-        x_y_interval_arr = np.array(x_y_interval_all)
-        if len(x_y_interval_arr) == 1:
+        whether_nan = [ele for index, ele in enumerate(x_y_interval_all) if 'nan' in ele]
+        if len(whether_nan) == 0:
+            x_y_interval_all0 = x_y_interval_all
+        else:
+            if len(x_y_interval_all) == 1:
+                x_y_interval_all0 = s_box_w_h
+            else:
+                x_y_interval_all1 = []
+                for element in x_y_interval_all:
+                    if element not in whether_nan:
+                        x_y_interval_all1.append(element)
+                x_y_interval_all2 = [x_y_interval_all1[0] for i in range(len(whether_nan))]
+                x_y_interval_all0 = x_y_interval_all1 + x_y_interval_all2
+
+        x_y_interval_arr = np.array(x_y_interval_all0)
+        if len(x_y_interval_arr) == 1 and len(rows_list) != 0:
             x_y_interval_all_arr = np.array(x_y_interval_all)
-            x_ = int(np.mean(x_y_interval_all_arr[:, 0]))
-            y_ = int(np.mean(x_y_interval_all_arr[:, 1]))
+            x_ = int(np.mean(x_y_interval_all_arr[:, 0][0][0]))
+            y_ = int(np.mean(x_y_interval_all_arr[:, 0][0][1]))
+            x_y_interval_ave = (x_, y_)
+
+            singe_box_width_height_ave = s_box_w_h[0]
+
+            image_height, image_width, _ = image.shape
+            image_size = (image_width, image_height)
+        elif len(x_y_interval_arr) == 1 and len(rows_list) == 0:
+            x_y_interval_all_arr = np.array(x_y_interval_all)
+            x_ = int(np.mean(x_y_interval_all_arr[0][0]))
+            y_ = int(np.mean(x_y_interval_all_arr[0][1]))
             x_y_interval_ave = (x_, y_)
 
             singe_box_width_height_ave = s_box_w_h[0]
@@ -530,15 +577,25 @@ def choice_m_row_col(image, choice_m_bbox_list, direction, xml_path):
 
             image_height, image_width, _ = image.shape
             image_size = (image_width, image_height)
-        choice_bbox = choice_bbox_vague(choice_m_list, x_y_interval_ave, singe_box_width_height_ave, direction, image_size)
+
+        choice_bbox = choice_bbox_vague(choice_m_list, x_y_interval_ave, singe_box_width_height_ave, direction,
+                                        image_size)
         choice_m_dict_list_all_tmp = []
         for index, choice_box_ele in enumerate(choice_bbox):
-            choice_region = utils.crop_region_direct(image, choice_box_ele[0])
-            # choice_path = xml_path[: xml_path.rfind('\\')]
-            # cv2.imwrite(os.path.join(choice_path, 'choice_region_' + str(index) + '.jpg'), choice_region)
+            choice_bbox_ele = []
+            if direction == 180:
+                choice_bbox_ele0 = choice_box_ele[0]
+                choice_bbox_ele = [choice_bbox_ele0[0], choice_bbox_ele0[1] - 2,
+                                   choice_bbox_ele0[2], choice_bbox_ele0[3] + 5]
+            elif direction == 90:
+                choice_bbox_ele = choice_box_ele[0]
+            choice_region = utils.crop_region_direct(image, choice_bbox_ele)
+            choice_path = xml_path[: xml_path.rfind('/')]
+            cv2.imwrite(os.path.join(choice_path, 'choice_region_' + str(index) + '.jpg'), choice_region)
             choice_m_box_dict_new = [choice_m_box_dict[i] for i in choice_box_ele[1]]
-            choice_m_dict_list_part = get_title_number_by_choice_m.get_title_number(choice_box_ele[0], choice_region,
-                                                                                     choice_m_box_dict_new, direction)
+            choice_m_dict_list_part = get_title_number_by_choice_m.get_title_number(choice_bbox_ele, choice_region,
+                                                                                    image, choice_m_box_dict_new,
+                                                                                    x_y_interval_ave, subject, direction)
             choice_m_dict_list_all_tmp.append(choice_m_dict_list_part)
         if len(choice_m_dict_list_all_tmp) == 1:
             choice_m_dict_list_all = choice_m_dict_list_all_tmp[0]

+ 358 - 42
segment/sheet_resolve/analysis/choice/get_title_number_by_choice_m.py

@@ -4,6 +4,7 @@ import numpy as np
 import re, os
 import xml.etree.cElementTree as ET
 import cv2
+from concurrent.futures import ThreadPoolExecutor
 
 
 def combine_char(all_digital_list):
@@ -42,6 +43,82 @@ def combine_char(all_digital_list):
     return new_all_digital_list
 
 
+def modify_coordinate(words_result_choice):
+    # single character height 22  width  17
+    # two character height 21  width 30
+    digital_list = []
+    for ele in words_result_choice:
+        pattern1 = re.compile('\d+')
+        if pattern1.findall(ele['char']):
+            digital_list.append(ele)
+            # if re.search('[ABCD]', ele['char']) == None:
+            #     digital_list.append(ele)
+
+    s_l = []
+    t_l = []
+
+    for ele in digital_list:
+        char = ele['char']
+        char1 = re.findall('\d+', char)[0]
+        pattern1 = re.compile('\d+')
+        pattern2 = re.compile(r'[A|B|C|D]|[((]|[))]')  # ABCD ..
+        pattern3 = re.compile('[\u4e00-\u9fa5]')  # chinese char
+        if pattern1.findall(ele['char']):
+            if not pattern3.findall(ele['char']) and not pattern2.findall(ele['char']):
+                if len(char1) == 1:
+                    s_l.append(ele)
+                else:
+                    t_l.append(ele)
+
+    if s_l != []:
+        s_h_list = [ele['location']['height'] for ele in s_l]
+        s_w_list = [ele['location']['width'] for ele in s_l]
+
+        s_h_mean = int(np.mean(np.array(s_h_list)))  # single  height
+        s_w_mean = int(np.mean(np.array(s_w_list)))  # single  width
+    else:
+        s_h_list = [ele['location']['height'] for ele in words_result_choice]
+        s_h_mean = int(np.mean(np.array(s_h_list)))  # single  height
+        s_w_mean = 18
+
+    if t_l != []:
+        t_h_list = [ele['location']['height'] for ele in t_l]
+        t_w_list = [ele['location']['width'] for ele in t_l]
+
+        t_h_mean = int(np.mean(np.array(t_h_list)))  # single  height
+        t_w_mean = int(np.mean(np.array(t_w_list)))  # single  width
+    else:
+        t_h_list = [ele['location']['height'] for ele in words_result_choice]
+        t_h_mean = int(np.mean(np.array(t_h_list)))  # single  height
+        t_w_mean = 31
+
+    digital_list_all = []
+    for ele in words_result_choice:
+        pattern1 = re.compile('\d+')
+        if pattern1.findall(ele['char']):
+            if re.search('[A|B|C|D]', ele['char']) != None:
+                char1 = len(pattern1.findall(ele['char'])[0])
+                if char1 == 1:
+                    location = {}
+                    location['left'] = ele['location']['left']
+                    location['top'] = ele['location']['top']
+                    location['width'] = s_w_mean
+                    location['height'] = s_h_mean
+                    ele.update({'location': location})
+                    digital_list_all.append(ele)
+                elif char1 == 2:
+                    location = {}
+                    location['left'] = ele['location']['left']
+                    location['top'] = ele['location']['top']
+                    location['width'] = t_w_mean
+                    location['height'] = t_h_mean
+                    ele.update({'location': location})
+                    digital_list_all.append(ele)
+            else:
+                digital_list_all.append(ele)
+    return digital_list_all
+
+
 def get_x_diff_and_y_diff0(single_choice_m_coordinates):
     single_choice_m_matrix = np.array(single_choice_m_coordinates)
     x_diff = single_choice_m_matrix[1:, 0] - single_choice_m_matrix[:-1, 2]
@@ -303,26 +380,203 @@ def analysis_s_box(choice_m_bbox_list):
     return choice_m_box_dict
 
 
-def get_title_number(choice_bbox, choice_region, choice_m_box_dict, direction):
-    words_result_choice = get_ocr_text_and_coordinate0(choice_region, ocr_accuracy='accurate', language_type='CHN_ENG')
+def decide_box_left(one_line_box, digital_list):
+    one_line_box = sorted(one_line_box, key=lambda k: k[0])
+
+    min_w = one_line_box[0][0]
+    max_w = one_line_box[-1][2]
+    width = max_w - min_w
+
+    s_box_width = 27
+    s_box_height = 20
+
+    # compare two box
+    s_box = one_line_box[0]
+
+    s_box_xmin = s_box[0]
+    s_box_ymin = s_box[1]
+    s_box_xmax = s_box[2]
+    s_box_ymax = s_box[3]
+
+    mid_x1 = int(s_box_xmin + (s_box_xmax - s_box_xmin) // 2)
+    mid_y1 = int(s_box_ymin + (s_box_ymax - s_box_ymin) // 2)
+
+    digital_box = digital_list[0]
+
+    xmin_d = digital_box['location']['left']
+    ymin_d = digital_box['location']['top']
+    xmax_d = digital_box['location']['left'] + digital_box['location']['width']
+    ymax_d = digital_box['location']['top'] + digital_box['location']['height']
+
+    if xmin_d < s_box_xmin and mid_x1 <= 2 * width and ymin_d <= mid_y1 <= ymax_d:
+        return True
+    else:
+        return False
+
+
+def analyse_words_result(words_result):
+    char_list = []
+    for ele in words_result:
+        digital_list = []
+        un_digital_list = []
+        for ele0 in ele['chars']:
+            digital = re.findall('\d+', ele0['char'])
+            if len(digital) == 0:
+                un_digital_list.append(ele0)
+            else:
+                digital_list.append(ele0)
+        digital_list1 = combine_char(digital_list)
+        words_dict = {}
+        words_dict['words'] = ele['words']
+        char_list_all = digital_list1 + un_digital_list + [words_dict]
+        char_list.append(char_list_all)
+    return char_list
+
+
+def get_chinese_choice_number(choice_m_box_dict, image):
+    for index, ele in enumerate(choice_m_box_dict):
+        choice_m_box = [ele['bounding_box']['xmin'], ele['bounding_box']['ymin'],
+                        ele['bounding_box']['xmax'], ele['bounding_box']['ymax']]
+
+        all_small_coordinate_dict = ele['all_small_coordinate']
+        all_small_coordinate_list = [[ele['xmin'], ele['ymin'], ele['xmax'], ele['ymax']] for ele in
+                                     all_small_coordinate_dict]
+        res_list = get_one_line_box(all_small_coordinate_list, ele['single_height'])
+
+        c_w = choice_m_box[2] - choice_m_box[0]
+        c_h = choice_m_box[3] - choice_m_box[1]
+
+        choice_m_box_forward = [choice_m_box[0] - int(1.5 * c_w), ele['bounding_box']['ymin'] - 5,
+                                    ele['bounding_box']['xmax'], ele['bounding_box']['ymax']]
+
+        all_small_coordinate_list_new = []
+        for s_box_line in res_list:
+            line_box = []
+            for s_box in s_box_line:
+                ss_box = utils.get_img_region_box1(s_box, choice_m_box_forward)
+                line_box.append(ss_box)
+            line_box_array = np.array(line_box)
+            width_s = int(np.mean(line_box_array[:, 2] - line_box_array[:, 0]))
+            width_h = int(np.mean(line_box_array[:, 3] - line_box_array[:, 1]))
+            s_box_dict = {}
+            s_box_dict['bbox'] = line_box
+            s_box_dict['width_s'] = width_s
+            s_box_dict['width_h'] = width_h
+            all_small_coordinate_list_new.append(s_box_dict)
+        # print(all_small_coordinate_list_new)
+
+        region = utils.crop_region_direct(image, choice_m_box_forward)
+        words_result = get_ocr_text_and_coordinate0(region)
+        char_list = analyse_words_result(words_result)
+
+        number0 = []
+        value0 = []
+        for one_line in all_small_coordinate_list_new:
+            one_line_box = one_line['bbox']
+
+            number_list = []
+            for ele_line in char_list:
+                ele_line_copy = ele_line.copy()
+                ele_line_copy = [ele for ele in ele_line_copy if 'char' in ele]
+
+                digital_list_line = sorted(ele_line_copy, key=lambda k: k['location']['left'])
+                # TODO  decide left
+
+                if decide_box_left(one_line_box, digital_list_line) == True:
+                    number_list.append(ele_line)
+                else:
+                    continue
+            # print(number_list)
+
+            for s_bbox0 in number_list:
+                for s_bbox in s_bbox0:
+                    if 'words' in s_bbox:
+                        words = s_bbox['words']
+
+                        pattern1 = re.compile(
+                            '^\d+[,、.]?[\u4e00-\u9fa5]?[((]?\d+分+[))]?[((]?\d+[))]?|^\d+[,、.]?[\u4e00-\u9fa5]?[((]?\d+分+[))]?')
+                        result1 = re.findall(pattern1, words)
+
+                        pattern11 = re.compile('^\d+[,、.]?[\u4e00-\u9fa5]?[((]?\d+[))]?')
+                        result11 = re.findall(pattern11, words)
+
+                        pattern2 = re.compile('[((]?\d+[))]?')
+                        result2 = re.findall(pattern2, words)
+
+                        pattern3 = re.compile('[((]?\d?分+[))]?')
+                        result3 = re.findall(pattern3, words)
+
+                        if result2 and result3:
+                            number = int(result2[0])
+                            value_str = re.findall('\d+', result2[1])
+                            value = int(value_str[0])
+                            number0.append(number)
+                            value0.append(value)
+
+                        elif result2 and not result3:
+                            number = int(result2[0])
+                            value = -1
+                            number0.append(number)
+                            value0.append(value)
+
+        ele.update({'number': number0, 'default_points': value0})
+    return choice_m_box_dict
+
+
+def get_title_number(choice_bbox, choice_region, image, choice_m_box_dict, x_y_interval_ave0, subject, direction):
+    # c_h, c_w, _ = choice_region.shape
+    # if c_h * c_w > 1000 * 1000:
+    #     # baidu ocr format
+    #     words_result_choice_raw = get_ocr_text_and_coordinate0(choice_region, ocr_accuracy='accurate', language_type='CHN_ENG')
+    #     words_result_choice1 = [ele1 for ele in words_result_choice_raw for ele1 in ele['chars']]
+    #     words_result_choice1 = combine_char(words_result_choice1)    # baidu need
+
+    # else:
+    # tr ocr
+    words_result_choice1 = ''
+    if direction == 90:
+        words_result_choice_raw = get_ocr_text_and_coordinate0(choice_region, ocr_accuracy='accurate', language_type='CHN_ENG')
+        char_list = [ele1 for ele in words_result_choice_raw for ele1 in ele['chars']]
+        words_result_choice1 = combine_char(char_list)
+
+    elif direction == 180:
+        choice_gray = cv2.cvtColor(choice_region, cv2.COLOR_BGR2GRAY)
+        with ThreadPoolExecutor() as executor:
+            future = executor.submit(tr.run, choice_gray)
+            words_result_choice0 = future.result()
+            words_result_choice_tmp = utils.change_baidu_to_tr_format(words_result_choice0)
+            # tr  I  -- > 1
+            words_result_choice1 = []
+            for ele in words_result_choice_tmp:
+                if 'I' in ele['char']:
+                    element = re.sub('I', '1', ele['char'])
+                    ele.update({'char': element})
+                    words_result_choice1.append(ele)
+                else:
+                    words_result_choice1.append(ele)
+    words_result_choice = modify_coordinate(words_result_choice1)
+
     all_digital_list0 = []
-    pattern = re.compile(r'\d')
-    for i, chars_dict in enumerate(words_result_choice):
-        chars_list = chars_dict['chars']
-        for ele in chars_list:
-            if pattern.search(ele['char']):
-                all_digital_list0.append(ele)
+    pattern = re.compile(r'\d+')
+    for i, ele in enumerate(words_result_choice):
+        if pattern.findall(ele['char']):
+            all_digital_list0.append(ele)
+
+    for i, ele in enumerate(all_digital_list0):
+        if pattern.findall(ele['char']):
+            new_char = pattern.findall(ele['char'])[0]
+            ele.update({'char': new_char})
+    # # tree = ET.parse(r'C:\Users\admin\Desktop\exam_segment_django113\segment\exam_info\000000-template.xml')  # xml tree
+    # # for index, bbox in enumerate(all_digital_list0):
+    # #     # bbox0 = region_info['bbox']
+    # #     location = bbox['location']
+    # #     xmin = location['left']
+    # #     ymin = location['top']
+    # #     xmax = location['left'] + location['width']
+    # #     ymax = location['top'] + location['height']
+    # #     tree = utils.create_xml(bbox['char'], tree, xmin, ymin, xmax, ymax)
+    # # tree.write(r'C:\Users\admin\Desktop\exam_segment_django113\segment\exam_image\sheet\arts_comprehensive\2020-02-05\choice_region_00.xml')
 
-    # tree = ET.parse(r'C:\Users\admin\Desktop\exam_segment_django113\segment\exam_info\000000-template.xml')  # xml tree
-    # for index, bbox in enumerate(all_digital_list0):
-    #     # bbox0 = region_info['bbox']
-    #     location = bbox['location']
-    #     xmin = location['left']
-    #     ymin = location['top']
-    #     xmax = location['left'] + location['width']
-    #     ymax = location['top'] + location['height']
-    #     tree = utils.create_xml(bbox['char'], tree, xmin, ymin, xmax, ymax)
-    # tree.write(r'C:\Users\admin\Desktop\exam_segment_django113\segment\exam_image\sheet\arts_comprehensive\2020-02-05\choice_region_00.xml')
 
     delete_list = []
     for ele_digtal in all_digital_list0:
@@ -333,12 +587,20 @@ def get_title_number(choice_bbox, choice_region, choice_m_box_dict, direction):
             ymax_d = ele_digtal['location']['top'] + ele_digtal['location']['height']
 
             ele_digtal_bbox = [xmin_d, ymin_d, xmax_d, ymax_d]
+            ele_digtal_bbox_tmp = [xmin_d + 4, ymin_d - 3, xmax_d - 6, ymax_d - 3]
+
 
             ele_choice_m_bbox = [ele_choice_m['bounding_box']['xmin'], ele_choice_m['bounding_box']['ymin'],
                                  ele_choice_m['bounding_box']['xmax'], ele_choice_m['bounding_box']['ymax']]
 
             choice_m_new_box = utils.get_img_region_box1(ele_choice_m_bbox, choice_bbox)
-            if utils.decide_coordinate_full_contains2(choice_m_new_box, ele_digtal_bbox) == True:
+            ele_digtal_bbox_tmp_array = np.array(ele_digtal_bbox_tmp)
+            neg_ele = np.where(ele_digtal_bbox_tmp_array < 0)[0]
+
+            if len(neg_ele) != 0:
+                delete_list.append(ele_digtal)
+                break
+            elif utils.decide_coordinate_full_contains2(choice_m_new_box, ele_digtal_bbox_tmp) == True:
                 delete_list.append(ele_digtal)
 
     all_digital_list = []
@@ -348,6 +610,7 @@ def get_title_number(choice_bbox, choice_region, choice_m_box_dict, direction):
             continue
         else:
             all_digital_list.append(ele)
+    # print(all_digital_list)
 
     # new_all_digital_list = combine_char(all_digital_list)
     #
@@ -407,8 +670,11 @@ def get_title_number(choice_bbox, choice_region, choice_m_box_dict, direction):
                 choice_m_box_dict_new.append(s_choice_m_box)
             else:
                 x_y_interval = utils.get_x_diff_and_y_diff1(all_small_coordinate_new, col)
-                x_y_interval_all.append(x_y_interval)
-
+                if 'nan' in x_y_interval:
+                    x_y_interval = x_y_interval_ave0
+                    x_y_interval_all.append(x_y_interval)
+                else:
+                    x_y_interval_all.append(x_y_interval)
                 all_small_coordinate_list = sorted(all_small_coordinate_new, key=lambda k: k[1])
                 s_box_array = np.array(all_small_coordinate_list)
                 s_box_wid_hei = (int(np.mean(s_box_array[:, 2])) - int(np.mean(s_box_array[:, 0])),
@@ -427,14 +693,14 @@ def get_title_number(choice_bbox, choice_region, choice_m_box_dict, direction):
     s_box_w_h_arr = np.array(s_box_w_h)
     singe_box_width_height_ave = (int(np.mean(s_box_w_h_arr[:, 0])), int(np.mean(s_box_w_h_arr[:, 1])))
 
-    digital_list_by_choice_m = get_digital_near_choice_m_box(all_digital_list0, choice_m_box_dict_new, x_y_interval_ave, singe_box_width_height_ave, direction)
+    digital_list_by_choice_m = get_digital_near_choice_m_box(all_digital_list, choice_m_box_dict_new, x_y_interval_ave, singe_box_width_height_ave, direction)
 
     for number in digital_list_by_choice_m:
         title_number_list = number['title_number']
 
         all_digital_list = sorted(title_number_list, key=lambda k: k.get('location')['top'])
-        new_title_number_list = combine_char(all_digital_list)
-        number.update({'title_number': new_title_number_list})
+        # new_title_number_list = combine_char(all_digital_list)
+        number.update({'title_number': all_digital_list})
 
     # tree = ET.parse(r'C:\Users\admin\Desktop\exam_segment_django113\segment\exam_info\000000-template.xml')  # xml tree
     # for index, bbox0 in enumerate(digital_list_by_choice_m):
@@ -485,23 +751,50 @@ def get_title_number(choice_bbox, choice_region, choice_m_box_dict, direction):
                         title_number_bbox = [ele1['location']['left'], ele1['location']['top'],
                                              ele1['location']['left'] + ele1['location']['width'],
                                              ele1['location']['top'] + ele1['location']['height']]
-                        row_box = {}
-                        if title_number_bbox[1] - single_bbox_width_height[1] < ele0[0][1] < title_number_bbox[1] + \
-                                single_bbox_width_height[1] \
-                                and title_number_bbox[3] - single_bbox_width_height[1] < ele0[0][3] < \
-                                title_number_bbox[3] + single_bbox_width_height[1]:
-                            row_box['title_number'] = ele1
-                            row_box['row_box'] = ele0
-                            row_box_.append(row_box)
-                            index_list.append(index0)
-                        elif title_number_bbox[0] - single_bbox_width_height[0] < ele0[0][0] < title_number_bbox[0] + \
-                                single_bbox_width_height[0] \
-                                and title_number_bbox[2] - single_bbox_width_height[0] < ele0[0][2] < \
-                                title_number_bbox[2] + single_bbox_width_height[0]:
-                            row_box['title_number'] = ele1
-                            row_box['row_box'] = ele0
-                            row_box_.append(row_box)
-                            index_list.append(index0)
+                        if choice_m_s['direction'] == 180:
+                            # x axis
+                            pattern1 = int(3 / 5 * (title_number_bbox[0] - single_bbox_width_height[0])) < \
+                                       int(9/10 * ele0[0][0]) < \
+                                       title_number_bbox[0] + int(1.6 * single_bbox_width_height[0]) \
+                                       and int(3 / 5 * (title_number_bbox[2] - single_bbox_width_height[0])) < \
+                                       int(ele0[0][2] * 9/10) < \
+                                       title_number_bbox[2] + int(1.6 * single_bbox_width_height[0])
+                            # y axis
+                            pattern2 = title_number_bbox[1] - single_bbox_width_height[1] < ele0[0][1] < \
+                                       title_number_bbox[1] + int(1.6 * single_bbox_width_height[1]) \
+                                       and int(3 / 5 * (title_number_bbox[3] - single_bbox_width_height[1])) < \
+                                       int(ele0[0][3] * 9 /10) < \
+                                       title_number_bbox[3] + int(1.6 * single_bbox_width_height[1])
+
+                            row_box = {}
+                            if pattern1 and pattern2:
+                                row_box['title_number'] = ele1
+                                row_box['row_box'] = ele0
+                                row_box_.append(row_box)
+                                index_list.append(index0)
+
+                        elif choice_m_s['direction'] == 90:
+                            # x axis
+                            pattern1 = title_number_bbox[0] - single_bbox_width_height[0] <= ele0[0][0] <= \
+                                       title_number_bbox[0] + int(1.6 * single_bbox_width_height[0]) \
+                                       and int(3 / 5 * (title_number_bbox[2] - single_bbox_width_height[0])) <= \
+                                       int(ele0[0][2] * 0.9) \
+                                       <= title_number_bbox[2] + int(1.6 * single_bbox_width_height[0])
+
+                            # y axis
+                            pattern2 = title_number_bbox[1] - single_bbox_width_height[1] <= \
+                                       ele0[0][1] \
+                                       <= title_number_bbox[1] + int(3 * single_bbox_width_height[1]) \
+                                       and int(3 / 5 * (title_number_bbox[3] - single_bbox_width_height[1])) <= \
+                                       int(ele0[0][3] * 0.9) \
+                                       <= title_number_bbox[3] + int(3 * single_bbox_width_height[1])
+                            row_box = {}
+                            if pattern1 and pattern2:
+                                row_box['title_number'] = ele1
+                                row_box['row_box'] = ele0
+                                row_box_.append(row_box)
+                                index_list.append(index0)
+                # print(index_list)
                 index_list = sorted(list(set(index_list)))
                 index0 = list(set([i for i in range(0, row_and_col[0])]) - set(index_list))  # lack index
                 number0 = choice_m_s['number']
@@ -512,7 +805,11 @@ def get_title_number(choice_bbox, choice_region, choice_m_box_dict, direction):
                 choice_m_s.update({'title_number': new_number_list})
                 all_list_new.append(choice_m_s)
             except Exception as e:
+                number0 = choice_m_s['number']
+                new_number_list = utils.infer_number(number0)
+                choice_m_s.update({'title_number': new_number_list})
                 all_list_new.append(choice_m_s)
+
     # print(all_list_new)
     title_number_by_choice_m_list = []     # sort change coordinate
     for index, single_choice_m in enumerate(all_list_new):
@@ -536,4 +833,23 @@ def get_title_number(choice_bbox, choice_region, choice_m_box_dict, direction):
         single_choice_m.pop('s_box_w_h')
         single_choice_m.pop('title_number')
         title_number_by_choice_m_list.append(single_choice_m)
-    return title_number_by_choice_m_list
+    # print(title_number_by_choice_m_list)
+
+    # Chinese number and value    11 (3 fen)
+    if subject == 'chinese' or subject == 'chinese_blank':
+        title_number_by_choice_m_list_tmp = title_number_by_choice_m_list.copy()
+        number_list_all = [ele['number'] for ele in title_number_by_choice_m_list_tmp if ele['number']]
+        number_list_all_ = [ele1 for ele in number_list_all for ele1 in ele]
+        number_tmp = list(set(number_list_all_))
+        if len(number_tmp) == 1 and number_tmp[0] == -1:
+            try:
+                number_by_choice_m_list = get_chinese_choice_number(title_number_by_choice_m_list_tmp, image)
+                return number_by_choice_m_list
+            except Exception as e:
+                print('Chinses number and value error, please check')
+                return title_number_by_choice_m_list
+        else:
+            return title_number_by_choice_m_list
+
+    else:
+        return title_number_by_choice_m_list

+ 637 - 0
segment/sheet_resolve/analysis/correct/register_once_mbq_recursive_by_page.py

@@ -0,0 +1,637 @@
+# @Author  : lightXu
+# @File    : register_once_mbq_recursive_by_page.py
+# @Time    : 2019/8/8 0008 上午 10:25
+import argparse
+import time
+from os import path as os_path, makedirs
+import math
+import cv2  # VERSION <=3.3
+import numpy as np
+from PIL import Image
+from pywt import wavedec2, waverec2
+
+REGISTER_RIGHT_HIGH = 0.85
+REGISTER_RIGHT_LOW = 0.5
+RECURSIVE_EPSILON = 0.02
+STD_DIFF = 5.0
+FIXED_MAX_LENGTH = 800
+CLIP_RATIO = 0.25
+WEIGHT_GLOBAL = 0.1
+WEIGHT_CORNER = 0.9
+SINGLE_GLOBAL = SINGLE_CORNER = 0.92
+NEED_FLIP = True
+MATCH_PIX = 10
+L_W_RATIO = 0.2
+RATIO = 0.7
+RECURSIVE_RATIO = 1.0
+
+
+def _binary_array_to_hex(arr):
+    """
+    internal function to make a hex string out of a binary array.
+    """
+    bit_string = ''.join(str(b) for b in 1 * arr.flatten())
+    width = int(np.ceil(len(bit_string) / 4))
+    return '{:0>{width}x}'.format(int(bit_string, 2), width=width)
+
+
+class ImageHash(object):
+    """
+    Hash encapsulation. Can be used for dictionary keys and comparisons.
+    """
+
+    def __init__(self, binary_array):
+        self.hash = binary_array
+
+    def __str__(self):
+        return _binary_array_to_hex(self.hash.flatten())
+
+    def __repr__(self):
+        return repr(self.hash)
+
+    def __sub__(self, other):
+        if other is None:
+            raise TypeError('Other hash must not be None.')
+        if self.hash.size != other.hash.size:
+            raise TypeError('ImageHashes must be of the same shape.', self.hash.shape, other.hash.shape)
+        return np.count_nonzero(self.hash.flatten() != other.hash.flatten())
+
+    def __eq__(self, other):
+        if other is None:
+            return False
+        return np.array_equal(self.hash.flatten(), other.hash.flatten())
+
+    def __ne__(self, other):
+        if other is None:
+            return False
+        return not np.array_equal(self.hash.flatten(), other.hash.flatten())
+
+    def __hash__(self):
+        # this returns a 8 bit integer, intentionally shortening the information
+        return sum([2 ** (i % 8) for i, v in enumerate(self.hash.flatten()) if v])
+
+
+def whash(image, hash_size=8, image_scale=None, mode='haar', remove_max_haar_ll=True):
+    if image_scale is not None:
+        assert image_scale & (image_scale - 1) == 0, "image_scale is not power of 2"
+    else:
+        image_natural_scale = 2 ** int(np.log2(min(image.size)))
+        image_scale = max(image_natural_scale, hash_size)
+
+    ll_max_level = int(np.log2(image_scale))
+
+    level = int(np.log2(hash_size))
+    assert hash_size & (hash_size - 1) == 0, "hash_size is not power of 2"
+    assert level <= ll_max_level, "hash_size in a wrong range"
+    dwt_level = ll_max_level - level
+
+    image = image.resize((image_scale, image_scale), Image.ANTIALIAS)
+    pixels = np.asarray(image) / 255
+
+    # Remove low level frequency LL(max_ll) if @remove_max_haar_ll using haar filter
+    if remove_max_haar_ll:
+        coeffs = wavedec2(pixels, 'haar', level=ll_max_level)
+        coeffs = list(coeffs)
+        coeffs[0] *= 0
+        pixels = waverec2(coeffs, 'haar')
+
+    # Use LL(K) as freq, where K is log2(@hash_size)
+    coeffs = wavedec2(pixels, mode, level=dwt_level)
+    dwt_low = coeffs[0]
+
+    # Substract median and compute hash
+    med = np.median(dwt_low)
+    diff = dwt_low > med
+    return ImageHash(diff)
+
+
+def phash(image, hash_size=8, highfreq_factor=4):
+    if hash_size < 2:
+        raise ValueError("Hash size must be greater than or equal to 2")
+
+    img_size = hash_size * highfreq_factor
+    image = image.resize((img_size, img_size), Image.ANTIALIAS)
+    pixels = np.asarray(image)
+    dct = cv2.dct(pixels.astype(np.float))
+    dctlowfreq = dct[:hash_size, :hash_size]
+    med = np.median(dctlowfreq)
+    diff = dctlowfreq > med
+    return ImageHash(diff)
+
+
+def dhash(image, hash_size=8):
+    # resize(w, h), but numpy.array((h, w))
+    if hash_size < 2:
+        raise ValueError("Hash size must be greater than or equal to 2")
+
+    image = image.resize((hash_size + 1, hash_size), Image.ANTIALIAS)
+    pixels = np.asarray(image)
+    # compute differences between columns
+    diff = pixels[:, 1:] > pixels[:, :-1]
+    return ImageHash(diff)
+
+
+def average_hash(image, hash_size=8):
+    if hash_size < 2:
+        raise ValueError("Hash size must be greater than or equal to 2")
+
+    # reduce size and complexity, then covert to grayscale
+    image = image.resize((hash_size, hash_size), Image.ANTIALIAS)
+
+    # find average pixel value; 'pixels' is an array of the pixel values, ranging from 0 (black) to 255 (white)
+    pixels = np.asarray(image)
+    avg = pixels.mean()
+
+    # create string of bits
+    diff = pixels > avg
+    # make a hash
+    return ImageHash(diff)
+
+
+def image_hash(image):
+    image = Image.fromarray(image)
+    hash_size = 10
+    dhash_code = dhash(image, hash_size=hash_size)
+    # ahash_code = average_hash(image, hash_size=hash_size)
+    phash_code = phash(image, hash_size=hash_size, highfreq_factor=4)
+    hash_code = whash(image, image_scale=64, hash_size=8, mode='db4')
+    return hash_code, dhash_code, phash_code
+
+
+def hash_similarity(hash_as_t, hash_to_regi):
+    similarity = 1 - (hash_as_t - hash_to_regi) / len(hash_as_t.hash) ** 2  # 相似性
+    return similarity
+
+
+def hash_detection(t_image, image_to_regi):
+    # t_hash = image_hash(t_image)
+    # r_hash = image_hash(image_to_regi)
+
+    [t_hash, tdh, tph] = image_hash(t_image)
+    [r_hash, rdh, rph] = image_hash(image_to_regi)
+    simi = hash_similarity(tdh, rdh)
+    # simi = hash_similarity(t_hash, r_hash)
+
+    evidence = [hash_similarity(t_hash, r_hash),
+                hash_similarity(tdh, rdh),
+                hash_similarity(tph, rph)]
+    return simi, evidence
+
+
+def read_single_img(img_path):
+    try:
+        im = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), -1)
+        # im = np.asarray(Image.open(img_path).convert('L'))
+        # im = np.asarray(Image.open(img_path))
+        # im = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
+        # im = im[:, :, 0]
+        if len(im.shape) == 3:
+            im = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
+    except FileNotFoundError as e:
+        raise e
+    return im
+
+
+def write_single_img(dst, save_path):
+    try:
+        # cv2.imencode('.jpg', dst)[1].tofile(save_path)
+        save_img = Image.fromarray(dst)
+        save_img.save(save_path)
+    except FileNotFoundError as e:
+        raise e
+
+
+def resize_by_percent(im, percent):
+    height = im.shape[0]
+    width = im.shape[1]
+    new_x = int(width * percent)
+    new_y = int(height * percent)
+    res = cv2.resize(im, (new_x, new_y), interpolation=cv2.INTER_AREA)
+    return res
+
+
+def resize_by_fixed_size(im, new_x, new_y):
+    res = cv2.resize(im, (new_x, new_y), interpolation=cv2.INTER_AREA)
+    return res
+
+
+def image_flip(image, angle='90'):
+    rotate_code = {"90": Image.ROTATE_90, "180": Image.ROTATE_180, "270": Image.ROTATE_270, }  # 逆时针
+    if str(angle) != "0":
+        img = ''
+        if 'PIL' not in str(type(image)):
+            img = Image.fromarray(image)
+        rotate_img = img.transpose(rotate_code[str(angle)])  # 翻转
+        cv2_img_tmp = np.asarray(rotate_img)
+        if len(cv2_img_tmp.shape) > 2:
+            cv2_img = cv2.cvtColor(cv2_img_tmp, cv2.COLOR_RGB2BGR)
+        else:
+            cv2_img = cv2_img_tmp
+    else:
+        cv2_img = image
+
+    return cv2_img
+
+
+def get_corners_of_image(image, clip_ratio=0.25):
+    t_x, t_y = image.shape[1], image.shape[0]
+    clip_x = int(t_x * clip_ratio)
+    clip_y = int(t_y * clip_ratio)
+    y_del_list = [ele for ele in range(clip_y, t_y - clip_y)]
+    x_del_list = [ele for ele in range(clip_x, t_x - clip_x)]
+
+    t_corner_array = np.delete(image, y_del_list, axis=0)
+    t_corner_array = np.delete(t_corner_array, x_del_list, axis=1)
+
+    return t_corner_array
+
+
+# def sift_kp(image):
+#     sift = cv2.xfeatures2d.SIFT_create()
+#     # sift = cv2.xfeatures2d.SURF_create()
+#     kp, des = sift.detectAndCompute(image, None)
+#     return kp, des
+
+
+def sift_kp(image):
+    fast = cv2.FastFeatureDetector_create(threshold=35)
+    kp = fast.detect(image, None)
+    dsp = cv2.xfeatures2d.BriefDescriptorExtractor_create()
+    kp, des = dsp.compute(image, kp)
+    return kp, des
+
+
+def get_good_match(des1, des2):
+    bf = cv2.BFMatcher()
+    matches = bf.knnMatch(des1, des2, k=2)  # des1为模板图,des2为匹配图
+    for ele in matches:
+        if ele[1].distance == 0:
+            ele[1].distance = 0.001
+    matches = sorted(matches, key=lambda x: x[0].distance / x[1].distance)
+    good = []
+    for m, n in matches:
+        if m.distance < 0.75 * n.distance:
+            good.append(m)
+    return good
+
+
+def transform_mtx_from_points(template_points, regi_points):
+    points1 = np.float64(np.matrix([[point[0], point[1]] for point in template_points]))
+    points2 = np.float64(np.matrix([[point[0], point[1]] for point in regi_points]))
+
+    c1 = np.mean(points1, axis=0)
+    c2 = np.mean(points2, axis=0)
+    points1 -= c1
+    points2 -= c2
+    s1 = np.std(points1)
+    s2 = np.std(points2)
+    points1 /= s1
+    points2 /= s2
+    U, S, Vt = np.linalg.svd(points1.T * points2)
+    R = (U * Vt).T
+    return np.vstack([np.hstack(((s2 / s1) * R, c2.T - (s2 / s1) * R * c1.T)), np.array([0., 0., 1.])])
+
+
+def sift_transform_mtx_by_match(good_match, t_kp, r_kp, ratio=1):
+    # x_ratio != y_ratio may cause bias
+    x_ratio = y_ratio = 1 / ratio
+    transform_mtx = np.ones((3, 3))
+    if len(good_match) > 4:
+        ratio_array = np.array([x_ratio, y_ratio])
+        t_pts = np.float32([t_kp[m.queryIdx].pt for m in good_match])
+        r_pts = np.float32([r_kp[m.trainIdx].pt for m in good_match])
+
+        t_pts = (t_pts * ratio_array)
+        r_pts = (r_pts * ratio_array)
+        pts_t_pers = (t_pts * ratio_array).reshape(-1, 1, 2)
+        pts_r_pers = (r_pts * ratio_array).reshape(-1, 1, 2)
+        _, status_array = cv2.findHomography(pts_t_pers, pts_r_pers, cv2.RANSAC, ransacReprojThreshold=4)
+
+        status_arr = status_array.reshape(-1)
+        idx = np.where(status_arr >= 1)
+
+        t_points = np.float32(t_pts[idx])
+        r_points = np.float32(r_pts[idx])
+        transform_mtx = transform_mtx_from_points(t_points, r_points)
+    return transform_mtx
+
+
+def perspective(regi_img, mtx):
+    width, height = regi_img.shape[1], regi_img.shape[0]
+    # img_out = cv2.warpPerspective(raw_img_r, mtx, (width, height),
+    #                               flags=cv2.INTER_AREA + cv2.WARP_INVERSE_MAP, borderValue=(255, 255, 255))
+
+    img_out = cv2.warpAffine(regi_img, mtx[:2], (width, height), flags=cv2.INTER_AREA + cv2.WARP_INVERSE_MAP,
+                             borderValue=(255, 255, 255))
+    return img_out
+
+
+def pre_process(t_image_path_list, r_img_path):
+    flip_degree = 0
+    t_image_tmp = read_single_img(t_image_path_list[0])
+    t_y, t_x = t_image_tmp.shape[0], t_image_tmp.shape[1]
+    regi_image = read_single_img(r_img_path)
+    r_y, r_x = regi_image.shape[0], regi_image.shape[1]
+
+    if (r_y - r_x) * (t_y - t_x) < 0:
+        regi_image = image_flip(regi_image, "90")
+        r_y, r_x = r_x, r_y
+        flip_degree = 90
+
+    n_x, n_y = t_x, t_y
+    if (r_y / r_x) != (t_y / t_x):
+        # if abs(int(r_y * t_x / t_y - r_x)) > MATCH_PIX:
+        #     raise ValueError("source image_shape(height/width) is not identical to register image_shape with 10+ "
+        #                      "pixels.")
+
+        if abs(max(r_y, r_x) / min(r_y, r_x) - max(t_y, t_x) / min(t_y, t_x)) > L_W_RATIO:
+            raise ValueError("template image shape is not identical to register image shape.")
+        else:
+            if r_y > t_y:  # resize regi_image
+                regi_image = resize_by_fixed_size(regi_image, t_x, t_y)
+                n_x, n_y = t_x, t_y
+
+            if r_y < t_y:
+                # t_norm_ratio = r_y / t_y
+                t_image_tmp = resize_by_fixed_size(t_image_tmp, r_x, r_y)
+                n_x, n_y = r_x, r_y
+
+    return (n_y, n_x), regi_image, t_image_tmp, flip_degree
+
+
+# def pre_process(t_image_path_list, r_img_path):
+#     flip_degree = 0
+#     t_image_tmp = read_single_img(t_image_path_list[0])
+#     t_y, t_x = t_image_tmp.shape[0], t_image_tmp.shape[1]
+#     regi_image = read_single_img(r_img_path)
+#     r_y, r_x = regi_image.shape[0], regi_image.shape[1]
+#
+#     if abs(max(r_y, r_x) / min(r_y, r_x) - max(t_y, t_x) / min(t_y, t_x)) > L_W_RATIO:
+#         raise ValueError("template image shape is not identical to register image shape.")
+#
+#     if (r_y - r_x) * (t_y - t_x) < 0:
+#         regi_image = image_flip(regi_image, "90")
+#         r_y, r_x = r_x, r_y
+#         flip_degree = 90
+#
+#     n_x, n_y = t_x, t_y
+#
+#     if t_y*t_x == r_y*r_x:
+#         if (t_y, t_x) != (r_y, r_x):
+#             regi_image = resize_by_fixed_size(regi_image, t_x, t_y)
+#     else:
+#         regi_image = resize_by_fixed_size(regi_image, t_x, t_y)
+#
+#     return (n_y, n_x), regi_image, t_image_tmp, flip_degree
+
+
+def ds_evidence(evid_list):
+    evid_array = np.array(evid_list)
+    evid_normed = evid_array / evid_array.sum(axis=0)
+
+    k = 0
+    p = []
+    for single in evid_normed:
+        product_tmp = 1
+        for ele in single:
+            product_tmp = product_tmp * ele
+        p.append(product_tmp)
+        k = k + product_tmp
+
+    p = [ele / k for ele in p]
+
+    return p
+
+
+def recursive_regi(save, template_list, r_image_path, max_length, old_similarity,
+                   norm_row, norm_col, regi_image, t_image_0, flip_degree):
+    degree_dict = {0: 0, 90: 1, 180: 2, 270: 3}
+    raw_name = r_image_path.replace('/', '\\').split('\\')[-1]
+
+    resize_ratio = min(round(max_length / max(norm_row, norm_col), 2), 1)
+    print('resize_ratio', resize_ratio)
+    max_simi_index, similarity_dict, std, registered_image, r_flip_degree = run(template_list, raw_name,
+                                                                                norm_row, norm_col, resize_ratio,
+                                                                                regi_image, t_image_0, flip_degree,
+                                                                                need_flip=NEED_FLIP)
+
+    similarity = round(similarity_dict['weight'], 2)
+    g_simi = round(similarity_dict['global'], 2)
+    c_simi = round(similarity_dict['corner'], 2)
+
+    cond1 = similarity >= REGISTER_RIGHT_HIGH  # 权重相似图上阈值, 超过就达标
+    cond2 = g_simi >= SINGLE_GLOBAL  # 全局相似度上阈值, 超过就达标
+    cond3 = c_simi >= SINGLE_CORNER  # 局部相似度上阈值,超过就达标
+    cond4 = abs(old_similarity - similarity) <= RECURSIVE_EPSILON  # 两次校准相似度偏差,是否小于RECURSIVE_EPSILON
+    cond5 = similarity > REGISTER_RIGHT_LOW  # 权重相似度是否大于相似度下限
+    cond6 = abs(std) < STD_DIFF  # 方差阈值, 小于达标
+    cond7 = similarity <= REGISTER_RIGHT_LOW  # 权重相似度下限,小于下限不达标
+
+    regi_flag = (cond6 and (cond1 or cond2 or cond3)) or (cond4 and cond5)
+
+    if resize_ratio >= RECURSIVE_RATIO:
+        if regi_flag:
+            degree_flag = degree_dict[r_flip_degree]
+            save_path = os_path.join(save, raw_name.replace('.jpg',
+                                                            '_index_{:02d}_{}.jpg'
+                                                            .format(max_simi_index, degree_flag)))
+            save_img = Image.fromarray(registered_image)
+            save_img.save(save_path)
+            write_single_img(registered_image, save_path)
+            print('{} >>> {}'.format(r_image_path, save_path))
+            return max_simi_index, save_path
+        else:
+            print('recursion and failed')
+            return -1, ''
+    else:
+        if regi_flag:
+            degree_flag = degree_dict[r_flip_degree]
+            save_path = os_path.join(save, raw_name.replace('.jpg',
+                                                            '_index_{:02d}_{}.jpg'
+                                                            .format(max_simi_index, degree_flag)))
+            write_single_img(registered_image, save_path)
+            print('{} >>> {}'.format(r_image_path, save_path))
+            return max_simi_index, save_path
+        else:
+            if cond7:
+                print('not register')
+                return -1, ''
+            print('recursion')
+            max_length = max_length + int(max_length * 0.5)
+            return recursive_regi(save, template_list, r_image_path, max_length, similarity,
+                                  norm_row, norm_col, regi_image, t_image_0, flip_degree)
+
+
+def run(t_image_path_list, raw_name, norm_row, norm_col, resize_ratio, regi_image, t_img_0, flip_degree, need_flip):
+    opencv_r_img_resize = resize_by_percent(regi_image, resize_ratio)
+
+    match_len = []
+    evid_list = []
+
+    t_length = len(t_image_path_list)
+    opencv_t_img_resize_0 = resize_by_percent(t_img_0, resize_ratio)
+
+    t_h, t_w = opencv_t_img_resize_0.shape[0], opencv_t_img_resize_0.shape[1]
+
+    array_dim = t_length
+    if need_flip:
+        t_image_array = np.zeros((t_h, t_w, 1 + t_length * 2), dtype=np.uint8)
+        array_dim = t_length * 2
+    else:
+        t_image_array = np.zeros((t_h, t_w, 1 + t_length), dtype=np.uint8)
+
+    for index, ele in enumerate(t_image_path_list):
+        if index == 0:
+            opencv_t_img_resize = opencv_t_img_resize_0
+        else:
+            norm_t_img = resize_by_fixed_size(read_single_img(ele), norm_col, norm_row)
+            opencv_t_img_resize = resize_by_percent(norm_t_img, resize_ratio)
+
+        t_image_array[:, :, index] = opencv_t_img_resize
+        if need_flip:
+            flip_image = np.flipud(np.fliplr(opencv_t_img_resize))
+            t_image_array[:, :, index + t_length] = flip_image
+
+    t_image_array[:, :, -1] = opencv_r_img_resize
+    t_corner_array = get_corners_of_image(t_image_array, CLIP_RATIO)
+    r_corners = t_corner_array[:, :, -1]
+
+    for index in range(0, array_dim):
+        similarity_corner, evidence_c = hash_detection(t_corner_array[:, :, index], r_corners)
+        similarity_global, evidence_g = hash_detection(t_image_array[:, :, index], opencv_r_img_resize)
+        evid_list.append(evidence_c + evidence_g)
+
+        print('  t-{}, corner:  whash: {:.4f}, dhash: {:.4f},  phash: {:.4f}'
+              .format(index,
+                      evidence_c[0],
+                      evidence_c[1],
+                      evidence_c[2],
+                      ))
+        print('       global:  whash: {:.4f}, dhash: {:.4f},  phash: {:.4f}'
+              .format(evidence_g[0],
+                      evidence_g[1],
+                      evidence_g[2],
+                      ))
+
+    p_list = ds_evidence(evid_list)
+
+    p_list_str = ', '.join(['page_{}: {:.4f}'.format(index, p) for index, p in enumerate(p_list)])
+    print(p_list_str)
+
+    if need_flip:
+        max_index = p_list.index(max(p_list))
+        half_len = t_length
+        if max_index >= half_len != 0:
+            max_index = max_index - half_len
+            opencv_r_img_resize = image_flip(opencv_r_img_resize, '180')
+            regi_image = image_flip(regi_image, '180')
+            r_kp, r_des = sift_kp(opencv_r_img_resize)
+            flip_degree = flip_degree + 180
+        else:
+            r_kp, r_des = sift_kp(opencv_r_img_resize)
+    else:
+        max_index = p_list.index(max(p_list))
+        r_kp, r_des = sift_kp(opencv_r_img_resize)
+
+    print('max_index: ', max_index)
+
+    t_kp, t_des = sift_kp(t_image_array[:, :, max_index])
+    match_points = get_good_match(t_des, r_des)
+    match_len.append(len(match_points))
+    try:
+        transform_mtx = sift_transform_mtx_by_match(match_points, t_kp, r_kp, resize_ratio)
+    except Exception as e:
+        print(e)
+        transform_mtx = None
+    # print(transform_mtx)
+    if transform_mtx is not None:
+        registered_image = perspective(regi_image, transform_mtx)
+
+        r_corners = get_corners_of_image(registered_image, CLIP_RATIO)
+        raw_corners = get_corners_of_image(regi_image, CLIP_RATIO)
+        simi_corner = hash_similarity(dhash(Image.fromarray(r_corners), hash_size=10),
+                                      dhash(Image.fromarray(raw_corners), hash_size=10))
+        simi_global = hash_similarity(dhash(Image.fromarray(registered_image), hash_size=10),
+                                      dhash(Image.fromarray(regi_image), hash_size=10))
+
+        std0 = np.std(resize_by_percent(regi_image, resize_ratio))
+        std1 = np.std(resize_by_percent(registered_image, resize_ratio))
+        std_dif = abs(std0 - std1)
+        print('std0: {:.4f}, std1: {:.4f}, diff: {:.4f}'.format(std0, std1, std_dif))
+        # std_dif = 0
+
+        similarity = WEIGHT_GLOBAL * simi_global + WEIGHT_CORNER * simi_corner  # 校准图与原图比较
+        print('simi_global: {:.4f}, simi_corner: {:.4f}'.format(simi_global, simi_corner))
+        print('weight_similarity: {:.4f}\n'.format(similarity))
+
+        simi_dict = {'global': simi_global, 'corner': simi_corner, 'weight': similarity}
+        return max_index, simi_dict, std_dif, registered_image, flip_degree
+
+    else:
+        raise ValueError('{} failed to find transform matrix'.format(raw_name))
+
+
+if __name__ == '__main__':
+    t1 = time.time()
+    parser = argparse.ArgumentParser(description="your script description")  # --help
+    parser.add_argument('--image_path', '-i', help='register image_path')
+    parser.add_argument('--save_path', '-s', help='register save_path')
+    parser.add_argument('--template_path', '-t', help='template_path')
+    # description参数可以用于插入描述脚本用途的信息,可以为空
+
+    args = parser.parse_args()  # 将变量以标签-值的字典形式存入args字典
+
+    # regi_image_path = args.image_path
+    # save_dir = args.save_path
+    # template_path_list = args.template_path.split(',')
+    # template_path_list = [ele.strip() for ele in template_path_list]
+
+    regi_image_path = r'C:\Users\Administrator\Desktop\error_imgs\201909061745_0001.jpg'
+    save_dir = r'C:\Users\Administrator\Desktop\error_imgs\save'
+    template_path_list = [r'C:\Users\Administrator\Desktop\error_imgs\10000000.jpg',
+                          r'C:\Users\Administrator\Desktop\error_imgs\10000001.jpg',
+                          # r'C:\Users\Administrator\Desktop\register_img_2\t\2.jpg',
+                          # r'C:\Users\Administrator\Desktop\register_img_2\t\4.jpg'
+                          ]
+    #
+    # # regi_image_path = r'C:\Users\Administrator\Desktop\error_imgs\Page0001.jpg'
+    # # save_dir = r'C:\Users\Administrator\Desktop\error_imgs\save'
+    # # template_path_list = [r'C:\Users\Administrator\Desktop\error_imgs\201910230723_0001.jpg',
+    # #                       # r'C:\Users\Administrator\Desktop\AnalysisRes\t\005983-1.jpg',
+    # #                       # r'C:\Users\Administrator\Desktop\AnalysisRes\t\010401.jpg',
+    # #                       # r'C:\Users\Administrator\Desktop\AnalysisRes\t\010401-1.jpg'
+    # #                       ]
+
+    template_path_list = [ele.strip() for ele in template_path_list]
+
+    if not os_path.exists(save_dir):
+        makedirs(save_dir)
+
+    fixed_max_length = FIXED_MAX_LENGTH
+
+    check = []
+    check_file_list = template_path_list
+    for ele in check_file_list:
+        if os_path.exists(ele):
+            check.append(True)
+        else:
+            check.append(False)
+
+    if False in check:
+        raise ValueError('templates not found, please check the path')
+    if not os_path.exists(regi_image_path):
+        raise ValueError('{} not found, please check the path'.format(regi_image_path))
+
+    else:
+        try:
+            ((norm_y, norm_x),
+             raw_regi_image, t_image_index_0, degree) = pre_process(template_path_list, regi_image_path)
+
+            init_similarity = 0.0
+            recursive_regi(save_dir, check_file_list, regi_image_path, fixed_max_length, init_similarity,
+                           norm_y, norm_x, raw_regi_image, t_image_index_0, degree)
+        except Exception as e:
+            print('{} registered failed: {}'.format(regi_image_path, e))
+
+    t2 = time.time()
+    print('time consume: {:.4f}'.format(t2 - t1))

+ 35 - 8
segment/sheet_resolve/analysis/resolve.py

@@ -26,6 +26,7 @@ from segment.sheet_resolve.analysis.sheet.sheet_infer import exam_number_adjust_
 from segment.sheet_resolve.tools import utils
 from segment.sheet_resolve.tools.tf_settings import xml_template_path, model_dict
 from segment.sheet_resolve.tools.utils import create_xml
+from segment.sheet_resolve.analysis.sheet.ocr_sheet import ocr2sheet
 
 logger = logging.getLogger(settings.LOGGING_TYPE)
 
@@ -33,6 +34,7 @@ sheet_infer_dict = dict(bar_code=True,
                         choice_m=True,
                         exam_number=True,
                         common_sheet=False,
+                        cloze=True,
                         solve=True)
 infer_choice_m_flag = False
 
@@ -50,6 +52,7 @@ def sheet(series_number, image_path, image, conf_thresh, mns_thresh, subject, sh
     h, w = image.shape[0], image.shape[1]
     regions = sheets_dict['regions']
     fetched_class = [ele['class_name'] for ele in regions]
+    infer_box_list = []
 
     try:
         regions = adjust_item_edge_by_gray_image(image, regions)
@@ -87,7 +90,8 @@ def sheet(series_number, image_path, image, conf_thresh, mns_thresh, subject, sh
 
             if not cond1 and not cond4 and cond2 and ocr:
                 exam_number_list = infer_exam_number(image, ocr, regions)
-                regions.extend(exam_number_list)
+                if len(exam_number_list) > 0:
+                    regions.extend(exam_number_list)
 
             image, regions = exam_number_adjust_infer(image, regions)
         except Exception as e:
@@ -95,12 +99,12 @@ def sheet(series_number, image_path, image, conf_thresh, mns_thresh, subject, sh
             logger.info('试卷:{} 考号推断失败: {}'.format(image_path, e))
 
     if sheet_infer_dict['choice_m']:
-
         try:
             col_split = col_split_x.copy()
             if not col_split:
                 col_split = [w - 1]
-            choice_m_list = infer_choice_m(image, regions, col_split, ocr)
+            infer_box_list = ocr2sheet(image, col_split_x, ocr)
+            choice_m_list = infer_choice_m(image, regions, infer_box_list, col_split)
             if len(choice_m_list) > 0:
                 choice_m_old_list = [ele for ele in regions if 'choice_m' == ele['class_name']]
                 for infer_box in choice_m_list.copy():
@@ -117,7 +121,7 @@ def sheet(series_number, image_path, image, conf_thresh, mns_thresh, subject, sh
                             #     choice_m_list.remove(infer_box)
                             regions.remove(tf_box)
                             # break
-                        elif iou[0] > 0:
+                        elif iou[0] > 0.05:
                             choice_m_list.remove(infer_box)
                             break
 
@@ -132,6 +136,27 @@ def sheet(series_number, image_path, image, conf_thresh, mns_thresh, subject, sh
             traceback.print_exc()
             logger.info('试卷:{} 选择题推断失败: {}'.format(image_path, e))
 
+    if sheet_infer_dict['cloze']:
+
+        cloze_list = []
+        for infer_box in infer_box_list:
+            # {'loc': [240, 786, 1569, 1368]}
+            loc = infer_box['loc']
+            xmin, ymin, xmax, ymax = loc[0], loc[1], loc[2], loc[3]
+
+            for ele in regions:
+                if ele['class_name'] in ['cloze_s', 'cloze']:
+                    tf_loc = ele['bounding_box']
+                    tf_loc_l = tf_loc['xmin']
+                    tf_loc_t = tf_loc['ymin']
+                    if xmin < tf_loc_l < xmax and ymin < tf_loc_t < ymax:
+                        cloze_box = {'class_name': 'cloze',
+                                     'bounding_box': {'xmin': xmin, 'ymin': ymin, 'xmax': xmax, 'ymax': ymax}}
+                        cloze_list.append(cloze_box)
+                        break
+
+        regions.extend(cloze_list)
+
     if sheet_infer_dict['solve']:
         try:
             include_class = ['info_title',
@@ -146,7 +171,8 @@ def sheet(series_number, image_path, image, conf_thresh, mns_thresh, subject, sh
                              'composition0',
                              'correction',
                              'alarm_info',
-                             'page'
+                             'page',
+                             'mark'
                              ]
             if 'math' not in subject:
                 include_class.remove('cloze_s')
@@ -285,11 +311,11 @@ def choice_row_col(image, regions, xml_path, conf_thresh, mns_thresh, choice_ses
     return choice_list
 
 
-def choice_m_row_col(image, regions, xml_path):
+def choice_m_row_col(image, regions, subject, xml_path):
     choice_m_dict_tf = []
     direction_list = []
-    for ele in regions:
 
+    for ele in regions:
         if ele['class_name'] == 'choice_m':
             choice_m_dict_tf.append(ele)
         if ele['class_name'] == 'choice_n':
@@ -327,7 +353,8 @@ def choice_m_row_col(image, regions, xml_path):
         # else:
         #     choice_list = choice_line_box.choice_m_row_col(image, choice_m_dict_tf, xml_path)  # 找选择题行列、分数
 
-        choice_list = choice_line_box.choice_m_row_col(image, choice_m_dict_tf, direction, xml_path)  # 找选择题行列、分数
+        choice_list = choice_line_box.choice_m_row_col(image, choice_m_dict_tf, direction, subject,
+                                                       xml_path)  # 找选择题行列、分数
         tree = ET.parse(xml_path)  # xml tree
         for index_num, box in enumerate(choice_list):
             if len(box['bounding_box']) > 0:

+ 178 - 108
segment/sheet_resolve/analysis/sheet/analysis_sheet.py

@@ -14,7 +14,6 @@ from segment.sheet_resolve.lib.utils.timer import Timer
 from segment.sheet_resolve.tools import utils
 from segment.sheet_resolve.analysis.solve.optional_solve import resolve_optional_choice
 
-
 logger = logging.getLogger(settings.LOGGING_TYPE)
 
 
@@ -113,6 +112,7 @@ def get_single_image_sheet_regions(analysis_type, img_path, img, classes,
                                            img, conf_thresh, mns_thresh,
                                            coordinate_bias_dict)
 
+    analysis_type.replace('_blank', '')
     img_dict = {"img_name": img_path,
                 # 'qr_code': qr_code_info,
                 'subject': analysis_type,
@@ -126,6 +126,13 @@ def get_single_image_sheet_regions(analysis_type, img_path, img, classes,
 
 
 def question_number_format(init_number, crt_numbers, regions):
+    """
+    将重复或者是-1的题号改为501
+    :param init_number: 初始替换的题号
+    :param crt_numbers: 目前已经有的题号
+    :param regions: 答题卡各区域
+    :return:
+    """
     logger.info('regions: {}'.format(regions))
     for region in regions:
         logger.info('region: {}'.format(region))
@@ -156,6 +163,34 @@ def question_number_format(init_number, crt_numbers, regions):
     return regions, init_number, crt_numbers
 
 
+def resolve_select_s(image, bbox):
+    box_region = utils.crop_region(image, bbox)
+    left = bbox['xmin']
+    top = bbox['ymin']
+    right = bbox['xmax']
+    bottom = bbox['ymax']
+
+    if (right - left) >= (bottom - top):
+        direction = 180
+    else:
+        direction = 90
+
+    try:
+        res = resolve_optional_choice(left, top, direction, box_region)
+    except Exception as e:
+        res = {'class_name': 'optional_choice',
+               'rows': 1, 'cols': 2,
+               'number': [501, 502],
+               'single_width': right - left,
+               'single_height': bottom - top,
+               'bounding_box': {'xmin': left,
+                                'ymin': top,
+                                'xmax': right,
+                                'ymax': bottom}}
+
+    return res
+
+
 def box_region_format(sheet_dict, image, subject, shrink=True):
     include_class = ['anchor_point',
                      'bar_code',
@@ -172,14 +207,15 @@ def box_region_format(sheet_dict, image, subject, shrink=True):
                      # 'correction'
                      ]
 
-    sheet_regions = sheet_dict['regions']
-    optional_choice_tmp = []
     default_points_dict = {'choice_m': 5, "cloze": 5, 'solve': 12, 'optional_solve': 10, 'cloze_s': 5,
                            "composition": 60}
-    if subject == "english":
-        default_points_dict = {'choice_m': 2, "cloze": 2, 'solve': 2, 'optional_solve': 10, 'cloze_s': 2,
+
+    if subject in ["english", 'physics', 'chemistry', 'biology', 'science_comprehensive']:
+        default_points_dict = {'choice_m': 2, "cloze": 2, 'solve': 10, 'optional_solve': 10, 'cloze_s': 2,
                                "composition": 25}
 
+    sheet_regions = sheet_dict['regions']
+    select_s_list = []
     for i in range(len(sheet_regions) - 1, -1, -1):
         if subject == "math":
             if sheet_regions[i]['class_name'] == 'cloze':
@@ -187,140 +223,174 @@ def box_region_format(sheet_dict, image, subject, shrink=True):
             if sheet_regions[i]['class_name'] == 'cloze_s':
                 sheet_regions[i]['class_name'] = 'cloze'  # math exclude cloze big
         if subject == "english":
-            if sheet_regions[i]['class_name'] == 'solve':
-                sheet_regions[i]['class_name'] = 'cloze'
+            if sheet_regions[i]['class_name'] == 'cloze':
+                sheet_regions[i]['class_name'] = 'solve'
             if sheet_regions[i]['class_name'] == 'correction':
                 sheet_regions[i]['class_name'] = 'solve'
 
-    for i in range(len(sheet_regions) - 1, -1, -1):
         if sheet_regions[i]['class_name'] in ['solve0']:
             sheet_regions[i]['class_name'] = 'solve'
         if sheet_regions[i]['class_name'] in ['composition0']:
             sheet_regions[i]['class_name'] = 'composition'
 
         if sheet_regions[i]['class_name'] == 'select_s':
-            # sheet_regions[i]['class_name'] = 'optional_choice'
-            # optional_solve_tmp.append(sheet_regions[i])
-
-            bbox = sheet_regions[i]['bounding_box']
-            box_region = utils.crop_region(image, bbox)
-            left = bbox['xmin']
-            top = bbox['ymin']
-            right = bbox['xmax']
-            bottom = bbox['ymax']
-
-            if (right - left) >= (bottom - top):
-                direction = 180
-            else:
-                direction = 90
-
-            try:
-                res = resolve_optional_choice(left, top, direction, box_region)
-            except Exception as e:
-                res = {'class_name': 'optional_choice',
-                       'rows': 1, 'cols': 1,
-                       'number': [501],
-                       'single_width': right - left,
-                       'single_height': bottom - top,
-                       'bounding_box': {'xmin': left,
-                                        'ymin': top,
-                                        'xmax': right,
-                                        'ymax': bottom}}
-
-            optional_choice_tmp.append(res)
-
-            sheet_regions.pop(i)
+            select_s_list.append(sheet_regions[i])
 
         if shrink:
             if sheet_regions[i]['class_name'] not in include_class:
                 sheet_regions.pop(i)
 
-    for ele in sheet_regions:
-        if ele['class_name'] == 'solve':
-            solve_box = (ele['bounding_box']['xmin'], ele['bounding_box']['ymin'],
-                         ele['bounding_box']['xmax'], ele['bounding_box']['ymax'])
-            for optional_choice in optional_choice_tmp:
-                optional_choice_box = (optional_choice['bounding_box']['xmin'], optional_choice['bounding_box']['ymin'],
-                                       optional_choice['bounding_box']['xmax'], optional_choice['bounding_box']['ymax'])
-                if utils.decide_coordinate_contains(optional_choice_box, solve_box):
-                    ele['class_name'] = 'optional_solve'
-                    choice_numbers = optional_choice['number']
-                    solve_points = ele['number']
-                    if choice_numbers[0] < 500:
-                        ele['number'] = choice_numbers
-                        ele['default_points'] = [ele['default_points']] * len(choice_numbers)
+    # 去重
+    sheet_tmp = sheet_regions.copy()
+    remove_index = []
+    for i, region in enumerate(sheet_tmp):
+        if i not in remove_index:
+            box = region['bounding_box']
+            name = region['class_name']
+            for j, region_in in enumerate(sheet_tmp):
+                box_in = region_in['bounding_box']
+                name_in = region_in['class_name']
+                iou = utils.cal_iou(box, box_in)
+
+                if name == name_in and (iou[0] > 0.75 or iou[1] > 0.85 or iou[2] > 0.85) and i != j:
+                    box_area = (box['xmax'] - box['xmin']) * (box['ymax'] - box['ymin'])
+                    box_in_area = (box_in['xmax'] - box_in['xmin']) * (box_in['ymax'] - box_in['ymin'])
+
+                    if box_area >= box_in_area:
+                        sheet_regions.remove(region_in)
+                        remove_index.append(j)
                     else:
-                        ele['number'] = [solve_points] * len(choice_numbers)
-                        optional_choice['numbers'] = [solve_points] * len(choice_numbers)
-                        ele['default_points'] = [ele['default_points']] * len(choice_numbers)
+                        sheet_regions.remove(region)
+                        remove_index.append(i)
                     break
-                else:
-                    continue
-
-        # 设置默认分数
-        # if ele['class_name'] == "composition":
-        #     if isinstance(ele['default_points'], list):
-        #         for i, dp in enumerate(ele['default_points']):
-        #             if dp != default_points_dict[ele['class_name']]:
-        #                 ele['default_points'][i] = default_points_dict[ele['class_name']]
-        #
-        #     if isinstance(ele['default_points'], int):
-        #         if ele['default_points'] != default_points_dict[ele['class_name']]:
-        #             ele['default_points'] = default_points_dict[ele['class_name']]
+
+    # 合并select_s
+    optional_choice_tmp = []
+    select_s_list_copy = select_s_list.copy()
+    if len(select_s_list) > 0:
+        for ele in sheet_regions:
+            if ele['class_name'] == 'solve':
+                solve_box = (ele['bounding_box']['xmin'], ele['bounding_box']['ymin'],
+                             ele['bounding_box']['xmax'], ele['bounding_box']['ymax'])
+
+                xn, yn, xm, ym = 9999, 9999, 0, 0
+                merge = False
+                for select_s in select_s_list:
+                    select_s_box = (select_s['bounding_box']['xmin'], select_s['bounding_box']['ymin'],
+                                    select_s['bounding_box']['xmax'], select_s['bounding_box']['ymax'])
+                    if utils.decide_coordinate_contains(select_s_box, solve_box):
+                        merge = True
+                        xn = min(xn, select_s_box[0])
+                        yn = min(yn, select_s_box[1])
+                        xm = max(xm, select_s_box[2])
+                        ym = max(ym, select_s_box[3])
+                        select_s_list_copy.remove(select_s)
+
+                if merge:
+                    new_box = {'xmin': xn, 'ymin': yn, 'xmax': xm, 'ymax': ym}
+                    optional_choice_info = resolve_select_s(image, new_box)
+                    optional_choice_tmp.append(optional_choice_info)
+
+    for ele in select_s_list_copy:
+        box = ele['bounding_box']
+        optional_choice_info = resolve_select_s(image, box)
+        optional_choice_tmp.append(optional_choice_info)
+
+    optional_choice_tmp_ = optional_choice_tmp.copy()
+    for ele in sheet_regions:
+        if len(optional_choice_tmp) > 0:
+            if ele['class_name'] == 'solve':
+                solve_box = (ele['bounding_box']['xmin'], ele['bounding_box']['ymin'],
+                             ele['bounding_box']['xmax'], ele['bounding_box']['ymax'])
+                for optional_choice in optional_choice_tmp_:
+                    optional_choice_box = (optional_choice['bounding_box']['xmin'], optional_choice['bounding_box']['ymin'],
+                                           optional_choice['bounding_box']['xmax'], optional_choice['bounding_box']['ymax'])
+                    if utils.decide_coordinate_contains(optional_choice_box, solve_box):
+                        optional_choice_tmp.remove(optional_choice)
+                        ele['class_name'] = 'optional_solve'
+
+                        choice_numbers = optional_choice['number']
+                        solve_numbers = ele['number']
+                        if choice_numbers[0] < 500:
+                            ele['number'] = choice_numbers
+                            ele['default_points'] = [ele['default_points']] * len(choice_numbers)
+                        else:
+                            tmp = [solve_numbers] * len(choice_numbers)
+                            for i, num in enumerate(tmp):
+                                tmp[i] = num + i
+                            ele['number'] = tmp
+                            optional_choice['number'] = tmp
+                            ele['default_points'] = [ele['default_points']] * len(choice_numbers)
+                        break
+                    else:
+                        continue
 
         if ele['class_name'] in ["choice_m", "cloze", "cloze_s", "solve", "optional_solve", "composition"]:
             if isinstance(ele['default_points'], list):
                 for i, dp in enumerate(ele['default_points']):
-                    if dp <= -1:
+                    if dp < 1:  # 小于一分
                         ele['default_points'][i] = default_points_dict[ele['class_name']]
 
             if isinstance(ele['default_points'], int) or isinstance(ele['default_points'], float):
-                if ele['default_points'] <= -1:
+                if ele['default_points'] < 1:  # 小于一分
                     ele['default_points'] = default_points_dict[ele['class_name']]
 
-    sheet_regions.extend(optional_choice_tmp)
-    # for ele in optional_choice_tmp:  # 选做题
-    #     bbox = ele['bounding_box']
-    #     box_region = utils.crop_region(image, bbox)
-    #     left = bbox['xmin']
-    #     top = bbox['ymin']
-    #     right = bbox['xmax']
-    #     bottom = bbox['ymax']
-    #
-    #     if (right - left) >= (bottom - top):
-    #         direction = 180
-    #     else:
-    #         direction = 90
-    #
-    #     # res = find_contours(left, top, box_region)
-    #     try:
-    #         res = resolve_optional_choice(left, top, direction, box_region)
-    #     except Exception as e:
-    #         res = {'class_name': 'optional_choice',
-    #                'rows': 1, 'cols': 1,
-    #                'numbers': [501],
-    #                'single_width': right - left,
-    #                'single_height': bottom - top,
-    #                'bounding_box': {'xmin': left,
-    #                                 'ymin': top,
-    #                                 'xmax': right,
-    #                                 'ymax': bottom}}
-    #
-    #     sheet_regions.append(res)
-
-    # iou
+    # select_s 在解答区域外侧
+    if len(optional_choice_tmp) > 0:
+        for oc in optional_choice_tmp:
+            optional_choice_box = (oc['bounding_box']['xmin'], oc['bounding_box']['ymin'],
+                                   oc['bounding_box']['xmax'], oc['bounding_box']['ymax'],
+                                   oc['bounding_box']['xmin']
+                                   + (oc['bounding_box']['xmax'] - oc['bounding_box']['xmin']) // 2)
+            for sr in sheet_regions:
+                if sr['class_name'] == 'solve':
+                    solve_box = (sr['bounding_box']['xmin'], sr['bounding_box']['ymin'],
+                                 sr['bounding_box']['xmax'], sr['bounding_box']['ymax'])
+                    if (optional_choice_box[1] <= solve_box[1] and
+                            solve_box[0] < optional_choice_box[4] < solve_box[2] and
+                            abs(optional_choice_box[1] - solve_box[1]) < solve_box[3] - solve_box[1]):
+                        sr['class_name'] = 'optional_solve'
+                        choice_numbers = oc['number']
+                        solve_numbers = sr['number']
+                        if choice_numbers[0] < 500:
+                            sr['number'] = choice_numbers
+                            sr['default_points'] = [sr['default_points']] * len(choice_numbers)
+                        else:
+                            tmp = [solve_numbers] * len(choice_numbers)
+                            for i, num in enumerate(tmp):
+                                tmp[i] = num + i
+
+                            sr['number'] = tmp
+                            oc['number'] = tmp
+                            sr['default_points'] = [sr['default_points']] * len(choice_numbers)
+
+                        break
+
+    if len(optional_choice_tmp_):
+        sheet_regions.extend(optional_choice_tmp_)
+
+    # 去重
     sheet_tmp = sheet_regions.copy()
     remove_index = []
     for i, region in enumerate(sheet_tmp):
         if i not in remove_index:
             box = region['bounding_box']
+            name = region['class_name']
             for j, region_in in enumerate(sheet_tmp):
                 box_in = region_in['bounding_box']
-                # TODO 根据大小
+                name_in = region_in['class_name']
                 iou = utils.cal_iou(box, box_in)
-                if iou[0] > 0.75 and i != j:
-                    sheet_regions.remove(region)
-                    remove_index.append(j)
+
+                if name == name_in and (iou[0] > 0.75 or iou[1] > 0.85 or iou[2] > 0.85) and i != j:
+                    box_area = (box['xmax'] - box['xmin']) * (box['ymax'] - box['ymin'])
+                    box_in_area = (box_in['xmax'] - box_in['xmin']) * (box_in['ymax'] - box_in['ymin'])
+
+                    if box_area >= box_in_area:
+                        sheet_regions.remove(region_in)
+                        remove_index.append(j)
+                    else:
+                        sheet_regions.remove(region)
+                        remove_index.append(i)
                     break
 
     sheet_dict.update({'regions': sheet_regions})
@@ -335,8 +405,8 @@ def merge_span_boxes(col_sheets):
         next_col = col_sheets[i + 1]
         if not cur_col or not next_col:
             continue
-        current_bottom_box = cur_col[-1]
-        next_col_top_box = next_col[0]
+        current_bottom_box = cur_col[-1]   # 当前栏的最后一个,bottom
+        next_col_top_box = next_col[0]  # 下一栏的第一个,top
 
         b_name = current_bottom_box['class_name']
         t_name = next_col_top_box['class_name']

+ 418 - 274
segment/sheet_resolve/analysis/sheet/choice_infer.py

@@ -299,6 +299,7 @@ def cluster_and_anti_abnormal(image, xml_path, choice_n_list, digital_list, char
             digital_list_to_cluster.append(digital_list[i])
             digital_loc_arr.append(point)
 
+    # 得到所有题号区域, 作为后续划分choice_m的依据
     choice_m_numbers_list = []
     for ele in choice_n_list:
         loc = ele['bounding_box']
@@ -418,295 +419,438 @@ def cluster_and_anti_abnormal(image, xml_path, choice_n_list, digital_list, char
     #             # cond1 = cond2 = true, 因为infer选择题时已横向排序, 默认这种情况不会出现
     #             pass
 
+    direction180, direction90 = 0, 0
     for ele in choice_m_numbers_list:
         loc = ele["loc"]
-        if loc[3] - loc[1] >= loc[2] - loc[0]:
+        if loc[3] - loc[1] >= 2 * (loc[2] - loc[0]):
             direction = 180
+            direction180 += 1
         else:
             direction = 90
+            direction90 += 1
         ele.update({'direction': direction})
-    # tree = ET.parse(xml_path)
-    # for index, choice_m in enumerate(choice_m_numbers_list):
-    #     name = str(choice_m["numbers"])
-    #     xmin, ymin, xmax, ymax, _, _ = choice_m["loc"]
-    #     tree = create_xml(name, tree, str(xmin + limit_left), str(ymin + limit_top), str(xmax + limit_left), str(ymax + limit_top))
-    #
-    # tree.write(xml_path)
-    choice_m_numbers_list = sorted(choice_m_numbers_list, key=lambda x: x['loc'][3] - x['loc'][1], reverse=True)
-    choice_m_numbers_right_limit = max([ele['loc'][2] for ele in choice_m_numbers_list])
-    remain_len = len(choice_m_numbers_list)
-    choice_m_list = list()
-    need_revised_choice_m_list = list()
-    while remain_len > 0:
-        # 先确定属于同行的数据,然后找字母划分block
-        # random_index = random.randint(0, len(choice_m_numbers_list)-1)
-        random_index = 0
-        # print(random_index)
-        ymax_limit = choice_m_numbers_list[random_index]["loc"][3]
-        ymin_limit = choice_m_numbers_list[random_index]["loc"][1]
-        # choice_m_numbers_list.pop(random_index)
-
-        # 当前行的choice_m
-        current_row_choice_m_d = [ele for ele in choice_m_numbers_list if ymin_limit < ele["loc"][5] < ymax_limit]
-        current_row_choice_m_d = sorted(current_row_choice_m_d, key=lambda x: x["loc"][0])
-        # current_row_choice_m_d.append(choice_m_numbers_list[random_index])
-        split_pix = sorted([ele["loc"][0] for ele in current_row_choice_m_d])  # xmin排序
-        split_index = get_split_index(split_pix, dif=choice_s_width * 0.8)
-        split_pix = [split_pix[ele] for ele in split_index[:-1]]
-
-        block_list = []
-        for i in range(len(split_index) - 1):
-            block = current_row_choice_m_d[split_index[i]: split_index[i + 1]]
-            if len(block) > 1:
-                remain_len = remain_len - (len(block) - 1)
-                numbers_new = []
-                loc_new = [[], [], [], []]
-                for blk in block:
-                    loc_old = blk["loc"]
-                    numbers_new.extend(blk["numbers"])
-                    for ii in range(4):
-                        loc_new[ii].append(loc_old[ii])
-
-                loc_new[0] = min(loc_new[0])
-                loc_new[1] = min(loc_new[1])
-                loc_new[2] = max(loc_new[2])
-                loc_new[3] = max(loc_new[3])
-
-                loc_new.append(loc_new[0] + (loc_new[2] - loc_new[0]) // 2)
-                loc_new.append(loc_new[1] + (loc_new[3] - loc_new[1]) // 2)
-
-                block = [{"numbers": sorted(numbers_new), "loc": loc_new, "direction": block[0]["direction"]}]
-
-            block_list.extend(block)
-
-        current_row_choice_m_d = block_list
-        current_row_chars = [ele for ele in chars_list
-                             if ymin_limit < (ele["location"]["top"] + ele["location"]["height"] // 2) < ymax_limit]
-
-        # split_index.append(row_chars_xmax)  # 边界
-        split_pix.append(round(split_pix[-1] + choice_s_width * 1.75))
-        for i in range(0, len(split_pix) - 1):
-            left_limit = split_pix[i]
-            right_limit = split_pix[i + 1]
-            block_chars = [ele for ele in current_row_chars
-                           if left_limit < (ele["location"]["left"] + ele["location"]["width"] // 2) < right_limit]
-
-            a_z = '_ABCD_FGHT'
-            letter_index = [a_z.index(ele['char'].upper()) for ele in block_chars if ele['char'].upper() in a_z]
-
-            letter_index_times = {ele: 0 for ele in set(letter_index)}
-            for l_index in letter_index:
-                letter_index_times[l_index] += 1
-
-            if (a_z.index("T") in letter_index) and (a_z.index("F") in letter_index):
-                choice_option = "T, F"
-                cols = 2
-            else:
-                if len(letter_index) < 1:
-                    tmp = 4
-                    choice_option = 'A,B,C,D'
-                else:
-                    tmp = max(set(letter_index))
-
-                    choice_option = ",".join(a_z[min(letter_index):tmp + 1])
-                cols = tmp
 
-            bias = 3  # pix
-            current_loc = current_row_choice_m_d[i]["loc"]
-            location = dict(xmin=(current_loc[2] + bias),  # 当前数字xmax右边
-                            # xmin=max(current_loc[2] + bias, chars_xmin) + limit_left,
-                            ymin=current_loc[1],
-
-                            xmax=(right_limit - bias),
-                            # xmax=min(chars_xmax, right_limit - bias) + limit_left,
-                            ymax=current_loc[3])
-
-            try:
-                choice_m_img = utils.crop_region(image, location)
-                if 0 in choice_m_img.shape[:2]:
-                    continue
-                right_loc, bottom_loc = adjust_choice_m(choice_m_img, mean_height, mean_width * 2)
-                if right_loc > 0:
-                    location.update(dict(xmax=right_loc + location['xmin']))
-                if bottom_loc > 0:
-                    location.update(dict(ymax=bottom_loc + location['ymin']))
-            except Exception as e:
-                print(e)
-                traceback.print_exc()
+    # 判断大多数choice_m的方向
+    if direction180 >= direction90:   # 横排
+
+        choice_m_numbers_list = sorted(choice_m_numbers_list, key=lambda x: x['loc'][3] - x['loc'][1], reverse=True)
+        choice_m_numbers_right_limit = max([ele['loc'][2] for ele in choice_m_numbers_list])
+        remain_len = len(choice_m_numbers_list)
+        choice_m_list = list()
+        need_revised_choice_m_list = list()
+        while remain_len > 0:
+            # 先确定属于同行的数据,然后找字母划分block
+
+            random_index = 0
+            # print(random_index)
+            ymax_limit = choice_m_numbers_list[random_index]["loc"][3]
+            ymin_limit = choice_m_numbers_list[random_index]["loc"][1]
+
+            # 当前行的choice_m
+            current_row_choice_m_d = [ele for ele in choice_m_numbers_list if ymin_limit < ele["loc"][5] < ymax_limit]
+            current_row_choice_m_d = sorted(current_row_choice_m_d, key=lambda x: x["loc"][0])
+            # current_row_choice_m_d.append(choice_m_numbers_list[random_index])
+
+            # 对同行的题号区域排序, 得到分割间隔, 两个题号中间的区域为choice_m
+            split_pix = sorted([ele["loc"][0] for ele in current_row_choice_m_d])  # xmin排序
+            split_index = get_split_index(split_pix, dif=choice_s_width * 0.8)
+            split_pix = [split_pix[ele] for ele in split_index[:-1]]
+
+            block_list = []
+            for i in range(len(split_index) - 1):
+                block = current_row_choice_m_d[split_index[i]: split_index[i + 1]]
+                if len(block) > 1:
+                    remain_len = remain_len - (len(block) - 1)
+                    numbers_new = []
+                    loc_new = [[], [], [], []]
+                    for blk in block:
+                        loc_old = blk["loc"]
+                        numbers_new.extend(blk["numbers"])
+                        for ii in range(4):
+                            loc_new[ii].append(loc_old[ii])
+
+                    loc_new[0] = min(loc_new[0])
+                    loc_new[1] = min(loc_new[1])
+                    loc_new[2] = max(loc_new[2])
+                    loc_new[3] = max(loc_new[3])
+
+                    loc_new.append(loc_new[0] + (loc_new[2] - loc_new[0]) // 2)
+                    loc_new.append(loc_new[1] + (loc_new[3] - loc_new[1]) // 2)
+
+                    block = [{"numbers": sorted(numbers_new), "loc": loc_new, "direction": block[0]["direction"]}]
+
+                block_list.extend(block)
+
+            current_row_choice_m_d = block_list
+            current_row_chars = [ele for ele in chars_list
+                                 if ymin_limit < (ele["location"]["top"] + ele["location"]["height"] // 2) < ymax_limit]
+
+            # split_index.append(row_chars_xmax)  # 边界
+            split_pix.append(limit_right)
+            for i in range(0, len(split_pix) - 1):
+                left_limit = split_pix[i]
+                right_limit = split_pix[i + 1]
+                block_chars = [ele for ele in current_row_chars
+                               if left_limit < (ele["location"]["left"] + ele["location"]["width"] // 2) < right_limit]
+
+                a_z = '_ABCD_FGHT'
+                letter_index = [a_z.index(ele['char'].upper()) for ele in block_chars if ele['char'].upper() in a_z]
+
+                letter_index_times = {ele: 0 for ele in set(letter_index)}
+                for l_index in letter_index:
+                    letter_index_times[l_index] += 1
+
+                if (a_z.index("T") in letter_index) and (a_z.index("F") in letter_index):
+                    choice_option = "T, F"
+                    cols = 2
+                else:
+                    if len(letter_index) < 1:
+                        tmp = 4
+                        choice_option = 'A,B,C,D'
+                    else:
+                        tmp = max(set(letter_index))
+
+                        choice_option = ",".join(a_z[min(letter_index):tmp + 1])
+                    cols = tmp
+
+                bias = 3  # pix
+                current_loc = current_row_choice_m_d[i]["loc"]
+                location = dict(xmin=(current_loc[2] + bias),  # 当前数字xmax右边
+                                # xmin=max(current_loc[2] + bias, chars_xmin) + limit_left,
+                                ymin=current_loc[1],
+
+                                xmax=(right_limit - bias),
+                                # xmax=min(chars_xmax, right_limit - bias) + limit_left,
+                                ymax=current_loc[3])
+
+                try:
+                    # 调整choice-m区域, 避免推断出来的区域过大
+                    choice_m_img = utils.crop_region(image, location)
+                    if 0 in choice_m_img.shape[:2]:
+                        continue
+                    right_loc, bottom_loc = adjust_choice_m(choice_m_img, mean_height, mean_width * 2)
+                    if right_loc > 0:
+                        location.update(dict(xmax=right_loc + location['xmin']))
+                    if bottom_loc > 0:
+                        location.update(dict(ymax=bottom_loc + location['ymin']))
+                except Exception as e:
+                    print(e)
+                    traceback.print_exc()
+
+                tmp_w, tmp_h = location['xmax'] - location['xmin'], location['ymax'] - location['ymin'],
+                numbers = current_row_choice_m_d[i]["numbers"]
+
+                direction = current_row_choice_m_d[i]["direction"]
+                if direction == 180:
+                    choice_m = dict(class_name='choice_m',
+                                    number=numbers,
+                                    bounding_box=location,
+                                    choice_option=choice_option,
+                                    default_points=[5] * len(numbers),
+                                    direction=direction,
+                                    cols=cols,
+                                    rows=len(numbers))
+                else:
+                    choice_m = dict(class_name='choice_m',
+                                    number=numbers,
+                                    bounding_box=location,
+                                    choice_option=choice_option,
+                                    default_points=[5] * len(numbers),
+                                    direction=direction,
+                                    cols=len(numbers),
+                                    rows=cols)
+
+                if tmp_w > 2 * choice_s_width:
+                    need_revised_choice_m_list.append(choice_m)
+                else:
+                    choice_m_list.append(choice_m)
+
+            remain_len = remain_len - len(current_row_choice_m_d)
+            for ele in choice_m_numbers_list.copy():
+                if ele in current_row_choice_m_d:
+                    choice_m_numbers_list.remove(ele)
+
+            for ele in choice_m_numbers_list.copy():
+                if ele in current_row_chars:
+                    choice_m_numbers_list.remove(ele)
+
+            # 解决单行问题
+            if len(choice_m_list) > 0:
+                crt_right_max = max([int(ele['bounding_box']['xmax']) for ele in choice_m_list])
+                if limit_right - crt_right_max > choice_s_width:
+                    # 存在区域
+                    region_loc = {'xmin': crt_right_max + 10,
+                                  'ymin': choice_m_list[0]['bounding_box']['ymin'],
+                                  'xmax': limit_right,
+                                  'ymax': choice_m_list[0]['bounding_box']['ymax']}
+
+                    contain_dig = []
+                    for i, ele in enumerate(digital_loc_arr):
+                        if region_loc['xmin'] < ele[0] < region_loc['xmax'] and region_loc['ymin'] < ele[1] < region_loc['ymax']:
+                            contain_dig.append(digital_list[i])
+
+                    contain_chars = [ele for ele in chars_list
+                                     if region_loc['xmin'] < (ele["location"]["left"] + ele["location"]["width"] // 2) < region_loc['xmax']
+                                     and
+                                     region_loc['xmin'] < (ele["location"]["top"] + ele["location"]["height"] // 2) < region_loc['ymax']]
+                    numbers = [-1]
+                    if contain_dig or contain_chars:
+                        d_ymin, d_ymax, d_xmin, d_xmax = 9999, 0, 9999, 0
+                        if contain_dig:
+                            numbers = [ele["digital"] for ele in contain_dig]
+                            d_ymin = min([ele['loc'][1] for ele in contain_dig])
+                            d_ymax = max([ele['loc'][3] for ele in contain_dig])
+                            d_xmin = min([ele['loc'][0] for ele in contain_dig])
+                            d_xmax = max([ele['loc'][2] for ele in contain_dig])
+
+                        c_ymin, c_ymax, c_xmin, c_xmax = 9999, 0, 9999, 0
+                        if contain_chars:
+                            c_ymin = min([ele["location"]["top"] for ele in contain_chars])
+                            c_ymax = max([ele["location"]["top"] + ele["location"]["height"] for ele in contain_chars])
+                            c_xmin = min([ele["location"]["left"] for ele in contain_chars])
+                            c_xmax = max([ele["location"]["left"] + ele["location"]["width"] for ele in contain_chars])
+
+                        r_ymin, r_ymax = min(d_ymin, c_ymin), max(d_ymax, c_ymax)
+                        r_xmin, r_xmax = min(d_xmin, c_xmin), max(d_xmax, c_xmax)
+
+                        region_loc['ymin'] = r_ymin - 10
+                        region_loc['ymax'] = r_ymax + 10
+                        if d_xmin == r_xmin:
+                            region_loc['xmin'] = d_xmax + 5
+                            region_loc['xmax'] = d_xmax + 5 + int(1.2 * choice_s_width)
+                        else:
+                            if 1.2 * (r_xmax - r_xmin) > choice_s_width:
+                                region_loc['xmin'] = r_xmin - 10
+                                region_loc['xmax'] = r_xmax + 10
+                            else:
+                                region_loc['xmin'] = max((r_xmax - r_xmin) // 2 + r_xmin - choice_s_width,
+                                                         crt_right_max + 10)
+                                region_loc['xmax'] = min((r_xmax - r_xmin) // 2 + r_xmin + choice_s_width ,
+                                                         limit_right)
 
-            tmp_w, tmp_h = location['xmax'] - location['xmin'], location['ymax'] - location['ymin'],
-            numbers = current_row_choice_m_d[i]["numbers"]
-
-            direction = current_row_choice_m_d[i]["direction"]
-            if direction == 180:
-                choice_m = dict(class_name='choice_m',
-                                number=numbers,
-                                bounding_box=location,
-                                choice_option=choice_option,
-                                default_points=[5] * len(numbers),
-                                direction=direction,
-                                cols=cols,
-                                rows=len(numbers))
-            else:
-                choice_m = dict(class_name='choice_m',
-                                number=numbers,
-                                bounding_box=location,
-                                choice_option=choice_option,
-                                default_points=[5] * len(numbers),
-                                direction=direction,
-                                cols=len(numbers),
-                                rows=cols)
-
-            if tmp_w > 2 * choice_s_width:
-                need_revised_choice_m_list.append(choice_m)
-            else:
-                choice_m_list.append(choice_m)
-
-        remain_len = remain_len - len(current_row_choice_m_d)
-        for ele in choice_m_numbers_list.copy():
-            if ele in current_row_choice_m_d:
-                choice_m_numbers_list.remove(ele)
-
-        for ele in choice_m_numbers_list.copy():
-            if ele in current_row_chars:
-                choice_m_numbers_list.remove(ele)
-
-        # 解决单行问题
-        crt_right_max = max([int(ele['bounding_box']['xmax']) for ele in choice_m_list])
-        if limit_right - crt_right_max > choice_s_width:
-            # 存在区域
-            region_loc = {'xmin': crt_right_max + 10,
-                          'ymin': choice_m_list[0]['bounding_box']['ymin'],
-                          'xmax': limit_right,
-                          'ymax': choice_m_list[0]['bounding_box']['ymax']}
-
-            contain_dig = []
-            for i, ele in enumerate(digital_loc_arr):
-                if region_loc['xmin'] < ele[0] < region_loc['xmax'] and region_loc['ymin'] < ele[1] < region_loc['ymax']:
-                    contain_dig.append(digital_list[i])
-
-            contain_chars = [ele for ele in chars_list
-                             if region_loc['xmin'] < (ele["location"]["left"] + ele["location"]["width"] // 2) < region_loc['xmax']
-                             and
-                             region_loc['xmin'] < (ele["location"]["top"] + ele["location"]["height"] // 2) < region_loc['ymax']]
-            numbers = [-1]
-            if contain_dig or contain_chars:
-                d_ymin, d_ymax, d_xmin, d_xmax = 9999, 0, 9999, 0
-                if contain_dig:
-                    numbers = [ele["digital"] for ele in contain_dig]
-                    d_ymin = min([ele['loc'][1] for ele in contain_dig])
-                    d_ymax = max([ele['loc'][3] for ele in contain_dig])
-                    d_xmin = min([ele['loc'][0] for ele in contain_dig])
-                    d_xmax = max([ele['loc'][2] for ele in contain_dig])
-
-                c_ymin, c_ymax, c_xmin, c_xmax = 9999, 0, 9999, 0
-                if contain_chars:
-                    c_ymin = min([ele["location"]["top"] for ele in contain_chars])
-                    c_ymax = max([ele["location"]["top"] + ele["location"]["height"] for ele in contain_chars])
-                    c_xmin = min([ele["location"]["left"] for ele in contain_chars])
-                    c_xmax = max([ele["location"]["left"] + ele["location"]["width"] for ele in contain_chars])
-
-                r_ymin, r_ymax = min(d_ymin, c_ymin), max(d_ymax, c_ymax)
-                r_xmin, r_xmax = min(d_xmin, c_xmin), max(d_xmax, c_xmax)
-
-                region_loc['ymin'] = r_ymin - 10
-                region_loc['ymax'] = r_ymax + 10
-                if d_xmin == r_xmin:
-                    region_loc['xmin'] = d_xmax + 5
-                    region_loc['xmax'] = d_xmax + 5 + int(1.2 * choice_s_width)
+                    else:
+                        # 默认这种情况只有1行或2行
+                        numbers = [-1]
+                        region_xmin = crt_right_max + 5
+                        region_xmax = int(region_xmin + 1.2 * choice_s_width)
+                        region_ymin = min([int(ele['bounding_box']['ymin']) for ele in choice_m_list])
+                        region_ymax = max([int(ele['bounding_box']['ymax']) for ele in choice_m_list])
+                        region_ymax = region_ymin + (region_ymax - region_ymin) // 2  # 默认这种情况只有1行或2行
+                        region_loc = {'xmin': region_xmin, 'ymin': region_ymin, 'xmax': region_xmax, 'ymax': region_ymax}
+
+                    try:
+                        choice_m_img = utils.crop_region(image, region_loc)
+                        if 0 in choice_m_img.shape[:2]:
+                            continue
+                        right_loc, bottom_loc = adjust_choice_m(choice_m_img, mean_height, mean_width * 2)
+                        if right_loc > 0:
+                            region_loc.update(dict(xmax=right_loc + region_loc['xmin']))
+                        if bottom_loc > 0:
+                            region_loc.update(dict(ymax=bottom_loc + region_loc['ymin']))
+                    except Exception as e:
+                        print(e)
+                        traceback.print_exc()
+
+                    choice_m = dict(class_name='choice_m',
+                                    number=numbers,
+                                    bounding_box=region_loc,
+                                    choice_option='A,B,C,D',
+                                    default_points=[5],
+                                    direction=180,
+                                    cols=4,
+                                    rows=1,
+                                    single_width=(region_loc['xmax'] - region_loc['xmin']) // 4,
+                                    )
+                    choice_m_list.append(choice_m)
+
+        # 单独一行不聚类(理论上不会再到这一步了, 上个block解决)
+        for i, revised_choice_m in enumerate(need_revised_choice_m_list):
+            loc = revised_choice_m['bounding_box']
+            left_part_loc = loc.copy()
+            left_part_loc.update({'xmax': loc['xmin'] + choice_s_width})
+            choice_m_img = utils.crop_region(image, left_part_loc)
+            right_loc, bottom_loc = adjust_choice_m(choice_m_img, mean_height, mean_width * 2)
+            if right_loc > 0:
+                left_part_loc.update(dict(xmax=right_loc + left_part_loc['xmin']))
+            if bottom_loc > 0:
+                left_part_loc.update(dict(ymax=bottom_loc + left_part_loc['ymin']))
+
+            left_tmp_height = left_part_loc['ymax'] - left_part_loc['ymin']
+
+            right_part_loc = loc.copy()
+            # right_part_loc.update({'xmin': loc['xmax']-choice_s_width})
+            right_part_loc.update({'xmin': left_part_loc['xmax'] + 5})
+            choice_m_img = utils.crop_region(image, right_part_loc)
+            right_loc, bottom_loc = adjust_choice_m(choice_m_img, mean_height, mean_width * 2)
+            if right_loc > 0:
+                right_part_loc.update(dict(xmax=right_loc + right_part_loc['xmin']))
+            if bottom_loc > 0:
+                right_part_loc.update(dict(ymax=bottom_loc + right_part_loc['ymin']))
+
+            right_tmp_height = right_part_loc['ymax'] - right_part_loc['ymin']
+
+            number_len = max(1, int(revised_choice_m['rows'] // (left_tmp_height // right_tmp_height)))
+            number = [ele + revised_choice_m['number'][-1] + 1 for ele in range(number_len)]
+            rows = len(number)
+
+            revised_choice_m.update({'bounding_box': left_part_loc})
+            choice_m_list.append(revised_choice_m)
+
+            tmp = revised_choice_m.copy()
+            tmp.update({'bounding_box': right_part_loc, 'number': number, 'rows': rows})
+            choice_m_list.append(tmp)
+
+        choice_m_list_copy = choice_m_list.copy()
+        for ele in choice_m_list_copy:
+            loc = ele["bounding_box"]
+            w, h = loc['xmax'] - loc['xmin'], loc['ymax'] - loc['ymin']
+            if 2 * w * h < choice_s_width * choice_s_height:
+                choice_m_list.remove(ele)
+        return choice_m_list
+
+    else:   # 竖排
+        # 横向最大
+        choice_m_numbers_list = sorted(choice_m_numbers_list, key=lambda x: x['loc'][2] - x['loc'][0], reverse=True)
+        remain_len = len(choice_m_numbers_list)
+        choice_m_list = list()
+        need_revised_choice_m_list = list()
+        while remain_len > 0:
+            # 先确定属于同列的数据,然后找字母划分block
+            random_index = 0
+            xmax_limit = choice_m_numbers_list[random_index]["loc"][2]
+            xmin_limit = choice_m_numbers_list[random_index]["loc"][0]
+            # choice_m_numbers_list.pop(random_index)
+
+            # 当前行的choice_m
+            current_row_choice_m_d = [ele for ele in choice_m_numbers_list if xmin_limit < ele["loc"][4] < xmax_limit]
+            current_row_choice_m_d = sorted(current_row_choice_m_d, key=lambda x: x["loc"][1])
+            # current_row_choice_m_d.append(choice_m_numbers_list[random_index])
+            split_pix = sorted([ele["loc"][1] for ele in current_row_choice_m_d])  # ymin排序
+            split_index = get_split_index(split_pix, dif=choice_s_height * 0.8)
+            split_pix = [split_pix[ele] for ele in split_index[:-1]]
+
+            block_list = []
+            for i in range(len(split_index) - 1):
+                block = current_row_choice_m_d[split_index[i]: split_index[i + 1]]
+                if len(block) > 1:
+                    remain_len = remain_len - (len(block) - 1)
+                    numbers_new = []
+                    loc_new = [[], [], [], []]
+                    for blk in block:
+                        loc_old = blk["loc"]
+                        numbers_new.extend(blk["numbers"])
+                        for ii in range(4):
+                            loc_new[ii].append(loc_old[ii])
+
+                    loc_new[0] = min(loc_new[0])
+                    loc_new[1] = min(loc_new[1])
+                    loc_new[2] = max(loc_new[2])
+                    loc_new[3] = max(loc_new[3])
+
+                    loc_new.append(loc_new[0] + (loc_new[2] - loc_new[0]) // 2)
+                    loc_new.append(loc_new[1] + (loc_new[3] - loc_new[1]) // 2)
+
+                    block = [{"numbers": sorted(numbers_new), "loc": loc_new, "direction": block[0]["direction"]}]
+
+                block_list.extend(block)
+
+            current_row_choice_m_d = block_list
+            current_row_chars = [ele for ele in chars_list
+                                 if xmin_limit < (ele["location"]["top"] + ele["location"]["height"] // 2) < xmax_limit]
+
+            split_pix.append(limit_bottom)
+            for i in range(0, len(split_pix) - 1):
+                top_limit = split_pix[i]
+                bottom_limit = split_pix[i + 1]
+                block_chars = [ele for ele in current_row_chars
+                               if top_limit < (ele["location"]["left"] + ele["location"]["width"] // 2) < bottom_limit]
+
+                a_z = '_ABCD_FGHT'
+                letter_text = set([ele['char'].upper() for ele in block_chars if ele['char'].upper() in a_z])
+                letter_index = [a_z.index(ele['char'].upper()) for ele in block_chars if ele['char'].upper() in a_z]
+
+                letter_index_times = {ele: 0 for ele in set(letter_index)}
+                for l_index in letter_index:
+                    letter_index_times[l_index] += 1
+
+                if (a_z.index("T") in letter_index) and (a_z.index("F") in letter_index):
+                    choice_option = "T, F"
+                    cols = 2
                 else:
-                    if 1.2 * (r_xmax - r_xmin) > choice_s_width:
-                        region_loc['xmin'] = r_xmin - 10
-                        region_loc['xmax'] = r_xmax + 10
+                    if len(letter_index) < 1:
+                        tmp = 4
+                        choice_option = 'A,B,C,D'
                     else:
-                        region_loc['xmin'] = max((r_xmax - r_xmin) // 2 + r_xmin - choice_s_width,
-                                                 crt_right_max + 10)
-                        region_loc['xmax'] = min((r_xmax - r_xmin) // 2 + r_xmin + choice_s_width ,
-                                                 limit_right)
+                        tmp = max(set(letter_index))
+                        choice_option = ",".join(a_z[min(letter_index):tmp + 1])
+                    cols = tmp
+
+                bias = 3  # pix
+                current_loc = current_row_choice_m_d[i]["loc"]
+                location = dict(xmin=current_loc[0],
+                                ymin=current_loc[3] + bias,
+                                xmax=current_loc[1],
+                                ymax=bottom_limit - bias)
+
+                try:
+                    choice_m_img = utils.crop_region(image, location)
+                    right_loc, bottom_loc = adjust_choice_m(choice_m_img, mean_height, mean_width * 2)
+                    if right_loc > 0:
+                        location.update(dict(xmax=right_loc + location['xmin']))
+                    if bottom_loc > 0:
+                        location.update(dict(ymax=bottom_loc + location['ymin']))
+                except Exception as e:
+                    print(e)
+                    traceback.print_exc()
+
+                tmp_w, tmp_h = location['xmax'] - location['xmin'], location['ymax'] - location['ymin'],
+                numbers = current_row_choice_m_d[i]["numbers"]
+                direction = current_row_choice_m_d[i]["direction"]
+                if direction == 180:
+                    choice_m = dict(class_name='choice_m',
+                                    number=numbers,
+                                    bounding_box=location,
+                                    choice_option=choice_option,
+                                    default_points=[5] * len(numbers),
+                                    direction=direction,
+                                    cols=cols,
+                                    rows=len(numbers))
+                else:
+                    choice_m = dict(class_name='choice_m',
+                                    number=numbers,
+                                    bounding_box=location,
+                                    choice_option=choice_option,
+                                    default_points=[5] * len(numbers),
+                                    direction=direction,
+                                    cols=len(numbers),
+                                    rows=cols)
+
+                if tmp_h > 2 * choice_s_height:
+                    need_revised_choice_m_list.append(choice_m)
+                else:
+                    choice_m_list.append(choice_m)
 
-            else:
-                # 默认这种情况只有1行或2行
-                numbers = [-1]
-                region_xmin = crt_right_max + 5
-                region_xmax = int(region_xmin + 1.2 * choice_s_width)
-                region_ymin = min([int(ele['bounding_box']['ymin']) for ele in choice_m_list])
-                region_ymax = max([int(ele['bounding_box']['ymax']) for ele in choice_m_list])
-                region_ymax = region_ymin + (region_ymax - region_ymin) // 2  # 默认这种情况只有1行或2行
-                region_loc = {'xmin': region_xmin, 'ymin': region_ymin, 'xmax': region_xmax, 'ymax': region_ymax}
+            remain_len = remain_len - len(current_row_choice_m_d)
+            for ele in choice_m_numbers_list.copy():
+                if ele in current_row_choice_m_d:
+                    choice_m_numbers_list.remove(ele)
 
-            try:
-                choice_m_img = utils.crop_region(image, region_loc)
-                if 0 in choice_m_img.shape[:2]:
-                    continue
-                right_loc, bottom_loc = adjust_choice_m(choice_m_img, mean_height, mean_width * 2)
-                if right_loc > 0:
-                    region_loc.update(dict(xmax=right_loc + region_loc['xmin']))
-                if bottom_loc > 0:
-                    region_loc.update(dict(ymax=bottom_loc + region_loc['ymin']))
-            except Exception as e:
-                print(e)
-                traceback.print_exc()
+            for ele in choice_m_numbers_list.copy():
+                if ele in current_row_chars:
+                    choice_m_numbers_list.remove(ele)
 
-            choice_m = dict(class_name='choice_m',
-                            number=numbers,
-                            bounding_box=region_loc,
-                            choice_option='A,B,C,D',
-                            default_points=[5],
-                            direction=180,
-                            cols=4,
-                            rows=1,
-                            single_width=(region_loc['xmax'] - region_loc['xmin']) // 4,
-                            )
-            choice_m_list.append(choice_m)
-
-    # 单独一行不聚类(理论上不会再到这一步了, 上个block解决)
-    for i, revised_choice_m in enumerate(need_revised_choice_m_list):
-        loc = revised_choice_m['bounding_box']
-        left_part_loc = loc.copy()
-        left_part_loc.update({'xmax': loc['xmin'] + choice_s_width})
-        choice_m_img = utils.crop_region(image, left_part_loc)
-        right_loc, bottom_loc = adjust_choice_m(choice_m_img, mean_height, mean_width * 2)
-        if right_loc > 0:
-            left_part_loc.update(dict(xmax=right_loc + left_part_loc['xmin']))
-        if bottom_loc > 0:
-            left_part_loc.update(dict(ymax=bottom_loc + left_part_loc['ymin']))
-
-        left_tmp_height = left_part_loc['ymax'] - left_part_loc['ymin']
-
-        right_part_loc = loc.copy()
-        # right_part_loc.update({'xmin': loc['xmax']-choice_s_width})
-        right_part_loc.update({'xmin': left_part_loc['xmax'] + 5})
-        choice_m_img = utils.crop_region(image, right_part_loc)
-        right_loc, bottom_loc = adjust_choice_m(choice_m_img, mean_height, mean_width * 2)
-        if right_loc > 0:
-            right_part_loc.update(dict(xmax=right_loc + right_part_loc['xmin']))
-        if bottom_loc > 0:
-            right_part_loc.update(dict(ymax=bottom_loc + right_part_loc['ymin']))
-
-        right_tmp_height = right_part_loc['ymax'] - right_part_loc['ymin']
-
-        number_len = max(1, int(revised_choice_m['rows'] // (left_tmp_height // right_tmp_height)))
-        number = [ele + revised_choice_m['number'][-1] + 1 for ele in range(number_len)]
-        rows = len(number)
-
-        revised_choice_m.update({'bounding_box': left_part_loc})
-        choice_m_list.append(revised_choice_m)
-
-        tmp = revised_choice_m.copy()
-        tmp.update({'bounding_box': right_part_loc, 'number': number, 'rows': rows})
-        choice_m_list.append(tmp)
-
-    choice_m_list_copy = choice_m_list.copy()
-    for ele in choice_m_list_copy:
-        loc = ele["bounding_box"]
-        w, h = loc['xmax'] - loc['xmin'], loc['ymax'] - loc['ymin']
-        if 2 * w * h < choice_s_width * choice_s_height:
-            choice_m_list.remove(ele)
-    return choice_m_list
+        choice_m_list_copy = choice_m_list.copy()
+        for ele in choice_m_list_copy:
+            loc = ele["bounding_box"]
+            w, h = loc['xmax'] - loc['xmin'], loc['ymax'] - loc['ymin']
+            if 2 * w * h < choice_s_width * choice_s_height:
+                choice_m_list.remove(ele)
+
+        return choice_m_list
 
 
-def infer_choice_m(image, tf_sheet, col_split_x, ocr, xml=None):
-    infer_box_list = ocr2sheet(image, col_split_x, ocr, xml)
+def infer_choice_m(image, tf_sheet, infer_box_list, col_split_x, xml=None):
+    # infer_box_list = ocr2sheet(image, col_split_x, ocr, xml)
     if not infer_box_list:
         for ele in tf_sheet:
             if ele['class_name'] == 'choice':

+ 17 - 6
segment/sheet_resolve/analysis/sheet/ocr_sheet.py

@@ -161,10 +161,12 @@ def ocr2sheet(image, col_split_list, raw_ocr, xml_path=None):
 
         # print(raw_chn_index)
 
-        left_limit = min([ele['location']['left'] for ele in ocr_res
-                          if ele['location']['width'] >= ele['location']['height']]) - 10
-        right_limit = max([ele['location']['right'] for ele in ocr_res
-                           if ele['location']['width'] >= ele['location']['height']]) + 10
+        left_limit, right_limit = col_split_list[ii], col_split_list[ii+1]
+        if len(ocr_res) > 0:
+            left_limit = min([ele['location']['left'] for ele in ocr_res
+                              if ele['location']['width'] >= ele['location']['height']]) - 10
+            right_limit = max([ele['location']['right'] for ele in ocr_res
+                               if ele['location']['width'] >= ele['location']['height']]) + 10
 
         # 文字识别结果可能连横跨栏,并导致文字分栏错误
         left_limit = max(left_limit, col_split_list[ii])
@@ -237,8 +239,17 @@ def sheet_sorted(regions, split_x):
     region_contain_set.extend(region_contain_set_)
 
     # split_x = tell_columns(image, regions)
-    regions = sorted(regions, key=lambda x: x['bounding_box']['xmin'])
-    x_min_list = [ele['bounding_box']['xmin'] for ele in regions]
+    for rg in regions:
+        xi = rg['bounding_box']['xmin']
+        xm = rg['bounding_box']['xmax']
+        yi = rg['bounding_box']['ymin']
+        ym = rg['bounding_box']['ymax']
+        xmid = (xm - xi)//2 + xi
+        ymid = (ym - yi)//2 + yi
+        rg['bounding_box'].update({'xmid': xmid, 'ymid': ymid})
+    regions = sorted(regions, key=lambda x: x['bounding_box']['xmid'])
+
+    x_min_list = [ele['bounding_box']['xmid'] for ele in regions]
 
     # 分栏
     col_list = []

+ 41 - 40
segment/sheet_resolve/analysis/sheet/sheet_adjust.py

@@ -11,6 +11,8 @@ import numpy as np
 ''' 根据CV检测矩形框 调整模型输出框'''
 ''' LSD直线检测 暂时改用 霍夫曼检测'''
 
+ADJUST_CLASS = ['solve', 'solve0', 'composition', 'composition0', 'choice', 'cloze', 'correction']
+
 
 # 用户自己计算阈值
 def custom_threshold(gray, type_inv=cv2.THRESH_BINARY):
@@ -53,7 +55,7 @@ def dilation_img(image, kernel_size):
 # 图像padding
 def image_padding(image, padding_w, padding_h):
     h, w = image.shape[:2]
-    if (3 == len(image.shape)):
+    if 3 == len(image.shape):
         image_new = np.zeros((h + padding_h, w + padding_w, 3), np.uint8)
     else:
         image_new = np.zeros((h + padding_h, w + padding_w), np.uint8)
@@ -62,7 +64,7 @@ def image_padding(image, padding_w, padding_h):
 
 
 def horizontal_projection(img_bin, mut=0):
-    '''水平方向投影'''
+    """水平方向投影"""
     h, w = img_bin.shape[:2]
     hist = [0 for i in range(w)]
     for x in range(w):
@@ -90,33 +92,33 @@ def vertical_projection(img_bin, mut=0):
 
 
 def get_white_blok_pos(arry, blok_w=0):
-    '''获取投影结果中的白色块'''
+    """获取投影结果中的白色块"""
     pos = []
     start = 1
     x0 = 0
     x1 = 0
     for idx, val in enumerate(arry):
-        if (start):
+        if start:
             if val:
                 x0 = idx
                 start = 0
         else:
-            if (0 == val):
+            if 0 == val:
                 x1 = idx
                 start = 1
-                if (x1 - x0 > blok_w):
+                if x1 - x0 > blok_w:
                     pos.append((x0, x1))
-    if (0 == start):
+    if 0 == start:
         x1 = len(arry) - 1
-        if (x1 - x0 > blok_w):
+        if x1 - x0 > blok_w:
             pos.append((x0, x1))
     return pos
 
 
 def get_decide_boberLpa(itemRe, itemGT):
-    '''
+    """
     IOU 计算
-    '''
+    """
     x1 = int(itemRe[0])
     y1 = int(itemRe[1])
     x1_ = int(itemRe[2])
@@ -153,6 +155,7 @@ def get_decide_boberLpa(itemRe, itemGT):
 
 
 # 查找连通区域 微调专用 不通用
+
 def get_contours(image):
     # image = cv2.imread(img_path,0)
     # if debug: plt_imshow(image)
@@ -172,29 +175,29 @@ def get_contours(image):
         w = int(box[2])
         h = int(box[3])
         area = int(box[4])
-        if (w < img_w / 5 or w > img_w - 10 or h < 50 or h > img_h - 10):  # 常见框大小限定
+        if w < img_w / 5 or w > img_w - 10 or h < 50 or h > img_h - 10:  # 常见框大小限定
             continue
-        if (img_w > img_h):  # 多栏答题卡 w大于宽度的一般肯定是错误的框
-            if (w > img_w / 2):
+        if img_w > img_h:  # 多栏答题卡 w大于宽度的一般肯定是错误的框
+            if w > img_w / 2:
                 continue
-        if (area < w * h / 3):  # 大框套小框 中空白色区域形成的面积 排除
+        if area < w * h / 3:  # 大框套小框 中空白色区域形成的面积 排除
             continue
         rects.append((x0, y0, x0 + w, y0 + h))
     return rects
 
 
 def adjust_alarm_info(image, box):
-    '''
+    """
     调整上下坐标 排除内部含有了边框线情况
     左右调整只有100%确认的 从边界开始遇到的第一个非0列就终止 误伤情况太多
     LSD算法转不过来  霍夫曼检测不靠谱 连通区域测试后排除误伤情况太多  改用投影
     image: 灰度 非 二值图
     box  : 坐标信息
-    '''
+    """
     # debug
     # debug = 0
 
-    if (image is None):
+    if image is None:
         print("error image")
         return box
     img_box = image[box[1]:box[3], box[0]:box[2]]
@@ -259,15 +262,15 @@ def adjust_alarm_info(image, box):
     new_right = w_len - 1
     b_flag = [0, 0]
     for idx, val in enumerate(hist_proj):
-        if (0 == b_flag[0]):
-            if (val != 0):
+        if 0 == b_flag[0]:
+            if val != 0:
                 new_left = idx
                 b_flag[0] = 1
-        if (0 == b_flag[1]):
-            if (hist_proj[w_len - 1 - idx] != 0):
+        if 0 == b_flag[1]:
+            if hist_proj[w_len - 1 - idx] != 0:
                 new_right = w_len - idx - 1
                 b_flag[1] = 1
-        if (b_flag[0] and b_flag[1]):
+        if b_flag[0] and b_flag[1]:
             break
 
     new_top = box[1] + y_pos[max_id][0]
@@ -283,12 +286,12 @@ def adjust_alarm_info(image, box):
 
 
 def adjust_zg_info(image, box, cv_boxes):
-    '''
+    """
     调整大区域的box
     1、cvbox要与box纵坐标有交叉
     2、IOU值大于0。8时 默认相等拷贝区域坐标
-    '''
-    if (image is None):
+    """
+    if image is None:
         return box
 
     min_rotio = 0.5
@@ -299,9 +302,9 @@ def adjust_zg_info(image, box, cv_boxes):
     tmp_rotio = 0
     rc_mz = box
     for idx, cv_box in enumerate(cv_boxes):
-        if ((box[1] - 10) > (cv_box[3])):  # 首先要保证纵坐标有交叉
+        if (box[1] - 10) > (cv_box[3]):  # 首先要保证纵坐标有交叉
             continue
-        if ((box[3] + 10) < cv_box[1]):
+        if (box[3] + 10) < cv_box[1]:
             continue
 
         jc_x = max(box[0], cv_box[0])
@@ -310,14 +313,14 @@ def adjust_zg_info(image, box, cv_boxes):
         bj_y = max(box[2], cv_box[2])
 
         rt = abs(jc_y - jc_x) * 1.0 / abs(bj_y - bj_x) * 1.0
-        if (rt < min_rotio):
+        if rt < min_rotio:
             continue
         jc_boxes.append(cv_box)
-        if (rt > tmp_rotio):
+        if rt > tmp_rotio:
             rc_mz = cv_box
             tmp_rotio = rt
     # 判断 调整
-    if (len(jc_boxes) != 0):
+    if len(jc_boxes) != 0:
         box[0] = rc_mz[0]
         box[2] = rc_mz[2]
         b_find = 0
@@ -325,29 +328,29 @@ def adjust_zg_info(image, box, cv_boxes):
         rc_biggst = rc_mz
         for mz_box in jc_boxes:
             iou = get_decide_boberLpa(mz_box, box)
-            if (iou > 0.8):
+            if iou > 0.8:
                 b_find = 1
                 frotio = iou
                 rc_biggst = mz_box
-        if (b_find):
+        if b_find:
             box[1] = rc_biggst[1]
             box[3] = rc_biggst[3]
     return box
 
 
 def adjust_item_edge(img_path, reback_json):
-    '''
+    """
     根据图像的CV分析结果和 模型直接输出结果 对模型输出的边框做微调
     1、外接矩形查找
     2、LSD直线检测 替换方法 霍夫曼直线检测
     3、只处理有把握的情况 任何含有不确定因素的一律不作任何处理
     img_path: 待处理图像绝对路径
     re_json : 模型输出结果
-    '''
+    """
     debug = 1
     # 存放新的结果
     re_json = copy.deepcopy(reback_json)
-    if (not os.path.exists(img_path) or 0 == len(re_json)):
+    if not os.path.exists(img_path) or 0 == len(re_json):
         return
     image = cv2.imread(img_path, 0)
     # 获取CV连通区域结果
@@ -365,7 +368,7 @@ def adjust_item_edge(img_path, reback_json):
         box = [item["bounding_box"]["xmin"], item["bounding_box"]["ymin"], item["bounding_box"]["xmax"],
                item["bounding_box"]["ymax"]]
         # print(name ,box)
-        if (name == "alarm_info" or name == "page" or name == "type_score"):
+        if name == "alarm_info" or name == "page" or name == "type_score":
             if debug:
                 cv2.rectangle(image_draw, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
             new_box = adjust_alarm_info(image, box)
@@ -423,7 +426,7 @@ def adjust_item_edge_by_gray_image(image, reback_json):
         box = [item["bounding_box"]["xmin"], item["bounding_box"]["ymin"], item["bounding_box"]["xmax"],
                item["bounding_box"]["ymax"]]
         # print(name ,box)
-        if (name == "alarm_info" or name == "page" or name == "type_score"):
+        if name == "alarm_info" or name == "page" or name == "type_score":
             if debug:
                 cv2.rectangle(image_draw, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
             new_box = adjust_alarm_info(image, box)
@@ -433,9 +436,7 @@ def adjust_item_edge_by_gray_image(image, reback_json):
             item["bounding_box"]["xmax"] = box[2]
             item["bounding_box"]["ymin"] = box[1]
             item["bounding_box"]["ymax"] = box[3]
-        elif (name == "solve" or name == "solve0"
-              or name == "cloze" or name == "choice"
-              or name == "composition" or name == "composition0"):
+        elif name in ADJUST_CLASS:
             if debug:
                 cv2.rectangle(image_draw, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
             new_box = adjust_zg_info(image, box, cv_boxes)

+ 34 - 30
segment/sheet_resolve/analysis/sheet/sheet_infer.py

@@ -117,10 +117,11 @@ def infer_bar_code(image, ocr_dict_list, attention_region):
                     xmax = right_board_location['left'] + right_board_location['width']
                     ymax = down_board_location['top'] + down_board_location['height']
 
-                    xmin = max(1, int(xmin)-5)
-                    ymin = max(1, int(ymin)-5)
-                    xmax = min(int(xmax), img_rows - 1)
-                    ymax = min(int(ymax), img_cols - 1 )
+                    bias = 5
+                    xmin = max(1, int(xmin)-bias)
+                    ymin = max(1, int(ymin)-bias)
+                    xmax = min(int(xmax)+bias, img_rows - 1)
+                    ymax = min(int(ymax)+bias, img_cols - 1)
                     bar_code_dict = {'class_name': 'bar_code',
                                      'bounding_box': {'xmin': xmin, 'ymin': ymin, 'xmax': xmax, 'ymax': ymax}}
                     bar_code_dict_list.append(bar_code_dict)
@@ -214,7 +215,7 @@ def infer_exam_number(image, ocr_dict_list, existed_regions, times_threshold=5):
                 if 9 in key_digital:
                     break
 
-    if 0 in key_digital and 9 in key_digital:
+    if 0 in key_digital and 9 in key_digital and len(key_digital) > 4:
         mean_height = sum(all_height)//10
         exam_number_dict = {'class_name': 'exam_number',
                             'bounding_box': {'xmin': xmin, 'ymin': ymin, 'xmax': xmax, 'ymax': ymax+mean_height},
@@ -257,29 +258,30 @@ def infer_exam_number(image, ocr_dict_list, existed_regions, times_threshold=5):
 
         iou_cond = True
         exam_number_dict_list_check = []
-        for exam_number_dict in exam_number_dict_list:
-            exam_number_polygon = Polygon([(exam_number_dict["xmin"], exam_number_dict["ymin"]),
-                                           (exam_number_dict["xmax"], exam_number_dict["ymin"]),
-                                           (exam_number_dict["xmax"], exam_number_dict["ymax"]),
-                                           (exam_number_dict["xmin"], exam_number_dict["ymax"])])
-            for region in existed_regions:
-                class_name = region["class_name"]
-
-                if class_name in ["attention", "solve", "choice", "choice_m", 'choice_s', "cloze", 'cloze_s',
-                                  'bar_code', 'qr_code', 'composition', 'solve0']:
-                    coordinates = region['bounding_box']
-                    xmin = coordinates['xmin']
-                    ymin = coordinates['ymin']
-                    xmax = coordinates['xmax']
-                    ymax = coordinates['ymax']
-                    existed_polygon = Polygon([(xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)])
-                    overlab_area = existed_polygon.intersection(exam_number_polygon).area
-                    iou = overlab_area / (exam_number_polygon.area + existed_polygon.area - overlab_area)
-                    if iou > 0:
-                        iou_cond = False
-                        break
-            if iou_cond:
-                exam_number_dict_list_check.append(exam_number_polygon)
+        if len(exam_number_dict_list) > 0:
+            for exam_number_dict in exam_number_dict_list:
+                exam_number_polygon = Polygon([(exam_number_dict["xmin"], exam_number_dict["ymin"]),
+                                               (exam_number_dict["xmax"], exam_number_dict["ymin"]),
+                                               (exam_number_dict["xmax"], exam_number_dict["ymax"]),
+                                               (exam_number_dict["xmin"], exam_number_dict["ymax"])])
+                for region in existed_regions:
+                    class_name = region["class_name"]
+
+                    if class_name in ["attention", "solve", "choice", "choice_m", 'choice_s', "cloze", 'cloze_s',
+                                      'bar_code', 'qr_code', 'composition', 'solve0']:
+                        coordinates = region['bounding_box']
+                        xmin = coordinates['xmin']
+                        ymin = coordinates['ymin']
+                        xmax = coordinates['xmax']
+                        ymax = coordinates['ymax']
+                        existed_polygon = Polygon([(xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)])
+                        overlab_area = existed_polygon.intersection(exam_number_polygon).area
+                        iou = overlab_area / (exam_number_polygon.area + existed_polygon.area - overlab_area)
+                        if iou > 0:
+                            iou_cond = False
+                            break
+                if iou_cond:
+                    exam_number_dict_list_check.append(exam_number_polygon)
 
         return exam_number_dict_list_check
 
@@ -393,7 +395,8 @@ def exam_number_adjust_infer(image, regions):
         region = regions[i]
         if region['class_name'] == 'exam_number_s':
             loc = region['bounding_box']
-            box.append([loc['xmin'], loc['ymin'], loc['xmax'], loc['ymax']])
+            if loc['ymax'] - loc['ymin'] > loc['xmax'] - loc['xmin']:
+                box.append([loc['xmin'], loc['ymin'], loc['xmax'], loc['ymax']])
 
         if region['class_name'] == 'exam_number':
             loc = region['bounding_box']
@@ -403,7 +406,8 @@ def exam_number_adjust_infer(image, regions):
 
         if region['class_name'] == 'exam_number_w':
             loc = region['bounding_box']
-            box1.append([loc['xmin'], loc['ymax'], loc['xmax'], loc['ymax']])
+            if loc['xmax'] - loc['xmin'] > loc['ymax'] - loc['ymin']:
+                box1.append([loc['xmin'], loc['ymax'], loc['xmax'], loc['ymax']])  # exam_number_w 只取下界
 
     if not box:
         return image, regions

+ 9 - 5
segment/sheet_resolve/analysis/solve/optional_solve.py

@@ -48,10 +48,10 @@ def resolve_optional_choice(l_, t_, direction, image):
     ocr_res = get_ocr_text_and_coordinate(image)
     digital_list, chars_list, d_mean_height, d_mean_width = find_digital(ocr_res, 0, 0,)
     if not digital_list:
-        numbers = [501]
+        numbers = [501, 502]
         h, w = image.shape
         optional_choice_dict = {'class_name': 'optional_choice',
-                                'rows': 1, 'cols': 1,
+                                'rows': 1, 'cols': 2,
                                 'single_width': w,
                                 'single_height': h,
                                 'direction': direction,
@@ -63,7 +63,7 @@ def resolve_optional_choice(l_, t_, direction, image):
     else:
         numbers = sorted([ele['digital'] for ele in digital_list])
 
-        contours = find_contours(l_, t_, image, d_mean_width//2, d_mean_height//2)
+        contours = find_contours(l_, t_, image, d_mean_width*2, d_mean_height)
 
         res_region = []
         for contour in contours:
@@ -91,9 +91,13 @@ def resolve_optional_choice(l_, t_, direction, image):
         mean_w, mean_h = sum_w//len(numbers), sum_h//len(numbers)
 
         if direction == 180:
-            rows, cols = 1, len(numbers)
+            rows, cols = 1, max(2, len(numbers))
+            if len(numbers) < 2:
+                numbers.append(501)
         else:
-            rows, cols = len(numbers), 1
+            rows, cols = max(2, len(numbers)), 1
+            if len(numbers) < 2:
+                numbers.append(501)
 
         optional_choice_dict = {'class_name': 'optional_choice',
                                 'rows': rows, 'cols': cols,

+ 2 - 0
segment/sheet_resolve/tools/tf_sess.py

@@ -7,6 +7,8 @@ from segment.sheet_resolve.lib.model.config import cfg
 from segment.sheet_resolve.lib.nets.resnet_v1 import resnetv1
 from segment.sheet_resolve.tools.tf_settings import model_dict
 
+# os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
+
 
 class TfSess:
 

+ 1 - 1
segment/sheet_resolve/tools/tf_settings.py

@@ -9,7 +9,7 @@ subject_list = ['math', 'math_zxhx', 'english', 'chinese',
                 'geography', 'science_comprehensive', 'arts_comprehensive', 'cloze', 'choice']
 
 BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
-decide_blank_model = model_dir_path = os.path.join(BASE_DIR, 'model', 'decide_blank', 'model.npy')
+decide_blank_model = os.path.join(BASE_DIR, 'model', 'decide_blank', 'model.npy')
 
 xml_template_path = os.path.join(BASE_DIR, 'labels', '000000-template.xml')
 images_dir_path = os.path.join(BASE_DIR, 'images')

+ 3 - 3
segment/sheet_resolve/tools/utils.py

@@ -155,8 +155,8 @@ def img_resize(analysis_type, im):
 
 
 def resize_faster_rcnn(analysis_type, im_orig):
-    min_size = 375
-    max_size = 500
+    min_size = 1500
+    max_size = 2000
     if analysis_type == 'math_blank':
         min_size = 1500
         max_size = 2000
@@ -172,7 +172,7 @@ def resize_faster_rcnn(analysis_type, im_orig):
     im = cv2.resize(im_orig, None, None, fx=im_scale, fy=im_scale,
                     interpolation=cv2.INTER_LINEAR)
 
-    return im, (im_scale, im_scale)
+    return im, (1/im_scale, 1/im_scale)
 
 
 def resize_by_percent(im, percent):

+ 7 - 3
segment/sheet_server.py

@@ -248,6 +248,7 @@ def sheet_row_col_resolve(raw_img, sheet_dict, choice_sess, cloze_sess, xml_save
     mns_thresh_0 = 0.3
     regions = sheet_dict['regions']
     classes_name = str([ele['class_name'] for ele in regions])
+    subject = sheet_dict['subject']
 
     region_tmp = regions.copy()
     # json.dumps(sheet_dict, ensure_ascii=False)
@@ -263,7 +264,7 @@ def sheet_row_col_resolve(raw_img, sheet_dict, choice_sess, cloze_sess, xml_save
 
     if 'choice_m' in classes_name:
         try:
-            choice_dict_list = choice_m_row_col(raw_img, regions, xml_save_path)
+            choice_dict_list = choice_m_row_col(raw_img, regions, subject, xml_save_path)
             if len(choice_dict_list) > 0:
                 region_tmp.extend(choice_dict_list)
         except Exception as e:
@@ -296,12 +297,13 @@ def sheet_row_col_resolve(raw_img, sheet_dict, choice_sess, cloze_sess, xml_save
 def sheet_detail_resolve(raw_img, sheet_dict, xml_save_path, shrink=True):
     regions = sheet_dict['regions']
     classes_names_list = set([ele['class_name'] for ele in regions])
+    subject = sheet_dict['subject']
 
     region_tmp = regions.copy()
     # json.dumps(sheet_dict, ensure_ascii=False)
     if 'choice_m' in classes_names_list:
         try:
-            choice_dict_list = choice_m_row_col(raw_img, regions, xml_save_path)
+            choice_dict_list = choice_m_row_col(raw_img, regions, subject, xml_save_path)
             if shrink:
                 for ele in choice_dict_list:
                     if 'all_small_coordinate' in ele.keys():
@@ -361,6 +363,7 @@ def sheet_points(sheet_dict_list, image_list, ocr_list, if_ocr=False):
     try:
         res = get_sheet_points(sheet_list)
         sheet_dict_list = [ele['sheet_dict'] for ele in res]
+        image_list = [ele['raw_image'] for ele in res]
     except Exception as e:
         traceback.print_exc()
         sheet_dict_list = [ele['sheet_dict'] for ele in sheet_list]
@@ -384,7 +387,7 @@ def sheet_points(sheet_dict_list, image_list, ocr_list, if_ocr=False):
     if if_ocr:
         for index, ele in enumerate(sheet_total_list):
             ele.update({'sheet_ocr': ocr_list[index]})
-    return sheet_total_list
+    return sheet_total_list, image_list
 
 
 def sheet_format_output(init_number, crt_numbers, sheet_dict, image, subject, shrink):
@@ -397,6 +400,7 @@ def sheet_format_output(init_number, crt_numbers, sheet_dict, image, subject, sh
     for col_regions in col_regions_list:
         _, init_number, crt_numbers = question_number_format(init_number, crt_numbers, col_regions)
 
+    # 合并跨栏主观题
     merge_span_boxes(col_regions_list)
 
     regions = list(itertools.chain(*col_regions_list))

+ 2 - 2
segment/sheet_views.py

@@ -438,7 +438,7 @@ def analysis_box_once_with_multiple_img(request):
                     detail_xml.append(small_box_xml_path)
                     shutil.copy(raw_xml_path, small_box_xml_path)
                     try:
-                        # 找题号
+                        # 找选择题行列题号, 考号行列
                         small_box_sheet_dict = sheet_detail_resolve(image, sheet_dict, small_box_xml_path, shrink=True)
                     except Exception as e:
                         traceback.print_exc()
@@ -460,7 +460,7 @@ def analysis_box_once_with_multiple_img(request):
                     error_info = error_info + 'box resolve error;'
 
             try:
-                res_info_list = sheet_points(res_info_list, image_list, ocr_list, if_ocr=False)
+                res_info_list, image_list = sheet_points(res_info_list, image_list, ocr_list, if_ocr=False)
             except Exception as e:
                 print(e)
                 traceback.print_exc()