123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318 |
- import sys
- import re
- import time
- # sys.path.append(r'/home/cv/workspace/tujintao/document_segmentation')
- # from transformers import BertTokenizer
- from PointerNet.main import NerPipeline
- from PointerNet.model import UIEModel
- from Utils.washutil import simpwash, again_wash
- from PointerNet.config import NerArgs
- # from Utils.field_eq2latex import get_latex
- from Utils.answer_match import get_ans_match
- # from pprint import pprint
- # from washutil import HtmlWash
- class Predictor:
- def __init__(self, ner_args=None):
- model = UIEModel(ner_args)
- self.ner_pipeline = NerPipeline(model, ner_args)
- self.ner_pipeline.load_model()
- def predict(self, text, con_type="stem_block", paper_id="temp_{}".format(int(time.time())), subject="数学"):
- """
- 数据样本清洗后格式简化,如何与源文档建立联系?????
- >>>html文本先经过washutil.py得到row_list,再用换行符“\n”连接,再经过again_wash后再按“\n”切割得到all_sents,
- 需保证row_list和all_sents的长度相等,否则应该建立两者位置联系;通过模型切割对all_sents切割完后即可通过位置关系找到原始文本
- 发现:经washutil.py清洗后会丢掉一些特殊位置,如在题首的图片;
- 所以还是选用本项目中的清洗函数:simpwash和again_wash, “【公式】和【图片】”的原始数据要保留,且表格不能处理太简单,但公式也简化过了,不好还原
- """
- self.subject = subject
- # text, new_html = HtmlWash(text, '11111111',is_reparse=1,must_latex=1).html_cleal()
- item_str, _, _ = simpwash(text, paper_id)
- self.sents_with_imginfo, simply_sents = again_wash(item_str, paper_id)
- entities, split_topic_idx, topic_item_pred = self.ner_pipeline.half_batch_predict(text_list=simply_sents)
- # pprint(entities)
- if not split_topic_idx:
- return []
- if con_type == "stem_block":
- topic_item_res = self.pre_structure(split_topic_idx, topic_item_pred)
- # pprint(topic_item_res)
- return topic_item_res
- else:
- all_ans, ans_no = self.pre_ans_split(split_topic_idx)
- # print("all_ans::",all_ans)
- print(ans_no)
- return all_ans, ans_no
- def pre_ans_split(self, split_topic_idx):
- all_ans, ans_no = [], []
- contain_table_ans = False
- last_idx = 0
- new_ans = []
- for topic_idx in split_topic_idx:
- if topic_idx[0] - last_idx > 0:
- may_non_topic_content_list = self.sents_with_imginfo[last_idx: topic_idx[0]]
- print(may_non_topic_content_list)
- for may_non in may_non_topic_content_list:
- if re.match("<table>.*?</table>.?$", may_non):
- new_ans.append(may_non)
- last_idx = topic_idx[-1]
- item_list = self.sents_with_imginfo[topic_idx[0]: topic_idx[-1]]
- item_str = "\n".join(item_list)
- new_ans.append(item_str)
- for ans_str in new_ans:
- item_id_info = re.match(r"([1-9][0-9]?)(?![0-9\+\-\*\\/\)\}])", ans_str)
- item_id = int(item_id_info.group(1)) if item_id_info else None
- ans_str = re.sub("^([1-9][0-9]?)\s*[、..、]", "", ans_str.strip())
- ans_str = re.sub(r"^(\[.*?\])?[((].*?\d+分[))]", "", ans_str[:20]) + ans_str[20:]
- if not ans_str.strip():
- continue
- get_table_ans_success = 0
- if not ans_no or contain_table_ans: # 第一个或上一个是表格
- if "</table>" in ans_str:
- table_ans, table_ans_no = self.get_table_ans([ans_str])
- if table_ans_no:
- contain_table_ans = True
- ans_no.extend(table_ans_no)
- all_ans.extend(table_ans)
- get_table_ans_success = 1
- elif contain_table_ans:
- contain_table_ans = False
- if not get_table_ans_success:
- ans_no.append(item_id)
- all_ans.append(ans_str)
- return all_ans, ans_no
-
- def pre_structure(self, split_topic_idx, topic_item_pred):
- """
- 预结构化
- """
- split_res, all_ans, ans_no = [], [], []
- contain_table_ans = False
- last_idx = 0
- pre_topic_type = ""
- is_ans_divider = False # 参考答案分界线
- topic_no = []
- # 对开头遗漏部分判断
- if split_topic_idx[0][0]>1: # pre_topic_type
- start_idx = [n for n, v in enumerate(topic_item_pred) if v and n < split_topic_idx[0][0]]
- if start_idx and all([True if i else False for i in topic_item_pred[start_idx[0]:split_topic_idx[0][0]]]):
- split_topic_idx.insert(0, (start_idx[0], split_topic_idx[0][0]))
- for nn, topic_idx in enumerate(split_topic_idx):
- if topic_idx[0] - last_idx > 0:
- non_topic_content_list = self.sents_with_imginfo[last_idx: topic_idx[0]]
- print(non_topic_content_list)
- if not is_ans_divider:
- is_ans_divider = self.ans_divider_judge(non_topic_content_list)
- pre_topic_type = self.pre_judge_topic_type("\n"+"\n".join(non_topic_content_list))
- last_idx = topic_idx[-1]
- print(topic_idx[0],topic_idx[-1], pre_topic_type)
- item_list = self.sents_with_imginfo[topic_idx[0]: topic_idx[-1]]
- item_str = "\n".join(item_list)
- item_id_info = re.match(r"([1-9][0-9]?)(?![0-9\+\-\*\\/\)\}])", item_str)
- item_id = int(item_id_info.group(1)) if item_id_info else None
- # 试卷开头信息可能预测错误:样本丰富度不够
- if topic_idx[0] == 0 and topic_idx[-1] == 1 and item_id is None \
- and len(split_topic_idx)>1 and split_topic_idx[1][0]==2:
- continue
-
- item_str = re.sub("^([1-9][0-9]?)\s*[、..、]", "", item_str.strip())
- item_str = re.sub(r"^(\[.*?\])?[((].*?\d+分[))]", "", item_str[:20]) + item_str[20:]
- if not is_ans_divider:
- if item_str.strip() and len(item_str.strip()) > 5:
- split_res.append({
- "stem": item_str,
- "item_id": item_id,
- "type": pre_topic_type,
- "errmsgs": [],
- })
- topic_no.append(item_id)
- else:
- get_table_ans_success = 0
- if not ans_no or contain_table_ans: # 第一个或上一个是表格
- if "</table>" in item_str:
- table_ans, table_ans_no = self.get_table_ans(item_list)
- if table_ans_no:
- contain_table_ans = True
- ans_no.extend(table_ans_no)
- all_ans.extend(table_ans)
- get_table_ans_success = 1
- elif contain_table_ans:
- contain_table_ans = False
- if not get_table_ans_success and len(item_str.strip()) > 5:
- ans_no.append(item_id)
- all_ans.append(item_str)
- # 根据topic_no纠正一下分错的topic:不能纠正的太过,有的就是乱序的
- if topic_no:
- print("topic_no:::", topic_no)
- for ni,no in enumerate(topic_no):
- if 0 < ni < len(topic_no)-1 and no is None and topic_no[ni-1]:
- if topic_no[ni+1] and topic_no[ni+1] - topic_no[ni-1] == 1: # 两连续题号间隔着一个None
- split_res[ni-1]["stem"] += "\n"+split_res[ni]["stem"]
- split_res[ni] = ''
- else:
- for nj, no2 in enumerate(topic_no[ni+1:]):
- if no2 is None:
- continue
- # print("nj:", nj, ni)
- if no2 - topic_no[ni-1] == 1 and nj > 0: # 两连续题号间隔着多个None
- # print([y["stem"] for y in split_res[ni:ni+1+nj]])
- split_res[ni-1]["stem"] += "\n"+"\n".join([y["stem"] for y in split_res[ni:ni+1+nj]])
- split_res[ni:ni+1+nj] = ['']*(nj+1)
- break
- # 最后1题题号为None
- if re.search(r"\d+None$", "".join([str(i) for i in topic_no])):
- split_res[-2]["stem"] += "\n"+split_res[-1]["stem"]
- split_res[-1] = ''
-
- split_res = [res for res in split_res if res]
- # 试题、答案匹配
- split_res = get_ans_match(split_res, all_ans, ans_no)
- return split_res
-
- def get_table_ans(self, anss):
- # all_item_ans = [] # 前期只记录表格答案和排列型答案
- table_ans = []
- ans_no = [] # 只记录表格答案和排列型答案的id
- # 默认表格答案放在最前面 !!!
- while anss and "table" in anss[0]: # 答案以表格形式呈现, 表格应放在前两行位置,不要插在答案中间
- row_list = [] # 要求表格形式为 横纵分明 ,不存在合并
- for tt in re.finditer('<tr>(((?!(</?tr>)).)*)</tr>', anss[0], re.S): # 先划分每行
- tt_list = re.split(r'</p></td>|<td><p>|</td><td>|</td>|<td>', tt.group(1)) # 再划分每列
- # row_list.append([col for col in tt_list if col.strip()]) # 也有可能答案为空
- row_list.append(tt_list)
- if row_list:
- print("^^^^^^存在答案放在表格里的情况!^^^^^^^")
- is_other_table = False
- if len(row_list) % 2 != 0:
- print('表格形式呈现的答案不是偶数行')
- # 可能是填空题放在表格里,可以先去表格看看
- if "题号" in row_list[0] and len(row_list) > 1:
- row_list = [re.sub("^\s*([1-9]|[1-2][0-9])\s*([((]\s*\d{1,2}\s*分?\s*[))])?\s*$", r"\n\1、", v)
- for row in row_list[1:] for v in row]
- anss[0] = "\n" + " ".join(row_list)
- is_other_table = True
- elif re.search("^\s*([1-9]|[12][0-9])\s*[..、、::].+?", "".join(row_list[0])):
- temp_id = re.search("^\s*([1-9]|[12][0-9])\s*[..、、::].+?", "".join(row_list[0])).group(1)
- if ans_no and 1 <= int(temp_id) - ans_no[-1] <= 2:
- # anss[0] = re.sub(r'</?table>|</?p>|</?t[dr]>|</?tbody>', '', anss[0])
- anss[0] = "\n" + "\n".join(["".join(i) for i in row_list])
- is_other_table = True
- else:
- print("row_list:", row_list)
- for k, v in enumerate(row_list):
- # print('-----',v)
- if (k + 1) % 2 == 1: # 奇数行==》答案序号行
- item_no = []
- item_ans = []
- # special_item_no = [] # 特殊的题号 如14(1)
- for num, i in enumerate(v):
- if re.sub(r"[^\d]", "", i.strip()):
- no_info = re.match("(\d+)", i.strip())
- no = no_info.group(1) if no_info else 0
- special_no_info = re.match(r"(\d+)\s*([((\-_]\s*\d|[((\-_]?\s*[①②③④])", i.strip())
- if special_no_info:
- no = special_no_info.group(1)
- # special_item_no.append(int(i))
- if not item_no or int(no) != item_no[-1]:
- item_no.append(int(no))
- item_ans.append(row_list[k + 1][num])
- elif item_no and int(no) == item_no[-1]:
- item_ans[-1] += "#" + row_list[k + 1][num]
- else:
- if (not ans_no and int(no)) or (item_no and 1 <= int(no) - item_no[-1] <= 2) or \
- (not item_no and ans_no and 1 <= int(no) - ans_no[-1] <= 2):
- item_no.append(int(no))
- if k + 1 < len(row_list) and num < len(row_list[k + 1]):
- item_ans.append(row_list[k + 1][num])
- else:
- item_ans.append("")
- ans_no.extend(item_no)
- table_ans.extend(item_ans)
- if not is_other_table:
- anss = anss[1:]
- print("表格答案:", table_ans)
- print("表格答案题号:", ans_no)
- # all_item_ans.extend(table_ans)
- return table_ans, ans_no
-
- def ans_divider_judge(self, item_list):
- split_p1 = [k for k, v in enumerate(item_list)
- if re.match(r'(参考|试[题卷]|考试|物理|理综|数学|化学|生物)答案.{,5}$'
- r'|答案[和与及]?解析([((].*?[))])?$' # |答\s*案$
- r'|.{,15}(参考|考试|(考?试|检测)[题卷]|物理|理综|数学|化学|生物)(答案|解析|答案[及与和]评分(标准|意见|细则|参考))\s*$'
- r'|.{,15}评分(标准|意见|细则|参考)$'
- r'|((参考|(考?试|检测)[题卷]|考试|物理|理综|数学|化学|生物)答案|答案[和与及]解析)[\dA-E\s..、、]+$'
- r'|.{,15}(参考|考试|(考?试|检测)[题卷])(答案|解析|答案[及与和]评分(标准|意见|细则|参考))\s*(物理|理综|数学|化学|生物)?\s*$'
- r'|.{,15}解析[和与及]答案$',
- re.sub(r"[上下]?学[年期]|[\d—【】..·、、::(())年\s]|[中大]学|模拟|[中高]考|年级|[学期][末中]"
- r"|[高初][一二三]|部分", "", v.strip()))]
- if split_p1:
- return True
- return False
- def pre_judge_topic_type(self, item_str):
- # all_type = re.findall(r"\n\s*[一二三四五六七八九十]{,2}\s*[、..、]\s*([^必考基础综合中等]{2,4}题)") # no 非
- all_type = re.findall(r"\n\s*[^必考基础综合中等::()()例训某课】]{,3}\s*[、..、]\s*([^必考基础综合中等例训创某课】]{2,4}题)", item_str)
- attention_type = ["选择","单选","多选","单项选择","多项选择","双项选择","单空","多空","填空",
- "实验","简答","解答","综合", "作图", "判断","计算"]
- new_all_type = [type for type in all_type if type.replace("题", "") in attention_type]
- if new_all_type:
- if "数学" in self.subject:
- new_all_type = ["解答题" if type.replace("题", "") in ["实验","简答","解答","综合", "作图", "判断","计算"]
- else type for type in new_all_type]
- if "物理" in self.subject:
- new_all_type = ["解答题" if type.replace("题", "") in ["简答","综合","计算"] else type for type in new_all_type]
- elif all_type and all_type[0] == '非选择题':
- return "解答题"
- if new_all_type:
- return new_all_type[0]
-
- return ""
- if __name__ == '__main__':
- nerArgs = NerArgs()
- predict_tool = Predictor(nerArgs)
- path = "/home/cv/workspace/tujintao/document_segmentation/Data/samples/真实样例/663b246a1ec1003b58557467.html"
- f = open(path, 'r', encoding='utf8')
- text = f.read()
- # row_list, new_html, subs2table = HtmlWash(text, '11111111',is_reparse=1,must_latex=1).html_cleal()
- # row_list1 = list(filter(lambda x: x.strip() != "", row_list))
- # item_str, _, _ = simpwash(row_list, "", "")
- # item_str, _, _ = simpwash(text, "", "")
- # washed_text = again_wash(item_str)
- # washed_text2 = re.sub("【<img .*?\"\s*/?>公式latex提取失败】", "【公式】", washed_text)
- # washed_text2 = re.sub("<img .*?[\"']\s*/?>", "【图片】", washed_text2)
- # washed_text2 = re.sub(r'<[a-z]+ [a-z]+="[^<>]*?"\s*/?>', "", washed_text2)
- # item_str, _, _ = simpwash(text, "", "")
- # sents_with_imginfo, simply_sents = again_wash(item_str)
- # pprint(sents_with_imginfo)
- # print(len(simply_sents), len(sents_with_imginfo))
- # # print(len(row_list1))
- # print(len(washed_text.split("\n")))
- # print(len(washed_text2.split("\n")))
- entities = predict_tool.predict(text,con_type="stem_block", paper_id="6264fa25f84c0e279ac643ef", subject="物理")
- # item_str = '二、多选题:本题共3小题,每小题6分,共18分。在每小题给出的四个选项中,有多项符合题目要求。全部选对的得6分,选对但不全的得3分,有选错的得0分。'
- # item_str = '三、非选择题,共54分,考生根据要求作答。'
- # all_type = re.findall(r"\n\s*[^必考基础综合中等::()()例训某课】]{,3}\s*[、..、]\s*([^必考基础综合中等例训创某课】]{2,4}题)", "\n"+item_str)
- # print(all_type)
|