latex2maple.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. from .aggregator import aggregate
  2. from .commands import MATRICES, COMMANDS
  3. from .symbols_parser import convert_symbol
  4. from .preprocessing import format_latex, last_clear
  5. from warnings import warn
  6. def _convert_command(key, num, params):
  7. if key == r'\frac':
  8. fz, fm = convert_symbol(params[0]), convert_symbol(params[1])
  9. fz = '(%s)' % (fz) if len(fz) > 1 else fz
  10. fm = '(%s)' % (fm) if len(fm) > 1 else fm
  11. # s = '%s/%s' % (fz, fm) if len(fz) == 1 and len(fm) == 1 else '(%s/%s)' % (fz, fm)
  12. s = '(%s/%s)' % (fz, fm)
  13. elif key == r'\sqrt':
  14. s = 'sqrt(%s)' % (convert_symbol(params[0]))
  15. elif key == r'_':
  16. if r'\log' in params or r'log' in params: # or r'\ln' in params or r'\lg' in params:
  17. s = '%s[%s]' % (convert_symbol(params[0]), convert_symbol(params[1]))
  18. else:
  19. s = '%s__%s' % (convert_symbol(params[0]), convert_symbol(params[1]))
  20. elif key == r'^':
  21. s = '%s^(%s)' % (convert_symbol(params[0]), convert_symbol(params[1])) if len(
  22. convert_symbol(params[1])) > 1 else '%s^%s' % (convert_symbol(params[0]), convert_symbol(params[1]))
  23. elif key == r'\right':
  24. s = convert_symbol(params[0]) if params[0] != '.' else ''
  25. elif key == r'\left':
  26. s = convert_symbol(params[0]) if params[0] != '.' else ''
  27. elif key == r'\text':
  28. s = '%s' % (convert_symbol(params[0]))
  29. elif key == r'\overrightarrow':
  30. s = '<%s>' % (convert_symbol(params[0]))
  31. elif key == r'\overrightarrowm':
  32. s = '<%s>' % (convert_symbol(params[0]))
  33. elif key == r'\overline':
  34. s = '一拔(%s)' % (convert_symbol(params[0]))
  35. elif key == r'\root':
  36. s = 'root%s(%s)' % (convert_symbol(params[1]),convert_symbol(params[0]))
  37. elif not num:
  38. s = ''
  39. else:
  40. s = '%s(%s)' % (key, convert_symbol(params[0]))
  41. return s
  42. def convert_command(element):
  43. # 对特殊符号进行操作
  44. try:
  45. rs = ''
  46. element = [i for i in element if i]
  47. # 先得到关键词信息
  48. iterable = iter(range(len(element)))
  49. for i in iterable:
  50. params = []
  51. if element[i] in COMMANDS:
  52. key = element[i]
  53. param_num, tag, attributes = COMMANDS[element[i]]
  54. for _ in range(param_num):
  55. i += 1
  56. params.append(element[i])
  57. s = _convert_command(key, param_num, params)
  58. rs += s
  59. [next(iterable) for _ in range(param_num)]
  60. else:
  61. rs += convert_symbol(element[i])
  62. return rs
  63. except:
  64. # tjt修改
  65. warn('COMMANDS warn')
  66. command_res = ''.join(element)
  67. if len(command_res) > 0 and command_res[0] == '^':
  68. command_res = command_res[1:]
  69. return command_res
  70. # return ''.join(element)
  71. def clear_leaves(iters):
  72. iterable = iter(range(len(iters)))
  73. for i in iterable:
  74. if isinstance(iters[i], list):
  75. iters[i] = convert_command(iters[i])
  76. else:
  77. iters[i] = iters[i]
  78. return iters
  79. def leaves(iter):
  80. if height(iter) == 3:
  81. return clear_leaves(iter)
  82. for i in range(len(iter)):
  83. if height(iter[i]) == 3:
  84. iter[i] = clear_leaves(iter[i])
  85. elif height(iter[i]) > 3:
  86. leaves(iter[i])
  87. else:
  88. leaves(iter)
  89. def height(self):
  90. max_child_height = 0
  91. for child in self:
  92. if isinstance(child, list):
  93. max_child_height = max(max_child_height, height(child))
  94. else:
  95. max_child_height = max(max_child_height, 1)
  96. return max_child_height + 1
  97. def _remove_excess(iter):
  98. if isinstance(iter, list):
  99. if len(iter) == 1 and isinstance(iter[0], list):
  100. iter = iter[0]
  101. return _remove_excess(iter)
  102. return iter
  103. def remove_excess(iter):
  104. for i, j in enumerate(iter):
  105. x = _remove_excess(j)
  106. iter[i] = x
  107. if isinstance(j, list):
  108. remove_excess(j)
  109. def convert(latex):
  110. math = ''
  111. # 先加一个 mrow
  112. latex = aggregate(latex)
  113. remove_excess(latex)
  114. return [latex]
  115. def structured(latex):
  116. if latex=='\[\]' or latex=='' or latex=='$$':
  117. return ''
  118. latex= format_latex(latex)
  119. li = convert(latex)
  120. # tjt修改
  121. if str(li).replace('[','').replace(']',''):
  122. leaves(li)
  123. else:
  124. # raise RecursionError
  125. warn('RecursionError warn')
  126. latex = latex.strip()
  127. if latex != '' and latex[0]+latex[-1] in {'[]','{}'}:
  128. latex = latex[1:-1]
  129. return latex
  130. return last_clear(convert_command(li))
  131. if __name__ == '__main__':
  132. print(structured(
  133. r'$z = \left( {{m^2} - 5m + 6} \right) + \left( {m - 3} \right)i$'))