from .aggregator import aggregate from .commands import MATRICES, COMMANDS from .symbols_parser import convert_symbol from .preprocessing import format_latex, last_clear from warnings import warn def _convert_command(key, num, params): if key == r'\frac': fz, fm = convert_symbol(params[0]), convert_symbol(params[1]) fz = '(%s)' % (fz) if len(fz) > 1 else fz fm = '(%s)' % (fm) if len(fm) > 1 else fm # s = '%s/%s' % (fz, fm) if len(fz) == 1 and len(fm) == 1 else '(%s/%s)' % (fz, fm) s = '(%s/%s)' % (fz, fm) elif key == r'\sqrt': s = 'sqrt(%s)' % (convert_symbol(params[0])) elif key == r'_': if r'\log' in params or r'log' in params: # or r'\ln' in params or r'\lg' in params: s = '%s[%s]' % (convert_symbol(params[0]), convert_symbol(params[1])) else: s = '%s__%s' % (convert_symbol(params[0]), convert_symbol(params[1])) elif key == r'^': s = '%s^(%s)' % (convert_symbol(params[0]), convert_symbol(params[1])) if len( convert_symbol(params[1])) > 1 else '%s^%s' % (convert_symbol(params[0]), convert_symbol(params[1])) elif key == r'\right': s = convert_symbol(params[0]) if params[0] != '.' else '' elif key == r'\left': s = convert_symbol(params[0]) if params[0] != '.' else '' elif key == r'\text': s = '%s' % (convert_symbol(params[0])) elif key == r'\overrightarrow': s = '<%s>' % (convert_symbol(params[0])) elif key == r'\overrightarrowm': s = '<%s>' % (convert_symbol(params[0])) elif key == r'\overline': s = '一拔(%s)' % (convert_symbol(params[0])) elif key == r'\root': s = 'root%s(%s)' % (convert_symbol(params[1]),convert_symbol(params[0])) elif not num: s = '' else: s = '%s(%s)' % (key, convert_symbol(params[0])) return s def convert_command(element): # 对特殊符号进行操作 try: rs = '' element = [i for i in element if i] # 先得到关键词信息 iterable = iter(range(len(element))) for i in iterable: params = [] if element[i] in COMMANDS: key = element[i] param_num, tag, attributes = COMMANDS[element[i]] for _ in range(param_num): i += 1 params.append(element[i]) s = _convert_command(key, param_num, params) rs += s [next(iterable) for _ in range(param_num)] else: rs += convert_symbol(element[i]) return rs except: # tjt修改 warn('COMMANDS warn') command_res = ''.join(element) if len(command_res) > 0 and command_res[0] == '^': command_res = command_res[1:] return command_res # return ''.join(element) def clear_leaves(iters): iterable = iter(range(len(iters))) for i in iterable: if isinstance(iters[i], list): iters[i] = convert_command(iters[i]) else: iters[i] = iters[i] return iters def leaves(iter): if height(iter) == 3: return clear_leaves(iter) for i in range(len(iter)): if height(iter[i]) == 3: iter[i] = clear_leaves(iter[i]) elif height(iter[i]) > 3: leaves(iter[i]) else: leaves(iter) def height(self): max_child_height = 0 for child in self: if isinstance(child, list): max_child_height = max(max_child_height, height(child)) else: max_child_height = max(max_child_height, 1) return max_child_height + 1 def _remove_excess(iter): if isinstance(iter, list): if len(iter) == 1 and isinstance(iter[0], list): iter = iter[0] return _remove_excess(iter) return iter def remove_excess(iter): for i, j in enumerate(iter): x = _remove_excess(j) iter[i] = x if isinstance(j, list): remove_excess(j) def convert(latex): math = '' # 先加一个 mrow latex = aggregate(latex) remove_excess(latex) return [latex] def structured(latex): if latex=='\[\]' or latex=='' or latex=='$$': return '' latex= format_latex(latex) li = convert(latex) # tjt修改 if str(li).replace('[','').replace(']',''): leaves(li) else: # raise RecursionError warn('RecursionError warn') latex = latex.strip() if latex != '' and latex[0]+latex[-1] in {'[]','{}'}: latex = latex[1:-1] return latex return last_clear(convert_command(li)) if __name__ == '__main__': print(structured( r'$z = \left( {{m^2} - 5m + 6} \right) + \left( {m - 3} \right)i$'))