123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164 |
- from .aggregator import aggregate
- from .commands import MATRICES, COMMANDS
- from .symbols_parser import convert_symbol
- from .preprocessing import format_latex, last_clear
- from warnings import warn
- def _convert_command(key, num, params):
- if key == r'\frac':
- fz, fm = convert_symbol(params[0]), convert_symbol(params[1])
- fz = '(%s)' % (fz) if len(fz) > 1 else fz
- fm = '(%s)' % (fm) if len(fm) > 1 else fm
- # s = '%s/%s' % (fz, fm) if len(fz) == 1 and len(fm) == 1 else '(%s/%s)' % (fz, fm)
- s = '(%s/%s)' % (fz, fm)
- elif key == r'\sqrt':
- s = 'sqrt(%s)' % (convert_symbol(params[0]))
- elif key == r'_':
- if r'\log' in params or r'log' in params: # or r'\ln' in params or r'\lg' in params:
- s = '%s[%s]' % (convert_symbol(params[0]), convert_symbol(params[1]))
- else:
- s = '%s__%s' % (convert_symbol(params[0]), convert_symbol(params[1]))
- elif key == r'^':
- s = '%s^(%s)' % (convert_symbol(params[0]), convert_symbol(params[1])) if len(
- convert_symbol(params[1])) > 1 else '%s^%s' % (convert_symbol(params[0]), convert_symbol(params[1]))
- elif key == r'\right':
- s = convert_symbol(params[0]) if params[0] != '.' else ''
- elif key == r'\left':
- s = convert_symbol(params[0]) if params[0] != '.' else ''
- elif key == r'\text':
- s = '%s' % (convert_symbol(params[0]))
- elif key == r'\overrightarrow':
- s = '<%s>' % (convert_symbol(params[0]))
- elif key == r'\overrightarrowm':
- s = '<%s>' % (convert_symbol(params[0]))
- elif key == r'\overline':
- s = '一拔(%s)' % (convert_symbol(params[0]))
- elif key == r'\root':
- s = 'root%s(%s)' % (convert_symbol(params[1]),convert_symbol(params[0]))
- elif not num:
- s = ''
- else:
- s = '%s(%s)' % (key, convert_symbol(params[0]))
- return s
- def convert_command(element):
- # 对特殊符号进行操作
- try:
- rs = ''
- element = [i for i in element if i]
- # 先得到关键词信息
- iterable = iter(range(len(element)))
- for i in iterable:
- params = []
- if element[i] in COMMANDS:
- key = element[i]
- param_num, tag, attributes = COMMANDS[element[i]]
- for _ in range(param_num):
- i += 1
- params.append(element[i])
- s = _convert_command(key, param_num, params)
- rs += s
- [next(iterable) for _ in range(param_num)]
- else:
- rs += convert_symbol(element[i])
- return rs
- except:
- # tjt修改
- warn('COMMANDS warn')
- command_res = ''.join(element)
- if len(command_res) > 0 and command_res[0] == '^':
- command_res = command_res[1:]
- return command_res
- # return ''.join(element)
- def clear_leaves(iters):
- iterable = iter(range(len(iters)))
- for i in iterable:
- if isinstance(iters[i], list):
- iters[i] = convert_command(iters[i])
- else:
- iters[i] = iters[i]
- return iters
- def leaves(iter):
- if height(iter) == 3:
- return clear_leaves(iter)
- for i in range(len(iter)):
- if height(iter[i]) == 3:
- iter[i] = clear_leaves(iter[i])
- elif height(iter[i]) > 3:
- leaves(iter[i])
- else:
- leaves(iter)
- def height(self):
- max_child_height = 0
- for child in self:
- if isinstance(child, list):
- max_child_height = max(max_child_height, height(child))
- else:
- max_child_height = max(max_child_height, 1)
- return max_child_height + 1
- def _remove_excess(iter):
- if isinstance(iter, list):
- if len(iter) == 1 and isinstance(iter[0], list):
- iter = iter[0]
- return _remove_excess(iter)
- return iter
- def remove_excess(iter):
- for i, j in enumerate(iter):
- x = _remove_excess(j)
- iter[i] = x
- if isinstance(j, list):
- remove_excess(j)
- def convert(latex):
- math = ''
- # 先加一个 mrow
- latex = aggregate(latex)
- remove_excess(latex)
- return [latex]
- def structured(latex):
- if latex=='\[\]' or latex=='' or latex=='$$':
- return ''
- latex= format_latex(latex)
- li = convert(latex)
- # tjt修改
- if str(li).replace('[','').replace(']',''):
- leaves(li)
- else:
- # raise RecursionError
- warn('RecursionError warn')
- latex = latex.strip()
- if latex != '' and latex[0]+latex[-1] in {'[]','{}'}:
- latex = latex[1:-1]
- return latex
- return last_clear(convert_command(li))
- if __name__ == '__main__':
- print(structured(
- r'$z = \left( {{m^2} - 5m + 6} \right) + \left( {m - 3} \right)i$'))
|