123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654 |
- import cv2
- import matplotlib.pylab as plt
- import numpy as np
- def read_single_img(img_path):
- try:
- im = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), -1)
- except FileNotFoundError as e:
- raise e
- return im
- def pre_process(image, blank_top=20, blank_bottom=-20, blur_size=5, sigma=5, debug=0):
- # 返回二值逆图
- blank_size = 20
- if image.ndim == 3:
- gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
- elif image.ndim == 2:
- gray = image
- else:
- raise Exception('The dimension of the image should be 2 or 3!')
- # 裁边
- gray[0:blank_top, :] = 255
- gray[blank_bottom:, :] = 255
- gray[:, 0:blank_size] = 255
- gray[:, -blank_size:] = 255
- pre = 255 - gray
- pre = cv2.GaussianBlur(pre, (blur_size, blur_size), sigma)
- binary = cv2.threshold(pre, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)[1]
- if debug == 1:
- # 显示灰度图和二值图
- plt.figure(figsize=(15, 10))
- plt.subplot(211)
- plt.title('gray')
- plt.imshow(gray, cmap='gray')
- plt.subplot(212)
- plt.title('binary')
- plt.imshow(255 - binary, cmap='gray')
- plt.show()
- return binary
- def pre_process_for_anchors(image, blank_top=20, blank_bottom=-20, h_ratio=(0.1, 0.9),
- blur_size=3, sigma=5, blank_size=20, debug=0):
- # 去掉中间内容,返回上下定位点的二值逆图
- # h_ration=0 则不裁去中间部分
- # assert(image.ndim == 2 or image.ndim == 3)
- if image.ndim == 3:
- gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
- elif image.ndim == 2:
- gray = image.copy()
- else:
- raise Exception('The dimension of the image should be 2 or 3!')
- # 裁边
- gray[0:blank_top, :] = 255
- gray[blank_bottom:, :] = 255
- gray[:, 0:blank_size] = 255
- gray[:, -blank_size:] = 255
- # 去掉中间内容
- if h_ratio != 0:
- h0 = int(image.shape[0] * h_ratio[0])
- h1 = int(image.shape[0] * h_ratio[1])
- gray[h0:h1, :] = 255
- pre = 255 - gray
- pre = cv2.GaussianBlur(pre, (blur_size, blur_size), sigma)
- binary = cv2.threshold(pre, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)[1]
- if debug == 1:
- # 显示灰度图和二值图
- plt.figure(figsize=(15, 10))
- plt.subplot(211)
- plt.title('gray')
- plt.imshow(gray, cmap='gray')
- plt.subplot(212)
- plt.title('binary')
- plt.imshow(255 - binary, cmap='gray')
- plt.show()
- return binary
- def extract_feature(binary, method=4, ker_size1=2, ker_size2=10, debug=0):
- # 对二值图进一步处理
- close_size = 3
- kernel_height = 5
- kernel_width = 1
- close_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (close_size, close_size))
- kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_height, kernel_width))
- horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (ker_size1, ker_size2))
- vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (ker_size2, ker_size1))
- if method == 1:
- # ret = cv2.dilate(binary, kernel)
- ret = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel, iterations=1)
- ret = cv2.morphologyEx(ret, cv2.MORPH_OPEN, vertical_kernel)
- elif method == 2:
- ret = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel)
- ret = cv2.morphologyEx(ret, cv2.MORPH_CLOSE, close_kernel)
- elif method == 3:
- ret = cv2.morphologyEx(binary, cv2.MORPH_OPEN, horizontal_kernel)
- ret = cv2.morphologyEx(ret, cv2.MORPH_OPEN, vertical_kernel)
- elif method == 4:
- ret = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, close_kernel)
- ret = cv2.morphologyEx(ret, cv2.MORPH_OPEN, horizontal_kernel)
- ret = cv2.morphologyEx(ret, cv2.MORPH_OPEN, vertical_kernel)
- else:
- ret = binary
- if debug == 1:
- # 显示特征提取其后的图像
- plt.figure(figsize=(15, 10))
- plt.subplot(211)
- plt.title('before feature extraction')
- plt.imshow(255 - binary, cmap='gray')
- # plt.show()
- # plt.figure(figsize=(15, 10))
- plt.subplot(212)
- plt.title('after feature extraction')
- plt.imshow(255 - ret, cmap='gray')
- plt.show()
- return ret
- def draw_contour(binary):
- (major, minor, _) = cv2.__version__.split(".") # check cv version
- boxes = []
- contours = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
- contours = contours[1] if major == '3' else contours[0]
- for i in range(0, len(contours)):
- xmin, ymin, w, h = cv2.boundingRect(contours[i])
- xmax = xmin + w
- ymax = ymin + h
- centroid = [xmin + w // 2, ymin + h // 2]
- boxes.append([xmin, ymin, xmax, ymax, centroid, w*h])
- return boxes
- def draw_connected_component(binary):
- connectivity = 8
- boxes = []
- num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary, connectivity=connectivity)
- for l in range(1, num_labels):
- xmin = stats[l, cv2.CC_STAT_LEFT]
- ymin = stats[l, cv2.CC_STAT_TOP]
- xmax = stats[l, cv2.CC_STAT_WIDTH] + xmin
- ymax = stats[l, cv2.CC_STAT_HEIGHT] + ymin
- area = stats[l, cv2.CC_STAT_AREA]
- boxes.append([xmin, ymin, xmax, ymax, [int(centroids[l][0]), int(centroids[l][1])], area])
- return boxes
- def find_boxes(binary, method='connected', debug=0):
- # 寻找轮廓
- if method == 'contour':
- boxes = draw_contour(binary)
- elif method == 'connected':
- boxes = draw_connected_component(binary)
- else:
- raise Exception('Wrong method for finding boxes!')
- if debug == 1:
- # 显示boxes信息
- boxes.sort(key=lambda x: x[4][1])
- print('The features of the boxes after feature extractions:')
- for box in boxes:
- width = box[2] - box[0]
- height = box[3] - box[1]
- w_to_h = width / height
- area = box[-1]
- centroid = box[-2]
- area_ratio = area / (width * height)
- print('width:{}, height:{}, centroid:{}, w_to_h:{}, area:{}, area ratio:{}'.
- format(width, height, centroid, w_to_h, area, area_ratio))
- return boxes
- def find_marker_by_shape(boxes,
- shape_para={'height': (80, 10), 'w2h': (3, 0.6), 'area': (6000, 500), 'area_ratio': 0.5},
- debug=0):
- # 通过形状参数寻找定位点
- area_ratio_threshold = 0.96
- max_height, min_height = shape_para['height']
- max_w2h, min_w2h = shape_para['w2h']
- max_area, min_area = shape_para['area']
- min_area_ratio = shape_para['area_ratio']
- markers = []
- for box in boxes:
- w = box[2] - box[0]
- h = box[3] - box[1]
- if box[-1] >= area_ratio_threshold*w*h and min_area <= box[-1] <= max_area:
- markers.append(box)
- elif min_height <= h <= max_height and min_w2h <= w/h <= max_w2h \
- and min_area <= box[-1] <= max_area and box[-1] >= min_area_ratio*w*h:
- markers.append(box)
- if debug == 1:
- # 显示通过形状参数找到的定位点信息
- markers.sort(reverse=True, key=lambda x: x[-1])
- print('The features of the markers picking up by shapes:')
- for box in markers:
- width = box[2] - box[0]
- height = box[3] - box[1]
- w_to_h = width / height
- area = box[-1]
- centroid = box[-2]
- area_ratio = area / (width * height)
- print('width:{}, height:{}, centroid:{}, w_to_h:{}, area:{}, area ratio:{}'.
- format(width, height, centroid, w_to_h, area, area_ratio))
- elif debug == 2:
- # 显示所有定位点信息
- print('The features of the markers without picking up by shapes:')
- for box in boxes:
- markers.append(box)
- for box in markers:
- width = box[2] - box[0]
- height = box[3] - box[1]
- w_to_h = width / height
- area = box[-1]
- centroid = box[-2]
- area_ratio = area / (width * height)
- print('width:{}, height:{}, centroid:{}, w_to_h:{}, area:{}, area ratio:{}'.
- format(width, height, centroid, w_to_h, area, area_ratio))
- return markers
- def find_box_list_by_position(box, box_list, method='h', shift_threshold=30, slope_threshold=0.2, area_threshold=0.28):
- # 根据相近原则将box加入box_list中
- if len(box_list) > 0:
- if method == 'h': # 水平分布
- index_flag, distance = -1, shift_threshold
- for index, bl in enumerate(box_list):
- d = abs(box[4][1] - bl[-1][4][1])
- if d < distance:
- distance = d
- index_flag = index
- if index_flag >= 0:
- box_list[index_flag].append(box)
- else:
- box_list.append([box])
- elif method == 'v': # 垂直分布
- index_flag, distance = -1, shift_threshold
- for index, bl in enumerate(box_list):
- d = abs(box[4][0] - bl[-1][4][0])
- if d < distance and d < abs(box[4][1] - bl[-1][4][1]) * slope_threshold:
- distance = d
- index_flag = index
- if index_flag >= 0:
- box_list[index_flag].append(box)
- else:
- box_list.append([box])
- elif method == 's': # 面积相近分布
- index_flag, area_diff = -1, area_threshold
- for index, bl in enumerate(box_list):
- d = abs((box[-1] - bl[-1][-1]) / bl[-1][-1])
- if d < area_diff:
- area_diff = d
- index_flag = index
- if index_flag >= 0:
- box_list[index_flag].append(box)
- else:
- box_list.append([box])
- else:
- box_list.append([box])
- return box_list
- def collect_markers_by_position(boxes, method='h', shift_threshold=30, slope_threshold=0.2, area_threshold=0.28, debug=0):
- # 按照相近位置排列定位点
- box_list = []
- if method == 'h': # 按水平位置相近排列
- boxes.sort(key=lambda x: x[4][0])
- for b in boxes:
- box_list = find_box_list_by_position(b, box_list, method=method, shift_threshold=shift_threshold,
- slope_threshold=slope_threshold)
- box_list.sort(key=lambda x: x[0][4][1])
- elif method == 'v': # 按垂直位置相近排列
- boxes.sort(key=lambda x: x[4][1])
- for b in boxes:
- box_list = find_box_list_by_position(b, box_list, method=method, shift_threshold=shift_threshold,
- slope_threshold=slope_threshold)
- box_list.sort(key=lambda x: x[0][4][0])
- elif method == 's': # 按面积大小相近排列
- boxes.sort(reverse=True, key=lambda x: x[-1])
- for b in boxes:
- box_list = find_box_list_by_position(b, box_list, method=method, shift_threshold=shift_threshold,
- slope_threshold=slope_threshold, area_threshold=area_threshold)
- box_list.sort(reverse=True, key=lambda x: x[0][-1])
- # if method == 'h': # 按水平位置相近排列
- # boxes.sort(key=lambda x: x[4][1])
- # for b in boxes:
- # index_flag, distance = -1, shift_threshold
- # for index, single_list in enumerate(box_list):
- # if abs(b[4][1] - single_list[-1][4][1]) < distance:
- # distance = abs(b[4][1] - single_list[-1][4][1])
- # index_flag = index
- # if index_flag >= 0:
- # box_list[index_flag].append(b)
- # else:
- # box_list.append([b])
- #
- # elif method == 'v': # 按垂直位置相近排列
- # boxes.sort(key=lambda x: x[4][0])
- # for b in boxes:
- # index_flag, distance = -1, shift_threshold
- # for index, single_list in enumerate(box_list):
- # if abs(b[4][0] - single_list[-1][4][0]) < distance:
- # distance = abs(b[4][0] - single_list[-1][4][0])
- # index_flag = index
- # if index_flag >= 0:
- # box_list[index_flag].append(b)
- # else:
- # box_list.append([b])
- if debug == 1:
- print('box list slope')
- if method == 'h':
- for box in box_list:
- if len(box) >= 2:
- for i in range(len(box)-1):
- slope = (box[i+1][4][1] - box[i][4][1])/(box[i+1][4][0] - box[i][4][0])
- print(slope)
- elif method == 'v':
- for box in box_list:
- if len(box) >= 2:
- for i in range(len(box) - 1):
- slope = (box[i + 1][4][0] - box[i][4][0]) / (box[i + 1][4][1] - box[i][4][1])
- print(slope)
- return box_list
- def collect_markers_by_area(boxes, area_threshold=2):
- # 去除面积相差过大的定位点
- boxes.sort(reverse=True, key=lambda x: x[-1])
- # print(boxes[0])
- min_area = boxes[0][-1] / area_threshold
- boxes = [b for b in boxes if b[-1] > min_area]
- return boxes
- def check_with_anchor(problem_markers, top_anchors, page_width, column_num):
- # 根据top_anchors位置去除异常markers
- min_shift = 100
- column_pos = []
- if len(top_anchors) == column_num + 1:
- column_pos.append(top_anchors[0][4][0])
- if top_anchors[1][4][0] - top_anchors[0][4][0] < top_anchors[-1][4][0] - top_anchors[-2][4][0]:
- for i in range(2, column_num+1):
- column_pos.append(top_anchors[i][4][0] - page_width)
- else:
- for i in range(1, column_num):
- column_pos.append(top_anchors[i][4][0])
- for index, markers in enumerate(problem_markers):
- remove_list = []
- for i in range(len(markers)//2):
- if abs(markers[2*i][4][0]-column_pos[index]) > min_shift or \
- abs(markers[2*i+1][4][0]-column_pos[index]-page_width) > min_shift:
- remove_list.extend([2*i, 2*i+1])
- problem_markers[index] = [problem_markers[index][i] for i in range(len(problem_markers[index]))
- if i not in remove_list]
- return problem_markers
- def remove_abnormal_marker(markers, page_width, debug=0):
- # 从markers中剔除异常点
- error = 10
- max_std = 3
- min_std = 0.1
- min_area_ratio = 0.9
- min_distance = 60
- min_slope = 0.2
- distance_list = []
- remove_list = []
- for i in range(len(markers)//2-1):
- min_flag = i
- distance_flag = abs(markers[2*i+1][4][0] - markers[2*i][4][0] - page_width) + abs(markers[2*i+1][4][1] -
- markers[2*i][4][1])
- for j in range(i+1, len(markers)//2):
- if abs(markers[2*i+1][4][0] - markers[2*j+1][4][0]) + abs(markers[2*i+1][4][1] - markers[2*j+1][4][1]) \
- < error:
- if distance_flag < abs(markers[2*j+1][4][0] - markers[2*j][4][0] - page_width) + \
- abs(markers[2*j+1][4][1] - markers[2*j][4][1]):
- remove_list.extend([2*j, 2*j+1])
- else:
- distance_flag = abs(markers[2*j+1][4][0] - markers[2*j][4][0] - page_width) + \
- abs(markers[2*j+1][4][1] - markers[2*j][4][1])
- remove_list.extend([2*min_flag, 2*min_flag+1])
- min_flag = j
- markers = [markers[i] for i in range(len(markers)) if i not in remove_list]
- remove_list = []
- if len(markers) >= 6:
- left_slope_list = np.asarray([abs((markers[2 * i][4][0] - markers[2 * i + 2][4][0])
- / (markers[2 * i][4][1] - markers[2 * i + 2][4][1]))
- for i in range(len(markers)//2-1)])
- right_slope_list = np.asarray([abs((markers[2 * i + 1][4][0] - markers[2 * i + 3][4][0]) /
- (markers[2 * i + 1][4][1] - markers[2 * i + 3][4][1]))
- for i in range(len(markers) // 2 - 1)])
- left_slope_list = left_slope_list > min_slope
- right_slope_list = right_slope_list > min_slope
- for i in range(len(left_slope_list)):
- if left_slope_list[i]:
- if i == len(left_slope_list) - 1:
- if not left_slope_list[i-1]:
- remove_list.extend([2*(i+1), 2*(i+1)+1])
- elif left_slope_list[i+1]:
- remove_list.extend([2*(i+1), 2*(i+1)+1])
- elif not left_slope_list[i+1]:
- remove_list.extend([2*i, 2*i+1])
- for i in range(len(right_slope_list)):
- if right_slope_list[i]:
- if i == len(right_slope_list) - 1:
- if not right_slope_list[i-1]:
- remove_list.extend([2*(i+1), 2*(i+1)+1])
- elif right_slope_list[i+1]:
- remove_list.extend([2*(i+1), 2*(i+1)+1])
- elif not right_slope_list[i+1]:
- remove_list.extend([2*i, 2*i+1])
- markers = [markers[i] for i in range(len(markers)) if i not in set(remove_list)]
- remove_list = []
- if len(markers) >= 2:
- left_x_list = np.asarray([markers[2*i][4][0] for i in range(len(markers)//2)])
- left_y_list = np.asarray([markers[2*i][4][1] for i in range(len(markers)//2)])
- right_x_list = np.asarray([markers[2*i+1][4][0] for i in range(len(markers)//2)])
- right_y_list = np.asarray([markers[2*i+1][4][1] for i in range(len(markers)//2)])
- distance_list = right_x_list - left_x_list
- shift_list = right_y_list - left_y_list
- left_x_mean = left_x_list.mean()
- distance_mean = distance_list.mean()
- shift_mean = shift_list.mean()
- left_x_std = left_x_list.std()
- distance_std = distance_list.std()
- shift_std = shift_list.std()
- if len(markers) >= 4:
- for i in range(len(markers)//2):
- if left_x_std > min_std and abs(left_x_list[i] - left_x_mean) / left_x_std > max_std:
- remove_list.extend([2*i, 2*i+1])
- elif shift_std > min_std and abs(shift_list[i] - shift_mean) / shift_std > max_std:
- remove_list.extend([2 * i, 2 * i + 1])
- elif distance_std > min_std and abs(distance_list[i] - distance_mean) / distance_std > max_std:
- remove_list.extend([2 * i, 2 * i + 1])
- elif len(markers) == 2:
- # area_ratio_list = np.asarray([m[-1]/((m[2]-m[0])*(m[3]-m[1])) for m in markers])
- if abs(distance_list-page_width) + abs(shift_list) > min_distance:
- remove_list.extend([0, 1])
- markers = [markers[i] for i in range(len(markers)) if i not in remove_list]
- if len(markers) >= 2:
- new_page_width = markers[1][4][0] - markers[0][4][0]
- else:
- new_page_width = page_width
- if debug == 1:
- print(len(markers))
- if len(markers) >= 4:
- print('left', left_x_mean, left_x_std)
- print(left_x_list)
- for i, x in enumerate(left_x_list):
- delta = abs(x - left_x_mean) / left_x_std
- print(delta, left_y_list[i])
- print('shift', shift_mean, shift_std)
- print(shift_list)
- for i, shift in enumerate(shift_list):
- delta = abs(shift - shift_mean) / shift_std
- print(delta, left_y_list[i])
- print('distance', distance_mean, distance_std)
- print(distance_list)
- for i, distance in enumerate(distance_list):
- delta = abs(distance - distance_mean) / distance_std
- print(delta, left_y_list[i])
- print('total')
- for i in range(len(left_x_list)):
- delta = abs(left_x_list[i] - left_x_mean) / left_x_std + abs(shift_list[i] - shift_mean) / shift_std \
- + abs(distance_list[i] - distance_mean) / distance_std
- d = abs(left_x_list[i] - left_x_mean) + abs(shift_list[i] - shift_mean) + \
- abs(distance_list[i] - distance_mean)
- print(delta, d, left_y_list[i])
- if debug == 2:
- if len(markers) >= 2:
- print('page width', page_width, 'new page width:', new_page_width,
- 'distant difference:', abs(new_page_width - page_width) + abs(markers[1][4][1]-markers[0][4][1]))
- return markers, new_page_width
- def draw_box(image, boxes, color=(0, 255, 0), debug=0):
- # 生成定位点标注框图
- for box in boxes:
- if len(box) > 0:
- cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), color, 5)
- if debug == 1:
- # 显示定位点信息
- for box in boxes:
- if len(box) == 4:
- width = box[2] - box[0]
- height = box[3] - box[1]
- w_to_h = width / height
- area = width * height
- # centroid = (box[0]+box[1])/2, (box[1]+box[3])/2
- print('width:{}, height:{}, w_to_h:{}, area:{}, top_left:{}, bottom_right:{},'.
- format(width, height, w_to_h, area, box[0:2], box[2:4]))
- elif len(box) > 4:
- width = box[2] - box[0]
- height = box[3] - box[1]
- w_to_h = width / height
- position_ratio = box[4][0] / image.shape[1]
- area = box[-1]
- centroid = box[4]
- area_ratio = area / (width * height)
- print('width:{}, height:{}, centroid:{}, position ratio:{}, w_to_h:{}, area:{}, area ratio:{}'.
- format(width, height, centroid, position_ratio, w_to_h, area, area_ratio))
- def find_pair(marker, boxes, page_width, threshold=100):
- # 若page_width为正,在boxes中找到marker的右配对, 若page_width为负, 在在boxes中找到marker的左配对
- distance = threshold
- pair_index = -1
- for i in range(len(boxes)):
- if abs(marker[4][1] - boxes[i][4][1]) + abs(boxes[i][4][0] - marker[4][0] - page_width) <= threshold:
- if abs(marker[4][1] - boxes[i][4][1]) + abs(boxes[i][4][0] - marker[4][0] - page_width) < distance:
- distance = abs(marker[4][1] - boxes[i][4][1]) + abs(boxes[i][4][0] - marker[4][0] - page_width)
- pair_index = i
- if pair_index >= 0:
- return boxes[pair_index], pair_index, distance
- else:
- return [], pair_index, distance
- def find_pair_list(marker_list, all_list, page_width, horizontal_threshold=100, debug=0):
- # all_list 中找到与marker_list最接近的配对list
- max_count = 0
- index_flag = -1
- min_distance = horizontal_threshold
- for index, l in enumerate(all_list):
- count = 0
- distance = 0
- for m in marker_list:
- if find_pair(m, l, page_width, horizontal_threshold)[1] >= 0:
- count += 1
- distance += find_pair(m, l, page_width, horizontal_threshold)[2]
- if count > max_count:
- max_count = count
- index_flag = index
- min_distance = distance / count
- elif count == max_count and count > 0:
- distance /= count
- if distance < min_distance:
- min_distance = distance
- index_flag = index
- if debug == 1:
- if index_flag >= 0:
- print('page width:', abs(marker_list[0][4][0] - all_list[index_flag][0][4][0]), 'anchor width:', page_width)
- return all_list[index_flag], index_flag, min_distance
- def find_column(anchors, width, column_num=2, debug=0):
- # 确定栏数,单栏宽度及第一栏和最后一栏的定位
- double_page_width_ratio = 0.42 # 默认双栏宽度比例
- three_page_width_ratio = 0.29 # 默认三栏宽度比例
- double_page_separation = 250 # 默认双栏栏间间距
- three_page_separation = 100 # 默认三栏栏间间距
- horizontal_threshold = 80 # 单栏宽度比例阈值
- top_anchors, bottom_anchors = anchors[:2]
- page_width = width
- if len(top_anchors) >= 2:
- for i in range(len(top_anchors)-1):
- page_width_0 = top_anchors[i+1][4][0] - top_anchors[i][4][0]
- page_width_1 = (top_anchors[i + 1][4][0] - top_anchors[i][4][0]) // 2
- page_width_2 = (top_anchors[i + 1][4][0] - top_anchors[i][4][0]) // 3
- if abs(page_width_0 - width * double_page_width_ratio) < horizontal_threshold:
- column_num = 2
- if page_width_0 < page_width:
- page_width = page_width_0
- elif abs(page_width_0 - width * three_page_width_ratio) < horizontal_threshold:
- column_num = 3
- if page_width_0 < page_width:
- page_width = page_width_0
- elif abs(page_width_1 - width * double_page_width_ratio) < horizontal_threshold:
- column_num = 2
- if page_width_1 < page_width:
- page_width = page_width_1
- elif abs(page_width_1 - width * three_page_width_ratio) < horizontal_threshold:
- column_num = 3
- if page_width_1 < page_width:
- page_width = page_width_1
- elif abs(page_width_2 - width * double_page_width_ratio) < horizontal_threshold:
- column_num = 2
- if page_width_2 < page_width:
- page_width = page_width_2
- elif abs(page_width_2 - width * three_page_width_ratio) < horizontal_threshold:
- column_num = 3
- if page_width_2 < page_width:
- page_width = page_width_2
- if page_width == width:
- if column_num == 2:
- page_width = int(width * double_page_width_ratio) # 如果没有找到合适的大定位点,使用默认的双栏宽度
- elif column_num == 3:
- page_width = int(width * three_page_width_ratio) # 如果没有找到合适的大定位点,使用默认的三栏宽度
- # 寻找第一栏和最后一栏的定位
- column_pos = []
- if len(top_anchors) >= 1:
- for i in range(4):
- if top_anchors[0][4][0] - (i + 1) * page_width < 0:
- column_pos.append(top_anchors[0][4][0] - i * page_width)
- break
- for i in range(4):
- if top_anchors[-1][4][0] + (i + 1) * page_width > width:
- column_pos.append(top_anchors[-1][4][0] + (i - 1) * page_width)
- break
- elif len(bottom_anchors) == 2:
- column_pos = [bottom_anchors[0][4][0], bottom_anchors[-1][4][0] - page_width]
- elif column_num == 2:
- column_pos = [(width - double_page_separation) // 2 - page_width, (width + double_page_separation) // 2]
- elif column_num == 3:
- column_pos = [width // 2 - three_page_separation - page_width * 3 // 2,
- width // 2 + three_page_separation + page_width // 2]
- if debug == 1:
- print('top anchors')
- for t in top_anchors:
- print(t[4])
- print('bottom anchors')
- for b in bottom_anchors:
- print(b[4])
- print('page width:', page_width, 'column number:', column_num, 'column position:', column_pos)
- return page_width, column_num, column_pos
|