# @Author : lightXu # @File : formula_segment_and_show.py # @Time : 2019/1/24 0024 下午 13:24 import time import re import copy import math import cv2 import numpy as np import xml.etree.cElementTree as ET from segment.formula import mathpix_ocr from segment.server import get_ocr_text_and_coordinate_formula from segment.image_operation import utils def get_coordinates(word_res, formula_words_list): res_list = [] for formula_raw in formula_words_list: coordinates_start_index = formula_raw[1][0] coordinates_end_index = formula_raw[1][1] - 1 coordinates_start = word_res['chars'][coordinates_start_index]['location'] coordinates_end = word_res['chars'][coordinates_end_index]['location'] coordinates = (coordinates_start['left'], # xmin min(coordinates_start['top'], coordinates_end['top']), # ymin coordinates_end['left'] + coordinates_end['width'], # xmax max(coordinates_start['top'] + coordinates_start['height'], coordinates_end['top'] + coordinates_end['height'])) # ymax tmp_dict = {'chars': formula_raw[0], 'raw_chars': formula_raw[0], 'coordinates': coordinates, 'middle': (coordinates[0] + int((coordinates[2] - coordinates[0]) // 2), coordinates[1] + int((coordinates[3] - coordinates[1]) // 2))} res_list.append(tmp_dict) return res_list def generate_char(words, index_pair, zh=True): if index_pair: # new_words = words.copy() length = index_pair[1] - index_pair[0] gen = '' if zh: for i in range(length): gen = '中' + gen else: for i in range(length): gen = 'F' + gen words = words.replace(words[index_pair[0]:index_pair[1]], gen) return words else: return words def segment(img, save_path, access_token): # raw_img = img.copy() # img = utils.preprocess(raw_img, None) word_result_list = get_ocr_text_and_coordinate_formula(img, access_token) formula_coordinates_dict_list = [] zh_coordinates_dict_list = [] zh_char_height = 20 # default zh_char_width = 15 # default zh_char_height_list = [] zh_char_width_list = [] exclude = r'{}|{}|{}|{}|{}|{}'.format( '[ABCD]\.', # A. B. C. D. '[((][))]', # () '^[((]*[\d]+[))]', # (1) # '[((]*[a-zA-Z]{2,}[))]', # (km), (kg) '[①②③④⑤⑥⑦⑧⑨⑩]', # ①②③④⑤⑥⑦⑧⑨⑩ '[\u4e00-\u9fa5][,;:。,;:.]', # 中. '[\u4e00-\u9fa5][\d]+[\u4e00-\u9fa5]') # 中123中 for index, word_res in enumerate(word_result_list): words = word_res['words'].replace(' ', '').replace('兀', 'π') # 去除空格,baidu_api bug abcd_words_m = re.finditer(exclude, words) abcd_index_list = [(m.group(), m.span()) for m in abcd_words_m if m] words_tmp_zh = copy.copy(words) for ele in abcd_index_list: words_tmp_zh = generate_char(words_tmp_zh, ele[1], zh=True) formula_words_m = re.finditer(r'[^\u4e00-\u9fa5_"“”]+', words_tmp_zh) formula_index_list = [(m.group(), m.span()) for m in formula_words_m if m] formula_list = get_coordinates(word_res, formula_index_list) formula_coordinates_dict_list = formula_coordinates_dict_list + formula_list words_tmp_formula = copy.copy(words) for ele in abcd_index_list: words_tmp_formula = generate_char(words_tmp_formula, ele[1], zh=False) zh_words_m = re.finditer(r'[\u4e00-\u9fa5_"“”]+', words_tmp_formula) zh_index_list = [(m.group(), m.span()) for m in zh_words_m if m] zh_list = get_coordinates(word_res, zh_index_list + abcd_index_list) zh_coordinates_dict_list = zh_coordinates_dict_list + zh_list one_zh_char_m = re.match(r'[\u4e00-\u9fa5]+', words) if one_zh_char_m: index = one_zh_char_m.span()[0] zh_char_height_list.append(word_res['chars'][index]['location']['height']) zh_char_width_list.append(word_res['chars'][index]['location']['width']) if len(zh_char_width_list) > 0 and len(zh_char_height_list) > 0: zh_char_height = np.mean(zh_char_height_list) zh_char_width = np.mean(zh_char_width_list) formula_coordinates_list = [ele['coordinates'] for ele in formula_coordinates_dict_list] temp_img = img.copy() for ele in formula_coordinates_list: cv2.rectangle(temp_img, (int(ele[0]), int(ele[1])), (int(ele[2]), int(ele[3])), (0, 255, 0), 1) save_path0 = save_path.replace('.jpg', '_@_{:02d}.jpg'.format(1)) utils.write_single_img(temp_img, save_path0) # 合并公式 formula_combine_list = combine(img, save_path, formula_coordinates_list, zh_char_height, zh_char_width, 1) # 欧式距离 formula_combine_dict_list = [] for ele in formula_combine_list: middle = (ele[0] + int((ele[2] - ele[0]) // 2), ele[1] + int((ele[3] - ele[1]) // 2)) ocr_region = utils.crop_region_direct(img, ele) y, x = ocr_region.shape[0], ocr_region.shape[1] if min(y, x) <= 50: ocr_region = utils.resize_by_percent(ocr_region, 1.50) # 放大若干倍 # cv2.imshow('region', ocr_region) # if cv2.waitKey(0) == 27: # cv2.destroyAllWindows() try: mathpix_raw_chars, latex_confidence = mathpix_ocr.mathpix_api(ocr_region) # 识别公式 render_mathpix_chars = ''.format(mathpix_raw_chars) if latex_confidence < 0.2 or mathpix_raw_chars == '' or len(mathpix_raw_chars) == 1: for item in formula_coordinates_dict_list: if ele == item['coordinates']: mathpix_raw_chars = item['chars'] render_mathpix_chars = '' \ .format(item['chars']) break except Exception: render_mathpix_chars = 'formula' mathpix_raw_chars = 'formula' for item in formula_coordinates_dict_list: if ele == item['coordinates']: mathpix_raw_chars = item['chars'] render_mathpix_chars = '' \ .format(item['chars']) break print(render_mathpix_chars) tmp_dict = {'chars': render_mathpix_chars, 'middle': middle, 'coordinates': ele, 'raw_chars': mathpix_raw_chars} formula_combine_dict_list.append(tmp_dict) # res_dict = {'formula': formula_combine_list, 'zh_chars': zh_coordinates_dict_list} all_dict_list = zh_coordinates_dict_list + formula_combine_dict_list all_dict_list = sorted(all_dict_list, key=lambda k: k.get('middle')[1]) # 相邻y做差 former = np.array([ele['middle'][1] for ele in all_dict_list[:-1]]) rear = np.array([ele['middle'][1] for ele in all_dict_list[1:]]) dif = rear - former split_x_index = [index for index, ele in enumerate(dif) if ele >= zh_char_height] # y轴排序 # 对整体图像大小进行resize scale = 1 h, w = img.shape[0], img.shape[1] if w > 1000: scale = float(1000 / w) elif h < 100: scale = float(100 / h) img_resize = utils.resize_by_percent(img, scale) utils.write_single_img(img_resize, save_path) if not split_x_index: all_dict_list = sorted(all_dict_list, key=lambda k: k.get('middle')[0]) # x轴排序 lines = [ele['chars'] for ele in all_dict_list] raw_lines = [ele['raw_chars'] for ele in all_dict_list] for ele in all_dict_list: bbox = [box * scale for box in ele['coordinates']] cv2.rectangle(img_resize, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (0, 255, 0), 3) utils.write_single_img(img_resize, save_path) return lines, raw_lines, h else: res_list = [] split_x_index = [ele + 1 for ele in split_x_index] # 索引值扩大 split_x_index.insert(0, 0) split_x_index.insert(-1, len(all_dict_list)) split_x_index = sorted(list(set(split_x_index))) for i, split in enumerate(split_x_index[1:]): one_line = all_dict_list[split_x_index[i]:split_x_index[i + 1]] one_line = sorted(one_line, key=lambda k: k.get('middle')[0]) # x轴排序 res_list.append(one_line) lines = [] raw_lines = [] for ele in res_list: line_chars = '' raw_lines_chars = '' for ele1 in ele: chars = ele1['chars'] raw_chars = ele1['raw_chars'] line_chars = line_chars + chars raw_lines_chars = raw_lines_chars + raw_chars bbox = [box * scale for box in ele1['coordinates']] cv2.rectangle(img_resize, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (0, 255, 0), 1) lines.append(line_chars + '\n') raw_lines.append(raw_lines_chars + '\n') utils.write_single_img(img_resize, save_path) # print(lines) return lines, raw_lines, h def combine(img, save_path, formula_coordinates_list, zh_char_height, zh_char_width, draw_index): img_draw = img.copy() formula_coordinates_list = sorted(formula_coordinates_list, key=lambda k: k[0]) formula_coordinates_list = sorted(formula_coordinates_list, key=lambda k: k[1]) # 先x轴,再y轴排序 recursion_flag = False del_list = [] temp_list = formula_coordinates_list.copy() for i, outer in enumerate(temp_list): # xmin, ymin, xmax, ymax for j, inner in enumerate(temp_list): # xmin, ymin, xmax, ymax if not i == j: min_distance, flag = get_min_distance(outer, inner) combine_coordinate = () if flag == 'i': recursion_flag = True combine_coordinate = (min(outer[0], inner[0]), min(outer[1], inner[1]), max(outer[2], inner[2]), max(outer[3], inner[3])) elif flag == 'h' and min_distance <= 1: recursion_flag = True combine_coordinate = (min(outer[0], inner[0]), min(outer[1], inner[1]), max(outer[2], inner[2]), max(outer[3], inner[3])) elif flag == 'w' and min_distance <= zh_char_width*2//3: recursion_flag = True combine_coordinate = (min(outer[0], inner[0]), min(outer[1], inner[1]), max(outer[2], inner[2]), max(outer[3], inner[3])) elif flag == 'c' and min_distance <= 1: recursion_flag = True combine_coordinate = (min(outer[0], inner[0]), min(outer[1], inner[1]), max(outer[2], inner[2]), max(outer[3], inner[3])) if combine_coordinate: if not combine_coordinate == outer and not combine_coordinate == inner: # 避免全包围的情况 del_list.append(outer) del_list.append(inner) if combine_coordinate == outer: del_list.append(inner) if combine_coordinate == inner: del_list.append(outer) formula_coordinates_list.append(combine_coordinate) res = list(set(formula_coordinates_list) - set(del_list)) if recursion_flag: draw_index = draw_index + 1 for ele in res: cv2.rectangle(img_draw, (int(ele[0]), int(ele[1])), (int(ele[2]), int(ele[3])), (0, 255, 0), 1) save_path_temp = save_path.replace('.jpg', '_@_{:02d}.jpg'.format(draw_index)) utils.write_single_img(img_draw, save_path_temp) return combine(img, save_path, res, zh_char_height, zh_char_width, draw_index) else: for ele in res: cv2.rectangle(img_draw, (int(ele[0]), int(ele[1])), (int(ele[2]), int(ele[3])), (0, 255, 0), 1) save_path_temp = save_path.replace('.jpg', '_@_final.jpg') utils.write_single_img(img_draw, save_path_temp) return res def get_min_distance_square(coordinate1, coordinate2): # 顶点间欧式距离最小值的平方和 all_points1 = [(x, y) for x in [coordinate1[0], coordinate1[2]] for y in [coordinate1[1], coordinate1[3]]] all_points2 = [(x, y) for x in [coordinate2[0], coordinate2[2]] for y in [coordinate2[1], coordinate2[3]]] distance_list = [] for index1, point1 in enumerate(all_points1): for index2, point2 in enumerate(all_points2): distance = (point1[0] - point2[0]) ** 2 + (point1[1] - point2[1]) ** 2 distance_list.append(distance) min_distance = min(distance_list) return min_distance def get_min_distance(coordinate1, coordinate2): # 欧式距离最小值 def dist(point1, point2): distance = (point1[0] - point2[0]) ** 2 + (point1[1] - point2[1]) ** 2 return math.sqrt(distance) (x1, y1, x1b, y1b) = coordinate1 (x2, y2, x2b, y2b) = coordinate2 left = x2b < x1 # 2在1的坐标左边 right = x1b < x2 # 2在1的坐标右边 bottom = y2b < y1 # 2在1的坐标下边 top = y1b < y2 # 2在1的坐标上边 if top and left: return dist((x1, y1b), (x2b, y2)), 'c' elif left and bottom: return dist((x1, y1), (x2b, y2b)), 'c' elif bottom and right: return dist((x1b, y1), (x2, y2b)), 'c' elif right and top: return dist((x1b, y1b), (x2, y2)), 'c' elif left: return x1 - x2b, 'w' elif right: return x2 - x1b, 'w' elif bottom: return y1 - y2b, 'h' elif top: return y2 - y1b, 'h' else: # rectangles intersect return 0, 'i'