# @Author : lightXu
# @File : formula_segment.py
# @Time : 2019/1/24 0024 下午 13:24
import time
import re
import copy
import math
import cv2
import numpy as np
import xml.etree.cElementTree as ET
from segment.formula import mathpix_ocr
from segment.server import get_ocr_text_and_coordinate_formula
from segment.image_operation import utils
def get_coordinates(word_res, formula_words_list):
res_list = []
for formula_raw in formula_words_list:
coordinates_start_index = formula_raw[1][0]
coordinates_end_index = formula_raw[1][1] - 1
coordinates_start = word_res['chars'][coordinates_start_index]['location']
coordinates_end = word_res['chars'][coordinates_end_index]['location']
coordinates = (coordinates_start['left'], # xmin
min(coordinates_start['top'], coordinates_end['top']), # ymin
coordinates_end['left'] + coordinates_end['width'], # xmax
max(coordinates_start['top'] + coordinates_start['height'],
coordinates_end['top'] + coordinates_end['height'])) # ymax
tmp_dict = {'chars': formula_raw[0],
'raw_chars': formula_raw[0],
'coordinates': coordinates,
'middle': (coordinates[0] + int((coordinates[2] - coordinates[0]) // 2),
coordinates[1] + int((coordinates[3] - coordinates[1]) // 2))}
res_list.append(tmp_dict)
return res_list
def generate_char(words, index_pair, zh=True):
if index_pair:
# new_words = words.copy()
length = index_pair[1] - index_pair[0]
gen = ''
if zh:
for i in range(length):
gen = '中' + gen
else:
for i in range(length):
gen = 'F' + gen
words = words.replace(words[index_pair[0]:index_pair[1]], gen)
return words
else:
return words
def segment(img, save_path, access_token):
# raw_img = img.copy()
# img = utils.preprocess(raw_img, None)
word_result_list = get_ocr_text_and_coordinate_formula(img, access_token)
formula_coordinates_dict_list = []
zh_coordinates_dict_list = []
zh_char_height = 20 # default
zh_char_width = 15 # default
zh_char_height_list = []
zh_char_width_list = []
exclude = r'{}|{}|{}|{}|{}|{}'.format(
'^[((]*[\d]+[))]',
'[ABCD]\.',
'[\u4e00-\u9fa5][,;:。,;:.]',
'[①②③④⑤⑥⑦⑧⑨⑩]',
'[((][))]',
'[\u4e00-\u9fa5][\d]+[\u4e00-\u9fa5]')
for index, word_res in enumerate(word_result_list):
words = word_res['words'].replace(' ', '').replace('兀', 'π') # 去除空格,baidu_api bug
abcd_words_m = re.finditer(exclude, words)
abcd_index_list = [(m.group(), m.span()) for m in abcd_words_m if m]
words_tmp_zh = copy.copy(words)
for ele in abcd_index_list:
words_tmp_zh = generate_char(words_tmp_zh, ele[1], zh=True)
formula_words_m = re.finditer(r'[^\u4e00-\u9fa5._"“”]+', words_tmp_zh)
formula_index_list = [(m.group(), m.span()) for m in formula_words_m if m]
formula_list = get_coordinates(word_res, formula_index_list)
formula_coordinates_dict_list = formula_coordinates_dict_list + formula_list
words_tmp_formula = copy.copy(words)
for ele in abcd_index_list:
words_tmp_formula = generate_char(words_tmp_formula, ele[1], zh=False)
zh_words_m = re.finditer(r'[\u4e00-\u9fa5._"“”]+', words_tmp_formula)
zh_index_list = [(m.group(), m.span()) for m in zh_words_m if m]
zh_list = get_coordinates(word_res, zh_index_list + abcd_index_list)
zh_coordinates_dict_list = zh_coordinates_dict_list + zh_list
one_zh_char_m = re.match(r'[\u4e00-\u9fa5]+', words)
if one_zh_char_m:
index = one_zh_char_m.span()[0]
zh_char_height_list.append(word_res['chars'][index]['location']['height'])
zh_char_width_list.append(word_res['chars'][index]['location']['width'])
if len(zh_char_width_list) > 0 and len(zh_char_height_list) > 0:
zh_char_height = np.mean(zh_char_height_list)
zh_char_width = np.mean(zh_char_width_list)
formula_coordinates_list = [ele['coordinates'] for ele in formula_coordinates_dict_list]
formula_combine_list = combine(formula_coordinates_list, zh_char_height, zh_char_width) # 欧式距离
formula_combine_dict_list = []
for i, ele in enumerate(formula_combine_list):
middle = (ele[0] + int((ele[2] - ele[0]) // 2), ele[1] + int((ele[3] - ele[1]) // 2))
ocr_region = utils.crop_region_direct(img, ele)
y, x = ocr_region.shape[0], ocr_region.shape[1]
if min(y, x) <= 50:
ocr_region = utils.resize_by_percent(ocr_region, 2.00) # 放大若干倍
try:
mathpix_raw_chars, latex_confidence = mathpix_ocr.mathpix_api(ocr_region) # 识别公式
render_mathpix_chars = '{}'.format(mathpix_raw_chars)
if latex_confidence < 0.2:
for item in formula_coordinates_dict_list:
if ele == item['coordinates']:
mathpix_raw_chars = item['chars']
render_mathpix_chars = '{}'.format(item['chars'])
break
except Exception:
render_mathpix_chars = 'formula'
mathpix_raw_chars = 'formula'
for item in formula_coordinates_dict_list:
if ele == item['coordinates']:
mathpix_raw_chars = item['chars']
render_mathpix_chars = '{}'.format(item['chars'])
break
# print(render_mathpix_chars)
tmp_dict = {'chars': render_mathpix_chars, 'middle': middle, 'coordinates': ele, 'raw_chars': mathpix_raw_chars}
formula_combine_dict_list.append(tmp_dict)
# res_dict = {'formula': formula_combine_list, 'zh_chars': zh_coordinates_dict_list}
all_dict_list = formula_combine_dict_list + zh_coordinates_dict_list
all_dict_list = sorted(all_dict_list, key=lambda k: k.get('middle')[1])
# 相邻y做差
former = np.array([ele['middle'][1] for ele in all_dict_list[:-1]])
rear = np.array([ele['middle'][1] for ele in all_dict_list[1:]])
dif = rear - former
split_x_index = [index for index, ele in enumerate(dif) if ele >= zh_char_height] # y轴排序
if not split_x_index:
all_dict_list = sorted(all_dict_list, key=lambda k: k.get('middle')[0]) # x轴排序
lines = [ele['chars'] for ele in all_dict_list]
raw_lines = [ele['raw_chars'] for ele in all_dict_list]
return lines, raw_lines
else:
res_list = []
split_x_index = [ele + 1 for ele in split_x_index] # 索引值扩大
split_x_index.insert(0, 0)
split_x_index.insert(-1, len(all_dict_list))
split_x_index = sorted(list(set(split_x_index)))
for i, split in enumerate(split_x_index[1:]):
one_line = all_dict_list[split_x_index[i]:split_x_index[i + 1]]
one_line = sorted(one_line, key=lambda k: k.get('middle')[0]) # x轴排序
res_list.append(one_line)
lines = []
raw_lines = []
for ele in res_list:
line_chars = ''
raw_lines_chars = ''
for ele1 in ele:
chars = ele1['chars']
raw_chars = ele1['raw_chars']
line_chars = line_chars + chars
raw_lines_chars = raw_lines_chars + raw_chars
lines.append(line_chars + '\n')
raw_lines.append(raw_lines_chars + '\n')
# print(lines)
return lines, raw_lines
def combine(formula_coordinates_list, zh_char_height, zh_char_width):
formula_coordinates_list = sorted(formula_coordinates_list, key=lambda k: k[0])
formula_coordinates_list = sorted(formula_coordinates_list, key=lambda k: k[1]) # 先x轴,再y轴排序
recursion_flag = False
del_list = []
temp_list = formula_coordinates_list.copy()
for i, outer in enumerate(temp_list): # xmin, ymin, xmax, ymax
for j, inner in enumerate(temp_list): # xmin, ymin, xmax, ymax
if not i == j:
min_distance, flag = get_min_distance(outer, inner)
combine_coordinate = ()
if flag == 'i':
recursion_flag = True
combine_coordinate = (min(outer[0], inner[0]), min(outer[1], inner[1]),
max(outer[2], inner[2]), max(outer[3], inner[3]))
elif flag == 'h' and min_distance <= 1:
recursion_flag = True
combine_coordinate = (min(outer[0], inner[0]), min(outer[1], inner[1]),
max(outer[2], inner[2]), max(outer[3], inner[3]))
elif flag == 'w' and min_distance <= zh_char_width:
recursion_flag = True
combine_coordinate = (min(outer[0], inner[0]), min(outer[1], inner[1]),
max(outer[2], inner[2]), max(outer[3], inner[3]))
elif flag == 'c' and min_distance <= 1:
recursion_flag = True
combine_coordinate = (min(outer[0], inner[0]), min(outer[1], inner[1]),
max(outer[2], inner[2]), max(outer[3], inner[3]))
if combine_coordinate:
if not combine_coordinate == outer and not combine_coordinate == inner: # 避免全包围的情况
del_list.append(outer)
del_list.append(inner)
if combine_coordinate == outer:
del_list.append(inner)
if combine_coordinate == inner:
del_list.append(outer)
formula_coordinates_list.append(combine_coordinate)
res = list(set(formula_coordinates_list) - set(del_list))
if recursion_flag:
return combine(res, zh_char_height, zh_char_width)
else:
return res
def get_min_distance_square(coordinate1, coordinate2): # 顶点间欧式距离最小值的平方和
all_points1 = [(x, y) for x in [coordinate1[0], coordinate1[2]] for y in [coordinate1[1], coordinate1[3]]]
all_points2 = [(x, y) for x in [coordinate2[0], coordinate2[2]] for y in [coordinate2[1], coordinate2[3]]]
distance_list = []
for index1, point1 in enumerate(all_points1):
for index2, point2 in enumerate(all_points2):
distance = (point1[0] - point2[0]) ** 2 + (point1[1] - point2[1]) ** 2
distance_list.append(distance)
min_distance = min(distance_list)
return min_distance
def get_min_distance(coordinate1, coordinate2): # 欧式距离最小值
def dist(point1, point2):
distance = (point1[0] - point2[0]) ** 2 + (point1[1] - point2[1]) ** 2
return math.sqrt(distance)
(x1, y1, x1b, y1b) = coordinate1
(x2, y2, x2b, y2b) = coordinate2
left = x2b < x1 # 2在1的坐标左边
right = x1b < x2 # 2在1的坐标右边
bottom = y2b < y1 # 2在1的坐标下边
top = y1b < y2 # 2在1的坐标上边
if top and left:
return dist((x1, y1b), (x2b, y2)), 'c'
elif left and bottom:
return dist((x1, y1), (x2b, y2b)), 'c'
elif bottom and right:
return dist((x1b, y1), (x2, y2b)), 'c'
elif right and top:
return dist((x1b, y1b), (x2, y2)), 'c'
elif left:
return x1 - x2b, 'w'
elif right:
return x2 - x1b, 'w'
elif bottom:
return y1 - y2b, 'h'
elif top:
return y2 - y1b, 'h'
else: # rectangles intersect
return 0, 'i'