import sys
import re
import time
# sys.path.append(r'/home/cv/workspace/tujintao/document_segmentation')
# from transformers import BertTokenizer
from PointerNet.main import NerPipeline
from PointerNet.model import UIEModel
from Utils.washutil import simpwash, again_wash
from PointerNet.config import NerArgs
# from Utils.field_eq2latex import get_latex
from Utils.answer_match import get_ans_match
# from pprint import pprint
# from washutil import HtmlWash
class Predictor:
def __init__(self, ner_args=None):
model = UIEModel(ner_args)
self.ner_pipeline = NerPipeline(model, ner_args)
self.ner_pipeline.load_model()
def predict(self, text, con_type="stem_block", paper_id="temp_{}".format(int(time.time())), subject="数学"):
"""
数据样本清洗后格式简化,如何与源文档建立联系?????
>>>html文本先经过washutil.py得到row_list,再用换行符“\n”连接,再经过again_wash后再按“\n”切割得到all_sents,
需保证row_list和all_sents的长度相等,否则应该建立两者位置联系;通过模型切割对all_sents切割完后即可通过位置关系找到原始文本
发现:经washutil.py清洗后会丢掉一些特殊位置,如在题首的图片;
所以还是选用本项目中的清洗函数:simpwash和again_wash, “【公式】和【图片】”的原始数据要保留,且表格不能处理太简单,但公式也简化过了,不好还原
"""
self.subject = subject
# text, new_html = HtmlWash(text, '11111111',is_reparse=1,must_latex=1).html_cleal()
item_str, _, _ = simpwash(text, paper_id)
self.sents_with_imginfo, simply_sents = again_wash(item_str, paper_id)
entities, split_topic_idx, topic_item_pred = self.ner_pipeline.half_batch_predict(text_list=simply_sents)
# pprint(entities)
if not split_topic_idx:
return []
if con_type == "stem_block":
topic_item_res = self.pre_structure(split_topic_idx, topic_item_pred)
# pprint(topic_item_res)
return topic_item_res
else:
all_ans, ans_no = self.pre_ans_split(split_topic_idx)
# print("all_ans::",all_ans)
print(ans_no)
return all_ans, ans_no
def pre_ans_split(self, split_topic_idx):
all_ans, ans_no = [], []
contain_table_ans = False
last_idx = 0
new_ans = []
for topic_idx in split_topic_idx:
if topic_idx[0] - last_idx > 0:
may_non_topic_content_list = self.sents_with_imginfo[last_idx: topic_idx[0]]
print(may_non_topic_content_list)
for may_non in may_non_topic_content_list:
if re.match("
.?$", may_non):
new_ans.append(may_non)
last_idx = topic_idx[-1]
item_list = self.sents_with_imginfo[topic_idx[0]: topic_idx[-1]]
item_str = "\n".join(item_list)
new_ans.append(item_str)
for ans_str in new_ans:
item_id_info = re.match(r"([1-9][0-9]?)(?![0-9\+\-\*\\/\)\}])", ans_str)
item_id = int(item_id_info.group(1)) if item_id_info else None
ans_str = re.sub("^([1-9][0-9]?)\s*[、..、]", "", ans_str.strip())
ans_str = re.sub(r"^(\[.*?\])?[((].*?\d+分[))]", "", ans_str[:20]) + ans_str[20:]
if not ans_str.strip():
continue
get_table_ans_success = 0
if not ans_no or contain_table_ans: # 第一个或上一个是表格
if "" in ans_str:
table_ans, table_ans_no = self.get_table_ans([ans_str])
if table_ans_no:
contain_table_ans = True
ans_no.extend(table_ans_no)
all_ans.extend(table_ans)
get_table_ans_success = 1
elif contain_table_ans:
contain_table_ans = False
if not get_table_ans_success:
ans_no.append(item_id)
all_ans.append(ans_str)
return all_ans, ans_no
def pre_structure(self, split_topic_idx, topic_item_pred):
"""
预结构化
"""
split_res, all_ans, ans_no = [], [], []
contain_table_ans = False
last_idx = 0
pre_topic_type = ""
is_ans_divider = False # 参考答案分界线
topic_no = []
# 对开头遗漏部分判断
if split_topic_idx[0][0]>1: # pre_topic_type
start_idx = [n for n, v in enumerate(topic_item_pred) if v and n < split_topic_idx[0][0]]
if start_idx and all([True if i else False for i in topic_item_pred[start_idx[0]:split_topic_idx[0][0]]]):
split_topic_idx.insert(0, (start_idx[0], split_topic_idx[0][0]))
for nn, topic_idx in enumerate(split_topic_idx):
if topic_idx[0] - last_idx > 0:
non_topic_content_list = self.sents_with_imginfo[last_idx: topic_idx[0]]
print(non_topic_content_list)
if not is_ans_divider:
is_ans_divider = self.ans_divider_judge(non_topic_content_list)
pre_topic_type = self.pre_judge_topic_type("\n"+"\n".join(non_topic_content_list))
last_idx = topic_idx[-1]
print(topic_idx[0],topic_idx[-1], pre_topic_type)
item_list = self.sents_with_imginfo[topic_idx[0]: topic_idx[-1]]
item_str = "\n".join(item_list)
item_id_info = re.match(r"([1-9][0-9]?)(?![0-9\+\-\*\\/\)\}])", item_str)
item_id = int(item_id_info.group(1)) if item_id_info else None
# 试卷开头信息可能预测错误:样本丰富度不够
if topic_idx[0] == 0 and topic_idx[-1] == 1 and item_id is None \
and len(split_topic_idx)>1 and split_topic_idx[1][0]==2:
continue
item_str = re.sub("^([1-9][0-9]?)\s*[、..、]", "", item_str.strip())
item_str = re.sub(r"^(\[.*?\])?[((].*?\d+分[))]", "", item_str[:20]) + item_str[20:]
if not is_ans_divider:
if item_str.strip() and len(item_str.strip()) > 5:
split_res.append({
"stem": item_str,
"item_id": item_id,
"type": pre_topic_type,
"errmsgs": [],
})
topic_no.append(item_id)
else:
get_table_ans_success = 0
if not ans_no or contain_table_ans: # 第一个或上一个是表格
if "" in item_str:
table_ans, table_ans_no = self.get_table_ans(item_list)
if table_ans_no:
contain_table_ans = True
ans_no.extend(table_ans_no)
all_ans.extend(table_ans)
get_table_ans_success = 1
elif contain_table_ans:
contain_table_ans = False
if not get_table_ans_success and len(item_str.strip()) > 5:
ans_no.append(item_id)
all_ans.append(item_str)
# 根据topic_no纠正一下分错的topic:不能纠正的太过,有的就是乱序的
if topic_no:
print("topic_no:::", topic_no)
for ni,no in enumerate(topic_no):
if 0 < ni < len(topic_no)-1 and no is None and topic_no[ni-1]:
if topic_no[ni+1] and topic_no[ni+1] - topic_no[ni-1] == 1: # 两连续题号间隔着一个None
split_res[ni-1]["stem"] += "\n"+split_res[ni]["stem"]
split_res[ni] = ''
else:
for nj, no2 in enumerate(topic_no[ni+1:]):
if no2 is None:
continue
# print("nj:", nj, ni)
if no2 - topic_no[ni-1] == 1 and nj > 0: # 两连续题号间隔着多个None
# print([y["stem"] for y in split_res[ni:ni+1+nj]])
split_res[ni-1]["stem"] += "\n"+"\n".join([y["stem"] for y in split_res[ni:ni+1+nj]])
split_res[ni:ni+1+nj] = ['']*(nj+1)
break
# 最后1题题号为None
if re.search(r"\d+None$", "".join([str(i) for i in topic_no])):
split_res[-2]["stem"] += "\n"+split_res[-1]["stem"]
split_res[-1] = ''
split_res = [res for res in split_res if res]
# 试题、答案匹配
split_res = get_ans_match(split_res, all_ans, ans_no)
return split_res
def get_table_ans(self, anss):
# all_item_ans = [] # 前期只记录表格答案和排列型答案
table_ans = []
ans_no = [] # 只记录表格答案和排列型答案的id
# 默认表格答案放在最前面 !!!
while anss and "table" in anss[0]: # 答案以表格形式呈现, 表格应放在前两行位置,不要插在答案中间
row_list = [] # 要求表格形式为 横纵分明 ,不存在合并
for tt in re.finditer('(((?!(?tr>)).)*)
', anss[0], re.S): # 先划分每行
tt_list = re.split(r'|| | | | |', tt.group(1)) # 再划分每列
# row_list.append([col for col in tt_list if col.strip()]) # 也有可能答案为空
row_list.append(tt_list)
if row_list:
print("^^^^^^存在答案放在表格里的情况!^^^^^^^")
is_other_table = False
if len(row_list) % 2 != 0:
print('表格形式呈现的答案不是偶数行')
# 可能是填空题放在表格里,可以先去表格看看
if "题号" in row_list[0] and len(row_list) > 1:
row_list = [re.sub("^\s*([1-9]|[1-2][0-9])\s*([((]\s*\d{1,2}\s*分?\s*[))])?\s*$", r"\n\1、", v)
for row in row_list[1:] for v in row]
anss[0] = "\n" + " ".join(row_list)
is_other_table = True
elif re.search("^\s*([1-9]|[12][0-9])\s*[..、、::].+?", "".join(row_list[0])):
temp_id = re.search("^\s*([1-9]|[12][0-9])\s*[..、、::].+?", "".join(row_list[0])).group(1)
if ans_no and 1 <= int(temp_id) - ans_no[-1] <= 2:
# anss[0] = re.sub(r'?table>|?p>|?t[dr]>|?tbody>', '', anss[0])
anss[0] = "\n" + "\n".join(["".join(i) for i in row_list])
is_other_table = True
else:
print("row_list:", row_list)
for k, v in enumerate(row_list):
# print('-----',v)
if (k + 1) % 2 == 1: # 奇数行==》答案序号行
item_no = []
item_ans = []
# special_item_no = [] # 特殊的题号 如14(1)
for num, i in enumerate(v):
if re.sub(r"[^\d]", "", i.strip()):
no_info = re.match("(\d+)", i.strip())
no = no_info.group(1) if no_info else 0
special_no_info = re.match(r"(\d+)\s*([((\-_]\s*\d|[((\-_]?\s*[①②③④])", i.strip())
if special_no_info:
no = special_no_info.group(1)
# special_item_no.append(int(i))
if not item_no or int(no) != item_no[-1]:
item_no.append(int(no))
item_ans.append(row_list[k + 1][num])
elif item_no and int(no) == item_no[-1]:
item_ans[-1] += "#" + row_list[k + 1][num]
else:
if (not ans_no and int(no)) or (item_no and 1 <= int(no) - item_no[-1] <= 2) or \
(not item_no and ans_no and 1 <= int(no) - ans_no[-1] <= 2):
item_no.append(int(no))
if k + 1 < len(row_list) and num < len(row_list[k + 1]):
item_ans.append(row_list[k + 1][num])
else:
item_ans.append("")
ans_no.extend(item_no)
table_ans.extend(item_ans)
if not is_other_table:
anss = anss[1:]
print("表格答案:", table_ans)
print("表格答案题号:", ans_no)
# all_item_ans.extend(table_ans)
return table_ans, ans_no
def ans_divider_judge(self, item_list):
split_p1 = [k for k, v in enumerate(item_list)
if re.match(r'(参考|试[题卷]|考试|物理|理综|数学|化学|生物)答案.{,5}$'
r'|答案[和与及]?解析([((].*?[))])?$' # |答\s*案$
r'|.{,15}(参考|考试|(考?试|检测)[题卷]|物理|理综|数学|化学|生物)(答案|解析|答案[及与和]评分(标准|意见|细则|参考))\s*$'
r'|.{,15}评分(标准|意见|细则|参考)$'
r'|((参考|(考?试|检测)[题卷]|考试|物理|理综|数学|化学|生物)答案|答案[和与及]解析)[\dA-E\s..、、]+$'
r'|.{,15}(参考|考试|(考?试|检测)[题卷])(答案|解析|答案[及与和]评分(标准|意见|细则|参考))\s*(物理|理综|数学|化学|生物)?\s*$'
r'|.{,15}解析[和与及]答案$',
re.sub(r"[上下]?学[年期]|[\d—【】..·、、::(())年\s]|[中大]学|模拟|[中高]考|年级|[学期][末中]"
r"|[高初][一二三]|部分", "", v.strip()))]
if split_p1:
return True
return False
def pre_judge_topic_type(self, item_str):
# all_type = re.findall(r"\n\s*[一二三四五六七八九十]{,2}\s*[、..、]\s*([^必考基础综合中等]{2,4}题)") # no 非
all_type = re.findall(r"\n\s*[^必考基础综合中等::()()例训某课】]{,3}\s*[、..、]\s*([^必考基础综合中等例训创某课】]{2,4}题)", item_str)
attention_type = ["选择","单选","多选","单项选择","多项选择","双项选择","单空","多空","填空",
"实验","简答","解答","综合", "作图", "判断","计算"]
new_all_type = [type for type in all_type if type.replace("题", "") in attention_type]
if new_all_type:
if "数学" in self.subject:
new_all_type = ["解答题" if type.replace("题", "") in ["实验","简答","解答","综合", "作图", "判断","计算"]
else type for type in new_all_type]
if "物理" in self.subject:
new_all_type = ["解答题" if type.replace("题", "") in ["简答","综合","计算"] else type for type in new_all_type]
elif all_type and all_type[0] == '非选择题':
return "解答题"
if new_all_type:
return new_all_type[0]
return ""
if __name__ == '__main__':
nerArgs = NerArgs()
predict_tool = Predictor(nerArgs)
path = "/home/cv/workspace/tujintao/document_segmentation/Data/samples/真实样例/663b246a1ec1003b58557467.html"
f = open(path, 'r', encoding='utf8')
text = f.read()
# row_list, new_html, subs2table = HtmlWash(text, '11111111',is_reparse=1,must_latex=1).html_cleal()
# row_list1 = list(filter(lambda x: x.strip() != "", row_list))
# item_str, _, _ = simpwash(row_list, "", "")
# item_str, _, _ = simpwash(text, "", "")
# washed_text = again_wash(item_str)
# washed_text2 = re.sub("【公式latex提取失败】", "【公式】", washed_text)
# washed_text2 = re.sub("", "【图片】", washed_text2)
# washed_text2 = re.sub(r'<[a-z]+ [a-z]+="[^<>]*?"\s*/?>', "", washed_text2)
# item_str, _, _ = simpwash(text, "", "")
# sents_with_imginfo, simply_sents = again_wash(item_str)
# pprint(sents_with_imginfo)
# print(len(simply_sents), len(sents_with_imginfo))
# # print(len(row_list1))
# print(len(washed_text.split("\n")))
# print(len(washed_text2.split("\n")))
entities = predict_tool.predict(text,con_type="stem_block", paper_id="6264fa25f84c0e279ac643ef", subject="物理")
# item_str = '二、多选题:本题共3小题,每小题6分,共18分。在每小题给出的四个选项中,有多项符合题目要求。全部选对的得6分,选对但不全的得3分,有选错的得0分。'
# item_str = '三、非选择题,共54分,考生根据要求作答。'
# all_type = re.findall(r"\n\s*[^必考基础综合中等::()()例训某课】]{,3}\s*[、..、]\s*([^必考基础综合中等例训创某课】]{2,4}题)", "\n"+item_str)
# print(all_type)
|