import re
import random
import numpy as np
import pickle
from copy import deepcopy
from bson.binary import Binary
from concurrent.futures import ThreadPoolExecutor
from sentence_transformers import SentenceTransformer
import config
from main_clear.sci_clear import get_maplef_items
# 按数据对应顺序随机打乱数据
def shuffle_data_pair(idx_list, vec_list):
zip_list = list(zip(idx_list, vec_list))
random.shuffle(zip_list)
idx_list, vec_list = zip(*zip_list)
return idx_list, vec_list
# 通用公有变量
public_id = 0
# 数据预处理
class DataPreProcessing():
def __init__(self, mongo_coll=None, logger=None, is_train=False):
# 配置初始数据
self.mongo_coll = mongo_coll
self.sbert_model = SentenceTransformer(config.sbert_path)
self.is_train = is_train
# 日志采集
self.logger = logger
self.log_msg = config.log_msg
# 主函数
def __call__(self, origin_dataset, is_retrieve=False):
# 句向量存储列表
sent_vec_list = []
# 批量处理数据字典
bp_dict = deepcopy(config.batch_processing_dict)
# 批量数据清洗
if self.is_train is False:
with ThreadPoolExecutor(max_workers=5) as executor:
executor_list = list(executor.map(self.content_clear_process, origin_dataset))
cont_clear_tuple, cont_cut_tuple = zip(*executor_list)
for data_idx, data in enumerate(origin_dataset):
# 通用公有变量
global public_id
# 记录id
public_id = data["id"] if "id" in data else data_idx + 1
print(public_id) if self.logger is None else None
if self.is_train is True:
content_clear, content_cut_list = self.content_clear_process(data)
# 根据self.is_train赋值content_clear, content_cut_list
content_clear = content_clear if self.is_train else cont_clear_tuple[data_idx]
content_cut_list = content_cut_list if self.is_train else cont_cut_tuple[data_idx]
# 日志采集
self.logger.info(self.log_msg.format(
id=public_id,
type="数据清洗结果",
message=content_clear)) if self.logger and is_retrieve else None
print(content_clear) if self.logger is None else None
bp_dict["id_list"].append(data["id"]) if is_retrieve is False else None
bp_dict["cont_clear_list"].append(content_clear)
# 将所有截断数据融合进行一次句向量计算
bp_dict["cont_cut_list"].extend(content_cut_list)
# 获取每条数据的截断长度
bp_dict["cut_idx_list"].append(bp_dict["cut_idx_list"][-1] + len(content_cut_list))
# 设置批量处理长度,若满足条件则进行批量处理
if (data_idx+1) % 5000 == 0:
sent_vec_list = self.batch_processing(sent_vec_list, bp_dict, is_retrieve)
# 数据满足条件处理完毕后,则重置数据结构
bp_dict = deepcopy(config.batch_processing_dict)
if len(bp_dict["cont_clear_list"]) > 0:
sent_vec_list = self.batch_processing(sent_vec_list, bp_dict, is_retrieve)
return sent_vec_list, bp_dict["cont_clear_list"]
# 数据批量处理计算句向量
def batch_processing(self, sent_vec_list, bp_dict, is_retrieve):
vec_list = self.sbert_model.encode(bp_dict["cont_cut_list"])
# 计算题目中每个句子的完整句向量
sent_length = len(bp_dict["cut_idx_list"]) - 1
for i in range(sent_length):
sentence_vec = np.array([np.nan])
if bp_dict["cont_clear_list"][i] != '':
# 平均池化
sentence_vec = np.sum(vec_list[bp_dict["cut_idx_list"][i]:bp_dict["cut_idx_list"][i+1]], axis=0) \
/(bp_dict["cut_idx_list"][i+1]-bp_dict["cut_idx_list"][i])
sent_vec_list.append(sentence_vec) if self.is_train is False else None
# 将结果存入数据库
if is_retrieve is False:
condition = {"id": bp_dict["id_list"][i]}
# 用二进制存储句向量以节约存储空间
sentence_vec_byte = Binary(pickle.dumps(sentence_vec, protocol=-1), subtype=128)
# 需要新增train_flag,防止机器奔溃重复训练
update_elements = {"$set": {"content_clear": bp_dict["cont_clear_list"][i],
"sentence_vec": sentence_vec_byte,
"sent_train_flag": config.sent_train_flag}}
self.mongo_coll.update_one(condition, update_elements)
return sent_vec_list
# 清洗函数
def clear_func(self, content):
if content in {'', None}:
return ''
# 将content字符串化,防止content是int/float型
if isinstance(content, str) is False:
if isinstance(content, int) or isinstance(content, float):
return str(content)
try:
# 进行文本清洗
content_clear = get_maplef_items(content)
except Exception as e:
# 通用公有变量
global public_id
# 日志采集
print(self.log_msg.format(id=public_id,
type="清洗错误: "+str(e),
message=str(content))) if self.logger is None else None
self.logger.error(self.log_msg.format(id=public_id,
type="清洗错误: "+str(e),
message=str(content))) if self.logger is not None else None
# 对于无法清洗的文本通过正则表达式直接获取文本中的中文字符
content_clear = re.sub(r'[^\u4e00-\u9fa5]', '', content)
return content_clear
# 重叠截取长文本进行Sentence-Bert训练
def truncate_func(self, content):
# 设置长文本截断长度
cut_length = 150
# 设置截断重叠长度
overlap = 10
content_cut_list = []
# 若文本长度小于等于截断长度,则取消截取直接返回
cont_length = len(content)
if cont_length <= cut_length:
content_cut_list = [content]
return content_cut_list
# 若文本长度大于截断长度,则进行重叠截断
# 设定文本截断尾部合并阈值(针对尾部文本根据长度进行合并)
# 防止截断后出现极短文本影响模型效果
tail_merge_value = 0.5 * cut_length
for i in range(0,cont_length,cut_length-overlap):
tail_idx = i + cut_length
cut_content = content[i:tail_idx]
# 保留单词完整性
# 判断尾部字符
if cont_length - tail_idx > tail_merge_value:
for j in range(len(cut_content)-1,-1,-1):
# 判断当前字符是否为字母或者数字
# 若不是字母或者数字则截取成功
if re.search('[A-Za-z]', cut_content[j]) is None:
cut_content = cut_content[:j+1]
break
else:
cut_content = content[i:]
# 判断头部字符
if i != 0:
for k in range(len(cut_content)):
# 判断当前字符是否为字母或者数字
# 若不是字母或者数字则截取成功
if re.search('[A-Za-z]', cut_content[k]) is None:
cut_content = cut_content[k+1:]
break
# 将头部和尾部都处理好的截断文本存入content_cut_list
content_cut_list.append(cut_content)
# 针对尾部文本截断长度为140-150以及满足尾部合并阈值的文本
# 进行重叠截断进行特殊处理
if cont_length - tail_idx <= tail_merge_value:
break
return content_cut_list
# 全文本数据清洗
def content_clear_func(self, content):
# 文本清洗
content_clear = self.clear_func(content)
# 去除文本中的空格以及空字符串
content_clear = re.sub(r',+', ',', re.sub(r'[\s_]', '', content_clear))
# 去除题目开头"【题文】(多选)/(..分)"
content_clear = re.sub(r'\[题文\]', '', content_clear)
content_clear = re.sub(r'(\([单多]选\)|\[[单多]选\])', '', content_clear)
content_clear = re.sub(r'(\(\d{1,2}分\)|\[\d{1,2}分\])', '', content_clear)
# 将文本中的选项"A.B.C.D."改为";"
content_clear = re.sub(r'[ABCD]\.', ';', content_clear)
# # 若选项中只有1,2,3或(1)(2)(3),或者只有标点符号,则过滤该选项
# content_clear = re.sub(r'(\(\d\)[、,;\.]?)+\(\d\)|\d[、,;]+\d', '', content_clear)
# 去除题目开头(...年...[中模月]考)文本
head_search = re.search(r'^(\(.*?[\)\]]?\)|\[.*?[\)\]]?\])', content_clear)
if head_search is not None and 5 < head_search.span(0)[1] < 40:
head_value = content_clear[head_search.span(0)[0]+1:head_search.span(0)[1]-1]
if re.search(r'.*?(\d{2}|[模检测训练考试验期省市县外第初高中学]).*?[模检测训练考试验期省市县外第初高中学].*?', head_value):
content_clear = content_clear[head_search.span(0)[1]:].lstrip()
# 对于只有图片格式以及标点符号的信息进行特殊处理(去除标点符号/空格/连接符)
if re.sub(r'[\.、。,;\:\?!#\-> ]+', '', content_clear) == '':
content_clear = ''
return content_clear
# 数据清洗与长文本重叠截取处理
def content_clear_process(self, data):
# 初始化content_clear
content_clear = ''
# 全文本数据清洗
if "quesBody" in data:
content_clear = self.content_clear_func(data["quesBody"])
elif "stem" in data:
content_clear = self.content_clear_func(data["stem"])
# 重叠截取长文本用于进行Sentence-Bert训练
content_cut_list = self.truncate_func(content_clear)
return content_clear, content_cut_list
if __name__ == "__main__":
# 获取mongodb数据
mongo_coll = config.mongo_coll
test_data = {
'quesBody': """【题文】某同学设计了如下的电路测量电压表内阻,R为能够满足实验条件的滑动变阻器,为电阻箱,电压表量程合适。实验的粗略步骤如下:
①闭合开关S1、S2,调节滑动变阻器R,使电压表指针指向满刻度的处;
②断开开关S2,调节某些仪器,使电压表指针指向满刻度的处;
③读出电阻箱的阻值,该阻值即为电压表内阻的测量值;
④断开开关S1、S2拆下实验仪器,整理器材。
(1)上述实验步骤②中,调节某些仪器时,正确的操作是__________
A.保持电阻箱阻值不变,调节滑动变阻器的滑片,使电压表指针指向满刻度的处
B.保持滑动变阻器的滑片位置不变,调节电阻箱阻值,使电压表指针指向满刻度的处
C.同时调节滑动变阻器和电阻箱的阻值,使电压表指针指向满刻度的处
(2)此实验电压表内阻的测量值与真实值相比________(选填“偏大”“偏小”或“相等”);
(3)如实验测得该电压表内阻为8500Ω,要将其量程扩大为原来的倍,需串联_____Ω的电阻。""",
'option': ['$\\left\\{-2,0\\right\\}$', '$\\left\\{-2,0,2\\right\\}$', '$\\left\\{-1,1,2\\right\\}$', '$\\left\\{-1,0,2\\right\\}$']}
dpp = DataPreProcessing(mongo_coll)
string = """如图三角形的每个顶点均在格点上,且每个小正方形的边长为1.
(1)________;
(2)求的面积."""
string = """已知c水=4.2×103J/(kg·℃),求"""
res = dpp.clear_func(string)
print(res)
# res = dpp.content_clear_process(test_data)
# print(res[0])
# print(dpp.content_clear_process(mongo_coll.find_one({}))[0])
# print(dpp(test_data,is_retrieve=True))