predictor.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. import sys
  2. import re
  3. import time
  4. # sys.path.append(r'/home/cv/workspace/tujintao/document_segmentation')
  5. # from transformers import BertTokenizer
  6. from PointerNet.main import NerPipeline
  7. from PointerNet.model import UIEModel
  8. from Utils.washutil import simpwash, again_wash
  9. from PointerNet.config import NerArgs
  10. # from Utils.field_eq2latex import get_latex
  11. from Utils.answer_match import get_ans_match
  12. # from pprint import pprint
  13. # from washutil import HtmlWash
  14. class Predictor:
  15. def __init__(self, ner_args=None):
  16. model = UIEModel(ner_args)
  17. self.ner_pipeline = NerPipeline(model, ner_args)
  18. self.ner_pipeline.load_model()
  19. def predict(self, text, con_type="stem_block", paper_id="temp_{}".format(int(time.time())), subject="数学"):
  20. """
  21. 数据样本清洗后格式简化,如何与源文档建立联系?????
  22. >>>html文本先经过washutil.py得到row_list,再用换行符“\n”连接,再经过again_wash后再按“\n”切割得到all_sents,
  23. 需保证row_list和all_sents的长度相等,否则应该建立两者位置联系;通过模型切割对all_sents切割完后即可通过位置关系找到原始文本
  24. 发现:经washutil.py清洗后会丢掉一些特殊位置,如在题首的图片;
  25. 所以还是选用本项目中的清洗函数:simpwash和again_wash, “【公式】和【图片】”的原始数据要保留,且表格不能处理太简单,但公式也简化过了,不好还原
  26. """
  27. self.subject = subject
  28. # text, new_html = HtmlWash(text, '11111111',is_reparse=1,must_latex=1).html_cleal()
  29. item_str, _, _ = simpwash(text, paper_id)
  30. self.sents_with_imginfo, simply_sents = again_wash(item_str, paper_id)
  31. entities, split_topic_idx, topic_item_pred = self.ner_pipeline.half_batch_predict(text_list=simply_sents)
  32. # pprint(entities)
  33. if not split_topic_idx:
  34. return []
  35. if con_type == "stem_block":
  36. topic_item_res = self.pre_structure(split_topic_idx, topic_item_pred)
  37. # pprint(topic_item_res)
  38. return topic_item_res
  39. else:
  40. all_ans, ans_no = self.pre_ans_split(split_topic_idx)
  41. # print("all_ans::",all_ans)
  42. print(ans_no)
  43. return all_ans, ans_no
  44. def pre_ans_split(self, split_topic_idx):
  45. all_ans, ans_no = [], []
  46. contain_table_ans = False
  47. last_idx = 0
  48. new_ans = []
  49. for topic_idx in split_topic_idx:
  50. if topic_idx[0] - last_idx > 0:
  51. may_non_topic_content_list = self.sents_with_imginfo[last_idx: topic_idx[0]]
  52. print(may_non_topic_content_list)
  53. for may_non in may_non_topic_content_list:
  54. if re.match("<table>.*?</table>.?$", may_non):
  55. new_ans.append(may_non)
  56. last_idx = topic_idx[-1]
  57. item_list = self.sents_with_imginfo[topic_idx[0]: topic_idx[-1]]
  58. item_str = "\n".join(item_list)
  59. new_ans.append(item_str)
  60. for ans_str in new_ans:
  61. item_id_info = re.match(r"([1-9][0-9]?)(?![0-9\+\-\*\\/\)\}])", ans_str)
  62. item_id = int(item_id_info.group(1)) if item_id_info else None
  63. ans_str = re.sub("^([1-9][0-9]?)\s*[、..、]", "", ans_str.strip())
  64. ans_str = re.sub(r"^(\[.*?\])?[((].*?\d+分[))]", "", ans_str[:20]) + ans_str[20:]
  65. if not ans_str.strip():
  66. continue
  67. get_table_ans_success = 0
  68. if not ans_no or contain_table_ans: # 第一个或上一个是表格
  69. if "</table>" in ans_str:
  70. table_ans, table_ans_no = self.get_table_ans([ans_str])
  71. if table_ans_no:
  72. contain_table_ans = True
  73. ans_no.extend(table_ans_no)
  74. all_ans.extend(table_ans)
  75. get_table_ans_success = 1
  76. elif contain_table_ans:
  77. contain_table_ans = False
  78. if not get_table_ans_success:
  79. ans_no.append(item_id)
  80. all_ans.append(ans_str)
  81. return all_ans, ans_no
  82. def pre_structure(self, split_topic_idx, topic_item_pred):
  83. """
  84. 预结构化
  85. """
  86. split_res, all_ans, ans_no = [], [], []
  87. contain_table_ans = False
  88. last_idx = 0
  89. pre_topic_type = ""
  90. is_ans_divider = False # 参考答案分界线
  91. topic_no = []
  92. # 对开头遗漏部分判断
  93. if split_topic_idx[0][0]>1: # pre_topic_type
  94. start_idx = [n for n, v in enumerate(topic_item_pred) if v and n < split_topic_idx[0][0]]
  95. if start_idx and all([True if i else False for i in topic_item_pred[start_idx[0]:split_topic_idx[0][0]]]):
  96. split_topic_idx.insert(0, (start_idx[0], split_topic_idx[0][0]))
  97. for nn, topic_idx in enumerate(split_topic_idx):
  98. if topic_idx[0] - last_idx > 0:
  99. non_topic_content_list = self.sents_with_imginfo[last_idx: topic_idx[0]]
  100. print(non_topic_content_list)
  101. if not is_ans_divider:
  102. is_ans_divider = self.ans_divider_judge(non_topic_content_list)
  103. pre_topic_type = self.pre_judge_topic_type("\n"+"\n".join(non_topic_content_list))
  104. last_idx = topic_idx[-1]
  105. print(topic_idx[0],topic_idx[-1], pre_topic_type)
  106. item_list = self.sents_with_imginfo[topic_idx[0]: topic_idx[-1]]
  107. item_str = "\n".join(item_list)
  108. item_id_info = re.match(r"([1-9][0-9]?)(?![0-9\+\-\*\\/\)\}])", item_str)
  109. item_id = int(item_id_info.group(1)) if item_id_info else None
  110. # 试卷开头信息可能预测错误:样本丰富度不够
  111. if topic_idx[0] == 0 and topic_idx[-1] == 1 and item_id is None \
  112. and len(split_topic_idx)>1 and split_topic_idx[1][0]==2:
  113. continue
  114. item_str = re.sub("^([1-9][0-9]?)\s*[、..、]", "", item_str.strip())
  115. item_str = re.sub(r"^(\[.*?\])?[((].*?\d+分[))]", "", item_str[:20]) + item_str[20:]
  116. if not is_ans_divider:
  117. if item_str.strip() and len(item_str.strip()) > 5:
  118. split_res.append({
  119. "stem": item_str,
  120. "item_id": item_id,
  121. "type": pre_topic_type,
  122. "errmsgs": [],
  123. })
  124. topic_no.append(item_id)
  125. else:
  126. get_table_ans_success = 0
  127. if not ans_no or contain_table_ans: # 第一个或上一个是表格
  128. if "</table>" in item_str:
  129. table_ans, table_ans_no = self.get_table_ans(item_list)
  130. if table_ans_no:
  131. contain_table_ans = True
  132. ans_no.extend(table_ans_no)
  133. all_ans.extend(table_ans)
  134. get_table_ans_success = 1
  135. elif contain_table_ans:
  136. contain_table_ans = False
  137. if not get_table_ans_success and len(item_str.strip()) > 5:
  138. ans_no.append(item_id)
  139. all_ans.append(item_str)
  140. # 根据topic_no纠正一下分错的topic:不能纠正的太过,有的就是乱序的
  141. if topic_no:
  142. print("topic_no:::", topic_no)
  143. for ni,no in enumerate(topic_no):
  144. if 0 < ni < len(topic_no)-1 and no is None and topic_no[ni-1]:
  145. if topic_no[ni+1] and topic_no[ni+1] - topic_no[ni-1] == 1: # 两连续题号间隔着一个None
  146. split_res[ni-1]["stem"] += "\n"+split_res[ni]["stem"]
  147. split_res[ni] = ''
  148. else:
  149. for nj, no2 in enumerate(topic_no[ni+1:]):
  150. if no2 is None:
  151. continue
  152. # print("nj:", nj, ni)
  153. if no2 - topic_no[ni-1] == 1 and nj > 0: # 两连续题号间隔着多个None
  154. # print([y["stem"] for y in split_res[ni:ni+1+nj]])
  155. split_res[ni-1]["stem"] += "\n"+"\n".join([y["stem"] for y in split_res[ni:ni+1+nj]])
  156. split_res[ni:ni+1+nj] = ['']*(nj+1)
  157. break
  158. # 最后1题题号为None
  159. if re.search(r"\d+None$", "".join([str(i) for i in topic_no])):
  160. split_res[-2]["stem"] += "\n"+split_res[-1]["stem"]
  161. split_res[-1] = ''
  162. split_res = [res for res in split_res if res]
  163. # 试题、答案匹配
  164. split_res = get_ans_match(split_res, all_ans, ans_no)
  165. return split_res
  166. def get_table_ans(self, anss):
  167. # all_item_ans = [] # 前期只记录表格答案和排列型答案
  168. table_ans = []
  169. ans_no = [] # 只记录表格答案和排列型答案的id
  170. # 默认表格答案放在最前面 !!!
  171. while anss and "table" in anss[0]: # 答案以表格形式呈现, 表格应放在前两行位置,不要插在答案中间
  172. row_list = [] # 要求表格形式为 横纵分明 ,不存在合并
  173. for tt in re.finditer('<tr>(((?!(</?tr>)).)*)</tr>', anss[0], re.S): # 先划分每行
  174. tt_list = re.split(r'</p></td>|<td><p>|</td><td>|</td>|<td>', tt.group(1)) # 再划分每列
  175. # row_list.append([col for col in tt_list if col.strip()]) # 也有可能答案为空
  176. row_list.append(tt_list)
  177. if row_list:
  178. print("^^^^^^存在答案放在表格里的情况!^^^^^^^")
  179. is_other_table = False
  180. if len(row_list) % 2 != 0:
  181. print('表格形式呈现的答案不是偶数行')
  182. # 可能是填空题放在表格里,可以先去表格看看
  183. if "题号" in row_list[0] and len(row_list) > 1:
  184. row_list = [re.sub("^\s*([1-9]|[1-2][0-9])\s*([((]\s*\d{1,2}\s*分?\s*[))])?\s*$", r"\n\1、", v)
  185. for row in row_list[1:] for v in row]
  186. anss[0] = "\n" + " ".join(row_list)
  187. is_other_table = True
  188. elif re.search("^\s*([1-9]|[12][0-9])\s*[..、、::].+?", "".join(row_list[0])):
  189. temp_id = re.search("^\s*([1-9]|[12][0-9])\s*[..、、::].+?", "".join(row_list[0])).group(1)
  190. if ans_no and 1 <= int(temp_id) - ans_no[-1] <= 2:
  191. # anss[0] = re.sub(r'</?table>|</?p>|</?t[dr]>|</?tbody>', '', anss[0])
  192. anss[0] = "\n" + "\n".join(["".join(i) for i in row_list])
  193. is_other_table = True
  194. else:
  195. print("row_list:", row_list)
  196. for k, v in enumerate(row_list):
  197. # print('-----',v)
  198. if (k + 1) % 2 == 1: # 奇数行==》答案序号行
  199. item_no = []
  200. item_ans = []
  201. # special_item_no = [] # 特殊的题号 如14(1)
  202. for num, i in enumerate(v):
  203. if re.sub(r"[^\d]", "", i.strip()):
  204. no_info = re.match("(\d+)", i.strip())
  205. no = no_info.group(1) if no_info else 0
  206. special_no_info = re.match(r"(\d+)\s*([((\-_]\s*\d|[((\-_]?\s*[①②③④])", i.strip())
  207. if special_no_info:
  208. no = special_no_info.group(1)
  209. # special_item_no.append(int(i))
  210. if not item_no or int(no) != item_no[-1]:
  211. item_no.append(int(no))
  212. item_ans.append(row_list[k + 1][num])
  213. elif item_no and int(no) == item_no[-1]:
  214. item_ans[-1] += "#" + row_list[k + 1][num]
  215. else:
  216. if (not ans_no and int(no)) or (item_no and 1 <= int(no) - item_no[-1] <= 2) or \
  217. (not item_no and ans_no and 1 <= int(no) - ans_no[-1] <= 2):
  218. item_no.append(int(no))
  219. if k + 1 < len(row_list) and num < len(row_list[k + 1]):
  220. item_ans.append(row_list[k + 1][num])
  221. else:
  222. item_ans.append("")
  223. ans_no.extend(item_no)
  224. table_ans.extend(item_ans)
  225. if not is_other_table:
  226. anss = anss[1:]
  227. print("表格答案:", table_ans)
  228. print("表格答案题号:", ans_no)
  229. # all_item_ans.extend(table_ans)
  230. return table_ans, ans_no
  231. def ans_divider_judge(self, item_list):
  232. split_p1 = [k for k, v in enumerate(item_list)
  233. if re.match(r'(参考|试[题卷]|考试|物理|理综|数学|化学|生物)答案.{,5}$'
  234. r'|答案[和与及]?解析([((].*?[))])?$' # |答\s*案$
  235. r'|.{,15}(参考|考试|(考?试|检测)[题卷]|物理|理综|数学|化学|生物)(答案|解析|答案[及与和]评分(标准|意见|细则|参考))\s*$'
  236. r'|.{,15}评分(标准|意见|细则|参考)$'
  237. r'|((参考|(考?试|检测)[题卷]|考试|物理|理综|数学|化学|生物)答案|答案[和与及]解析)[\dA-E\s..、、]+$'
  238. r'|.{,15}(参考|考试|(考?试|检测)[题卷])(答案|解析|答案[及与和]评分(标准|意见|细则|参考))\s*(物理|理综|数学|化学|生物)?\s*$'
  239. r'|.{,15}解析[和与及]答案$',
  240. re.sub(r"[上下]?学[年期]|[\d—【】..·、、::(())年\s]|[中大]学|模拟|[中高]考|年级|[学期][末中]"
  241. r"|[高初][一二三]|部分", "", v.strip()))]
  242. if split_p1:
  243. return True
  244. return False
  245. def pre_judge_topic_type(self, item_str):
  246. # all_type = re.findall(r"\n\s*[一二三四五六七八九十]{,2}\s*[、..、]\s*([^必考基础综合中等]{2,4}题)") # no 非
  247. all_type = re.findall(r"\n\s*[^必考基础综合中等::()()例训某课】]{,3}\s*[、..、]\s*([^必考基础综合中等例训创某课】]{2,4}题)", item_str)
  248. attention_type = ["选择","单选","多选","单项选择","多项选择","双项选择","单空","多空","填空",
  249. "实验","简答","解答","综合", "作图", "判断","计算"]
  250. new_all_type = [type for type in all_type if type.replace("题", "") in attention_type]
  251. if new_all_type:
  252. if "数学" in self.subject:
  253. new_all_type = ["解答题" if type.replace("题", "") in ["实验","简答","解答","综合", "作图", "判断","计算"]
  254. else type for type in new_all_type]
  255. if "物理" in self.subject:
  256. new_all_type = ["解答题" if type.replace("题", "") in ["简答","综合","计算"] else type for type in new_all_type]
  257. elif all_type and all_type[0] == '非选择题':
  258. return "解答题"
  259. if new_all_type:
  260. return new_all_type[0]
  261. return ""
  262. if __name__ == '__main__':
  263. nerArgs = NerArgs()
  264. predict_tool = Predictor(nerArgs)
  265. path = "/home/cv/workspace/tujintao/document_segmentation/Data/samples/真实样例/663b246a1ec1003b58557467.html"
  266. f = open(path, 'r', encoding='utf8')
  267. text = f.read()
  268. # row_list, new_html, subs2table = HtmlWash(text, '11111111',is_reparse=1,must_latex=1).html_cleal()
  269. # row_list1 = list(filter(lambda x: x.strip() != "", row_list))
  270. # item_str, _, _ = simpwash(row_list, "", "")
  271. # item_str, _, _ = simpwash(text, "", "")
  272. # washed_text = again_wash(item_str)
  273. # washed_text2 = re.sub("【<img .*?\"\s*/?>公式latex提取失败】", "【公式】", washed_text)
  274. # washed_text2 = re.sub("<img .*?[\"']\s*/?>", "【图片】", washed_text2)
  275. # washed_text2 = re.sub(r'<[a-z]+ [a-z]+="[^<>]*?"\s*/?>', "", washed_text2)
  276. # item_str, _, _ = simpwash(text, "", "")
  277. # sents_with_imginfo, simply_sents = again_wash(item_str)
  278. # pprint(sents_with_imginfo)
  279. # print(len(simply_sents), len(sents_with_imginfo))
  280. # # print(len(row_list1))
  281. # print(len(washed_text.split("\n")))
  282. # print(len(washed_text2.split("\n")))
  283. entities = predict_tool.predict(text,con_type="stem_block", paper_id="6264fa25f84c0e279ac643ef", subject="物理")
  284. # item_str = '二、多选题:本题共3小题,每小题6分,共18分。在每小题给出的四个选项中,有多项符合题目要求。全部选对的得6分,选对但不全的得3分,有选错的得0分。'
  285. # item_str = '三、非选择题,共54分,考生根据要求作答。'
  286. # all_type = re.findall(r"\n\s*[^必考基础综合中等::()()例训某课】]{,3}\s*[、..、]\s*([^必考基础综合中等例训创某课】]{2,4}题)", "\n"+item_str)
  287. # print(all_type)