Pārlūkot izejas kodu

完善选择题推断中遇到单行的情况;
删除延长线补全中原有的选择题区域内补全;
题号格式化设置错误上限。

lighttxu 4 gadi atpakaļ
vecāks
revīzija
d17675cade

BIN
db.sqlite3


+ 9 - 8
segment/sheet_resolve/analysis/resolve.py

@@ -95,7 +95,7 @@ def sheet(series_number, image_path, image, conf_thresh, mns_thresh, subject, sh
 
         try:
             choice_m_list = infer_choice_m(image, regions, ocr)
-            remain_choice_m = []
+            #remain_choice_m = []
             if len(choice_m_list) > 0:
                 choice_m_old_list = [ele for ele in regions if 'choice_m' == ele['class_name']]
                 for infer_box in choice_m_list.copy():
@@ -104,20 +104,21 @@ def sheet(series_number, image_path, image, conf_thresh, mns_thresh, subject, sh
                     for tf_box in choice_m_old_list:
                         tf_loc = tf_box['bounding_box']
                         iou = utils.cal_iou(infer_loc, tf_loc)
-                        if iou[0] > 0.85 or iou[1] > 0.85:
-                            if infer_box not in remain_choice_m:
-                                remain_choice_m.append(infer_box)
-                                choice_m_list.remove(infer_box)
+                        if iou[0] > 0.70 or iou[1] > 0.70 or iou[2] > 0.70:
+                            # if infer_box not in remain_choice_m:
+                            #     remain_choice_m.append(infer_box)
+                            #     choice_m_list.remove(infer_box)
                             regions.remove(tf_box)
-                            break
+                            # break
                         elif iou[0] > 0:
                             choice_m_list.remove(infer_box)
                             break
 
-                remain_choice_m.extend(choice_m_list)
+                #remain_choice_m.extend(choice_m_list)
 
                 # regions = [ele for ele in regions if 'choice_m' != ele['class_name']]
-                regions.extend(remain_choice_m)
+                # regions.extend(remain_choice_m)
+                regions.extend(choice_m_list)
                 infer_choice_m_flag = True
 
         except Exception as e:

+ 2 - 2
segment/sheet_resolve/analysis/sheet/analysis_sheet.py

@@ -122,7 +122,7 @@ def question_number_format(init_number, crt_numbers, sheet_dict):
     for region in sheet_dict['regions']:
         numbers = region.get("number")
         if numbers and isinstance(numbers, int):
-            if numbers <= 0 or numbers in crt_numbers:
+            if numbers <= 0 or numbers in crt_numbers or numbers >= 1000:
                 numbers = init_number
                 crt_numbers.append(numbers)
                 init_number += 1
@@ -130,7 +130,7 @@ def question_number_format(init_number, crt_numbers, sheet_dict):
             crt_numbers.append(numbers)
         if numbers and isinstance(numbers, list):
             for i, num in enumerate(numbers):
-                if num <= 0 or num in crt_numbers:
+                if num <= 0 or num in crt_numbers or num >= 1000:
                     numbers[i] = init_number
                     crt_numbers.append(init_number)
                     init_number += 1

+ 9 - 8
segment/sheet_resolve/analysis/sheet/choice_infer.py

@@ -267,11 +267,11 @@ def cluster_and_anti_abnormal(image, xml_path, digital_list, chars_list,
         arr[i] = np.array([ele["loc"][-2], ele["loc"][-1]])
 
     if choice_s_height != 0:
-        eps = int(choice_s_height * 2)
+        eps = int(choice_s_height * 2.5)
     else:
-        eps = int(mean_height * 2.5)
+        eps = int(mean_height * 3)
     print("eps: ", eps)
-    db = DBSCAN(eps=eps, min_samples=2, metric='chebyshev').fit(arr)
+    db = DBSCAN(eps=eps, min_samples=1, metric='chebyshev').fit(arr)
 
     labels = db.labels_
     # print(labels)
@@ -402,7 +402,7 @@ def cluster_and_anti_abnormal(image, xml_path, digital_list, chars_list,
         current_row_choice_m_d = sorted(current_row_choice_m_d, key=lambda x: x["loc"][0])
         # current_row_choice_m_d.append(choice_m_numbers_list[random_index])
         split_pix = sorted([ele["loc"][0] for ele in current_row_choice_m_d])  # xmin排序
-        split_index = get_split_index(split_pix)
+        split_index = get_split_index(split_pix, dif=choice_s_width*0.8)
         split_pix = [split_pix[ele] for ele in split_index[:-1]]
 
         block_list = []
@@ -444,9 +444,9 @@ def cluster_and_anti_abnormal(image, xml_path, digital_list, chars_list,
 
         # split_index.append(row_chars_xmax)  # 边界
         split_pix.append(round(split_pix[-1] + choice_s_width * 1.2))
-        for i in range(0, len(split_index) - 1):
-            left_limit = split_index[i]
-            right_limit = split_index[i + 1]
+        for i in range(0, len(split_pix) - 1):
+            left_limit = split_pix[i]
+            right_limit = split_pix[i + 1]
             block_chars = [ele for ele in current_row_chars
                            if left_limit < (ele["location"]["left"] + ele["location"]["width"] // 2) < right_limit]
 
@@ -504,6 +504,7 @@ def cluster_and_anti_abnormal(image, xml_path, digital_list, chars_list,
 
             tmp_w, tmp_h = location['xmax'] - location['xmin'], location['ymax'] - location['ymin'],
             numbers = current_row_choice_m_d[i]["numbers"]
+
             direction = current_row_choice_m_d[i]["direction"]
             if direction == 180:
                 choice_m = dict(class_name='choice_m',
@@ -584,7 +585,7 @@ def cluster_and_anti_abnormal(image, xml_path, digital_list, chars_list,
     for ele in tmp:
         loc = ele["bounding_box"]
         w, h = loc['xmax'] - loc['xmin'], loc['ymax'] - loc['ymin']
-        if w*h < choice_s_width*choice_s_height:
+        if 2*w*h < choice_s_width*choice_s_height:
             choice_m_list.remove(ele)
     return choice_m_list
 

+ 11 - 9
segment/sheet_resolve/analysis/sheet/sheet_infer.py

@@ -681,11 +681,11 @@ def infer_sheet_box(image, sheet_dict, lon_split_line, exclude_classes):
     gen_polygon_list = [k for k, g in it]
 
     # 在选择题区域的infer polygon
-    gen_choice = []
-    for ele in gen_polygon_list:
-        for choice_p in choice_polygon:
-            if ele.within(choice_p):
-                gen_choice.append(ele)
+    # gen_choice = []
+    # for ele in gen_polygon_list:
+    #     for choice_p in choice_polygon:
+    #         if ele.within(choice_p):
+    #             gen_choice.append(ele)
 
     sheet_box_area = sum([polygon.area for polygon in sheet_polygons])
     image_area = width_max * height_max
@@ -810,9 +810,10 @@ def infer_sheet_box(image, sheet_dict, lon_split_line, exclude_classes):
 
     gen_polygon_list = [polygon for index, polygon in enumerate(gen_polygon_list) if polygon.area > min_polygon.area]
 
-    if gen_choice:
-        gen_choice = sorted(gen_choice, key=lambda x: x.area)[-1]
-        gen_polygon_list.append(gen_choice)
+    # if gen_choice:
+    #     gen_choice = sorted(gen_choice, key=lambda x: x.area)[-1]
+    #     gen_polygon_list.append(gen_choice)
+
     return gen_polygon_list
 
 
@@ -1053,8 +1054,9 @@ def box_infer_and_complete(image, sheet_region_dict, ocr=''):
         'page',
         'alarm_info',
         # 'score_collect',
+        'choice_m'
         'choice_s',
-    ]
+    ]  # 不找这些区域的延长线
     y, x = image.shape[0], image.shape[1]
     x1, x2 = subfield_answer_sheet(image, sheet_region_dict)
 

+ 2 - 2
segment/sheet_views.py

@@ -52,11 +52,11 @@ tf_sess_dict = {
     # 'math_zxhx': TfSess('math_zxhx'),
     # 'math_zxhx_detail': TfSess('math_zxhx_detail'),
 
-    'math': TfSess('math'),
+    # 'math': TfSess('math'),
     # 'english': TfSess('english'),
     # 'chinese': TfSess('chinese'),
     # 'physics': TfSess('physics'),
-    # 'chemistry': TfSess('chemistry'),
+    'chemistry': TfSess('chemistry'),
     # 'biology': TfSess('biology'),
     # 'politics': TfSess('politics'),
     # 'history': TfSess('history'),