import sys import re import time # sys.path.append(r'/home/cv/workspace/tujintao/document_segmentation') # from transformers import BertTokenizer from PointerNet.main import NerPipeline from PointerNet.model import UIEModel from Utils.washutil import simpwash, again_wash from PointerNet.config import NerArgs # from Utils.field_eq2latex import get_latex from Utils.answer_match import get_ans_match # from pprint import pprint # from washutil import HtmlWash class Predictor: def __init__(self, ner_args=None): model = UIEModel(ner_args) self.ner_pipeline = NerPipeline(model, ner_args) self.ner_pipeline.load_model() def predict(self, text, con_type="stem_block", paper_id="temp_{}".format(int(time.time())), subject="数学"): """ 数据样本清洗后格式简化,如何与源文档建立联系????? >>>html文本先经过washutil.py得到row_list,再用换行符“\n”连接,再经过again_wash后再按“\n”切割得到all_sents, 需保证row_list和all_sents的长度相等,否则应该建立两者位置联系;通过模型切割对all_sents切割完后即可通过位置关系找到原始文本 发现:经washutil.py清洗后会丢掉一些特殊位置,如在题首的图片; 所以还是选用本项目中的清洗函数:simpwash和again_wash, “【公式】和【图片】”的原始数据要保留,且表格不能处理太简单,但公式也简化过了,不好还原 """ self.subject = subject # text, new_html = HtmlWash(text, '11111111',is_reparse=1,must_latex=1).html_cleal() item_str, _, _ = simpwash(text, paper_id) self.sents_with_imginfo, simply_sents = again_wash(item_str, paper_id) entities, split_topic_idx, topic_item_pred = self.ner_pipeline.half_batch_predict(text_list=simply_sents) # pprint(entities) if not split_topic_idx: return [] if con_type == "stem_block": topic_item_res = self.pre_structure(split_topic_idx, topic_item_pred) # pprint(topic_item_res) return topic_item_res else: all_ans, ans_no = self.pre_ans_split(split_topic_idx) # print("all_ans::",all_ans) print(ans_no) return all_ans, ans_no def pre_ans_split(self, split_topic_idx): all_ans, ans_no = [], [] contain_table_ans = False last_idx = 0 new_ans = [] for topic_idx in split_topic_idx: if topic_idx[0] - last_idx > 0: may_non_topic_content_list = self.sents_with_imginfo[last_idx: topic_idx[0]] print(may_non_topic_content_list) for may_non in may_non_topic_content_list: if re.match(".*?
.?$", may_non): new_ans.append(may_non) last_idx = topic_idx[-1] item_list = self.sents_with_imginfo[topic_idx[0]: topic_idx[-1]] item_str = "\n".join(item_list) new_ans.append(item_str) for ans_str in new_ans: item_id_info = re.match(r"([1-9][0-9]?)(?![0-9\+\-\*\\/\)\}])", ans_str) item_id = int(item_id_info.group(1)) if item_id_info else None ans_str = re.sub("^([1-9][0-9]?)\s*[、..、]", "", ans_str.strip()) ans_str = re.sub(r"^(\[.*?\])?[((].*?\d+分[))]", "", ans_str[:20]) + ans_str[20:] if not ans_str.strip(): continue get_table_ans_success = 0 if not ans_no or contain_table_ans: # 第一个或上一个是表格 if "" in ans_str: table_ans, table_ans_no = self.get_table_ans([ans_str]) if table_ans_no: contain_table_ans = True ans_no.extend(table_ans_no) all_ans.extend(table_ans) get_table_ans_success = 1 elif contain_table_ans: contain_table_ans = False if not get_table_ans_success: ans_no.append(item_id) all_ans.append(ans_str) return all_ans, ans_no def pre_structure(self, split_topic_idx, topic_item_pred): """ 预结构化 """ split_res, all_ans, ans_no = [], [], [] contain_table_ans = False last_idx = 0 pre_topic_type = "" is_ans_divider = False # 参考答案分界线 topic_no = [] # 对开头遗漏部分判断 if split_topic_idx[0][0]>1: # pre_topic_type start_idx = [n for n, v in enumerate(topic_item_pred) if v and n < split_topic_idx[0][0]] if start_idx and all([True if i else False for i in topic_item_pred[start_idx[0]:split_topic_idx[0][0]]]): split_topic_idx.insert(0, (start_idx[0], split_topic_idx[0][0])) for nn, topic_idx in enumerate(split_topic_idx): if topic_idx[0] - last_idx > 0: non_topic_content_list = self.sents_with_imginfo[last_idx: topic_idx[0]] print(non_topic_content_list) if not is_ans_divider: is_ans_divider = self.ans_divider_judge(non_topic_content_list) pre_topic_type = self.pre_judge_topic_type("\n"+"\n".join(non_topic_content_list)) last_idx = topic_idx[-1] print(topic_idx[0],topic_idx[-1], pre_topic_type) item_list = self.sents_with_imginfo[topic_idx[0]: topic_idx[-1]] item_str = "\n".join(item_list) item_id_info = re.match(r"([1-9][0-9]?)(?![0-9\+\-\*\\/\)\}])", item_str) item_id = int(item_id_info.group(1)) if item_id_info else None # 试卷开头信息可能预测错误:样本丰富度不够 if topic_idx[0] == 0 and topic_idx[-1] == 1 and item_id is None \ and len(split_topic_idx)>1 and split_topic_idx[1][0]==2: continue item_str = re.sub("^([1-9][0-9]?)\s*[、..、]", "", item_str.strip()) item_str = re.sub(r"^(\[.*?\])?[((].*?\d+分[))]", "", item_str[:20]) + item_str[20:] if not is_ans_divider: if item_str.strip() and len(item_str.strip()) > 5: split_res.append({ "stem": item_str, "item_id": item_id, "type": pre_topic_type, "errmsgs": [], }) topic_no.append(item_id) else: get_table_ans_success = 0 if not ans_no or contain_table_ans: # 第一个或上一个是表格 if "" in item_str: table_ans, table_ans_no = self.get_table_ans(item_list) if table_ans_no: contain_table_ans = True ans_no.extend(table_ans_no) all_ans.extend(table_ans) get_table_ans_success = 1 elif contain_table_ans: contain_table_ans = False if not get_table_ans_success and len(item_str.strip()) > 5: ans_no.append(item_id) all_ans.append(item_str) # 根据topic_no纠正一下分错的topic:不能纠正的太过,有的就是乱序的 if topic_no: print("topic_no:::", topic_no) for ni,no in enumerate(topic_no): if 0 < ni < len(topic_no)-1 and no is None and topic_no[ni-1]: if topic_no[ni+1] and topic_no[ni+1] - topic_no[ni-1] == 1: # 两连续题号间隔着一个None split_res[ni-1]["stem"] += "\n"+split_res[ni]["stem"] split_res[ni] = '' else: for nj, no2 in enumerate(topic_no[ni+1:]): if no2 is None: continue # print("nj:", nj, ni) if no2 - topic_no[ni-1] == 1 and nj > 0: # 两连续题号间隔着多个None # print([y["stem"] for y in split_res[ni:ni+1+nj]]) split_res[ni-1]["stem"] += "\n"+"\n".join([y["stem"] for y in split_res[ni:ni+1+nj]]) split_res[ni:ni+1+nj] = ['']*(nj+1) break # 最后1题题号为None if re.search(r"\d+None$", "".join([str(i) for i in topic_no])): split_res[-2]["stem"] += "\n"+split_res[-1]["stem"] split_res[-1] = '' split_res = [res for res in split_res if res] # 试题、答案匹配 split_res = get_ans_match(split_res, all_ans, ans_no) return split_res def get_table_ans(self, anss): # all_item_ans = [] # 前期只记录表格答案和排列型答案 table_ans = [] ans_no = [] # 只记录表格答案和排列型答案的id # 默认表格答案放在最前面 !!! while anss and "table" in anss[0]: # 答案以表格形式呈现, 表格应放在前两行位置,不要插在答案中间 row_list = [] # 要求表格形式为 横纵分明 ,不存在合并 for tt in re.finditer('(((?!()).)*)', anss[0], re.S): # 先划分每行 tt_list = re.split(r'

|

|||', tt.group(1)) # 再划分每列 # row_list.append([col for col in tt_list if col.strip()]) # 也有可能答案为空 row_list.append(tt_list) if row_list: print("^^^^^^存在答案放在表格里的情况!^^^^^^^") is_other_table = False if len(row_list) % 2 != 0: print('表格形式呈现的答案不是偶数行') # 可能是填空题放在表格里,可以先去表格看看 if "题号" in row_list[0] and len(row_list) > 1: row_list = [re.sub("^\s*([1-9]|[1-2][0-9])\s*([((]\s*\d{1,2}\s*分?\s*[))])?\s*$", r"\n\1、", v) for row in row_list[1:] for v in row] anss[0] = "\n" + " ".join(row_list) is_other_table = True elif re.search("^\s*([1-9]|[12][0-9])\s*[..、、::].+?", "".join(row_list[0])): temp_id = re.search("^\s*([1-9]|[12][0-9])\s*[..、、::].+?", "".join(row_list[0])).group(1) if ans_no and 1 <= int(temp_id) - ans_no[-1] <= 2: # anss[0] = re.sub(r'|||', '', anss[0]) anss[0] = "\n" + "\n".join(["".join(i) for i in row_list]) is_other_table = True else: print("row_list:", row_list) for k, v in enumerate(row_list): # print('-----',v) if (k + 1) % 2 == 1: # 奇数行==》答案序号行 item_no = [] item_ans = [] # special_item_no = [] # 特殊的题号 如14(1) for num, i in enumerate(v): if re.sub(r"[^\d]", "", i.strip()): no_info = re.match("(\d+)", i.strip()) no = no_info.group(1) if no_info else 0 special_no_info = re.match(r"(\d+)\s*([((\-_]\s*\d|[((\-_]?\s*[①②③④])", i.strip()) if special_no_info: no = special_no_info.group(1) # special_item_no.append(int(i)) if not item_no or int(no) != item_no[-1]: item_no.append(int(no)) item_ans.append(row_list[k + 1][num]) elif item_no and int(no) == item_no[-1]: item_ans[-1] += "#" + row_list[k + 1][num] else: if (not ans_no and int(no)) or (item_no and 1 <= int(no) - item_no[-1] <= 2) or \ (not item_no and ans_no and 1 <= int(no) - ans_no[-1] <= 2): item_no.append(int(no)) if k + 1 < len(row_list) and num < len(row_list[k + 1]): item_ans.append(row_list[k + 1][num]) else: item_ans.append("") ans_no.extend(item_no) table_ans.extend(item_ans) if not is_other_table: anss = anss[1:] print("表格答案:", table_ans) print("表格答案题号:", ans_no) # all_item_ans.extend(table_ans) return table_ans, ans_no def ans_divider_judge(self, item_list): split_p1 = [k for k, v in enumerate(item_list) if re.match(r'(参考|试[题卷]|考试|物理|理综|数学|化学|生物)答案.{,5}$' r'|答案[和与及]?解析([((].*?[))])?$' # |答\s*案$ r'|.{,15}(参考|考试|(考?试|检测)[题卷]|物理|理综|数学|化学|生物)(答案|解析|答案[及与和]评分(标准|意见|细则|参考))\s*$' r'|.{,15}评分(标准|意见|细则|参考)$' r'|((参考|(考?试|检测)[题卷]|考试|物理|理综|数学|化学|生物)答案|答案[和与及]解析)[\dA-E\s..、、]+$' r'|.{,15}(参考|考试|(考?试|检测)[题卷])(答案|解析|答案[及与和]评分(标准|意见|细则|参考))\s*(物理|理综|数学|化学|生物)?\s*$' r'|.{,15}解析[和与及]答案$', re.sub(r"[上下]?学[年期]|[\d—【】..·、、::(())年\s]|[中大]学|模拟|[中高]考|年级|[学期][末中]" r"|[高初][一二三]|部分", "", v.strip()))] if split_p1: return True return False def pre_judge_topic_type(self, item_str): # all_type = re.findall(r"\n\s*[一二三四五六七八九十]{,2}\s*[、..、]\s*([^必考基础综合中等]{2,4}题)") # no 非 all_type = re.findall(r"\n\s*[^必考基础综合中等::()()例训某课】]{,3}\s*[、..、]\s*([^必考基础综合中等例训创某课】]{2,4}题)", item_str) attention_type = ["选择","单选","多选","单项选择","多项选择","双项选择","单空","多空","填空", "实验","简答","解答","综合", "作图", "判断","计算"] new_all_type = [type for type in all_type if type.replace("题", "") in attention_type] if new_all_type: if "数学" in self.subject: new_all_type = ["解答题" if type.replace("题", "") in ["实验","简答","解答","综合", "作图", "判断","计算"] else type for type in new_all_type] if "物理" in self.subject: new_all_type = ["解答题" if type.replace("题", "") in ["简答","综合","计算"] else type for type in new_all_type] elif all_type and all_type[0] == '非选择题': return "解答题" if new_all_type: return new_all_type[0] return "" if __name__ == '__main__': nerArgs = NerArgs() predict_tool = Predictor(nerArgs) path = "/home/cv/workspace/tujintao/document_segmentation/Data/samples/真实样例/663b246a1ec1003b58557467.html" f = open(path, 'r', encoding='utf8') text = f.read() # row_list, new_html, subs2table = HtmlWash(text, '11111111',is_reparse=1,must_latex=1).html_cleal() # row_list1 = list(filter(lambda x: x.strip() != "", row_list)) # item_str, _, _ = simpwash(row_list, "", "") # item_str, _, _ = simpwash(text, "", "") # washed_text = again_wash(item_str) # washed_text2 = re.sub("【公式latex提取失败】", "【公式】", washed_text) # washed_text2 = re.sub("", "【图片】", washed_text2) # washed_text2 = re.sub(r'<[a-z]+ [a-z]+="[^<>]*?"\s*/?>', "", washed_text2) # item_str, _, _ = simpwash(text, "", "") # sents_with_imginfo, simply_sents = again_wash(item_str) # pprint(sents_with_imginfo) # print(len(simply_sents), len(sents_with_imginfo)) # # print(len(row_list1)) # print(len(washed_text.split("\n"))) # print(len(washed_text2.split("\n"))) entities = predict_tool.predict(text,con_type="stem_block", paper_id="6264fa25f84c0e279ac643ef", subject="物理") # item_str = '二、多选题:本题共3小题,每小题6分,共18分。在每小题给出的四个选项中,有多项符合题目要求。全部选对的得6分,选对但不全的得3分,有选错的得0分。' # item_str = '三、非选择题,共54分,考生根据要求作答。' # all_type = re.findall(r"\n\s*[^必考基础综合中等::()()例训某课】]{,3}\s*[、..、]\s*([^必考基础综合中等例训创某课】]{2,4}题)", "\n"+item_str) # print(all_type)