import re import random import pickle from bson.binary import Binary from copy import deepcopy import numpy as np from concurrent.futures import ThreadPoolExecutor from sentence_transformers import SentenceTransformer import config import config from main_clear.sci_clear import get_maplef_items # 按数据对应顺序随机打乱数据 def shuffle_data_pair(idx_list, vec_list): zip_list = list(zip(idx_list, vec_list)) random.shuffle(zip_list) idx_list, vec_list = zip(*zip_list) return idx_list, vec_list # 通用公有变量 public_topic_id = 0 # 数据预处理 class DataPreProcessing(): def __init__(self, mongo_coll, logger=None, is_train=False): self.mongo_coll = mongo_coll self.sbert_model = SentenceTransformer(config.sbert_path) self.is_train = is_train # 日志采集 self.logger = logger self.log_msg = config.log_msg # 主函数 def __call__(self, origin_dataset, hnsw_index, is_retrieve=False): # 句向量存储列表 sent_vec_list = [] # 批量处理数据字典 bp_dict = deepcopy(config.batch_processing_dict) if self.is_train is False: hnsw_index_list = [hnsw_index] * len(origin_dataset) with ThreadPoolExecutor(max_workers=5) as executor: executor_list = list(executor.map(self.content_clear_process, origin_dataset, hnsw_index_list)) cont_clear_tuple, cont_cut_tuple = zip(*executor_list) for data_idx, data in enumerate(origin_dataset): # 通用公有变量 global public_topic_id # 记录topic_id topic_id = data["topic_id"] if "topic_id" in data else data_idx + 1 public_topic_id = topic_id print(topic_id) if self.logger is None else None if self.is_train is True: # 数据清洗处理函数 content_clear, content_cut_list = self.content_clear_process(data, hnsw_index) # 根据self.is_train赋值content_clear, content_cut_list content_clear = content_clear if self.is_train else cont_clear_tuple[data_idx] content_cut_list = content_cut_list if self.is_train else cont_cut_tuple[data_idx] # 日志采集 self.logger.info(self.log_msg.format( id=topic_id, type="数据清洗结果", message=content_clear)) if self.logger and is_retrieve else None print(content_clear) if self.logger is None else None bp_dict["topic_id_list"].append(data["topic_id"]) if is_retrieve is False else None bp_dict["cont_clear_list"].append(content_clear) # 将所有截断数据融合进行一次句向量计算 bp_dict["cont_cut_list"].extend(content_cut_list) # 获取每条数据的截断长度 bp_dict["cut_idx_list"].append(bp_dict["cut_idx_list"][-1]+len(content_cut_list)) # 设置批量处理长度,若满足条件则进行批量处理 if (data_idx+1) % 5000 == 0: sent_vec_list = self.batch_processing(sent_vec_list, bp_dict, hnsw_index, is_retrieve) # 数据满足条件处理完毕后,则重置数据结构 bp_dict = deepcopy(config.batch_processing_dict) if len(bp_dict["cont_clear_list"]) > 0: sent_vec_list = self.batch_processing(sent_vec_list, bp_dict, hnsw_index, is_retrieve) return sent_vec_list, bp_dict["cont_clear_list"] # 数据批量处理计算句向量 def batch_processing(self, sent_vec_list, bp_dict, hnsw_index, is_retrieve): vec_list = self.sbert_model.encode(bp_dict["cont_cut_list"]) # 计算题目中每个句子的完整句向量 sent_length = len(bp_dict["cut_idx_list"]) - 1 for i in range(sent_length): sentence_vec = np.array([np.nan]) if bp_dict["cont_clear_list"][i] != '': # 平均池化 sentence_vec = np.sum(vec_list[bp_dict["cut_idx_list"][i]:bp_dict["cut_idx_list"][i+1]], axis=0) \ /(bp_dict["cut_idx_list"][i+1]-bp_dict["cut_idx_list"][i]) sent_vec_list.append(sentence_vec) if self.is_train is False else None # 将结果存入数据库 if is_retrieve is False: condition = {"topic_id": bp_dict["topic_id_list"][i]} # 用二进制存储句向量以节约存储空间 sentence_vec_byte = Binary(pickle.dumps(sentence_vec, protocol=-1), subtype=128) # 需要新增train_flag,防止机器奔溃重复训练 update_dict = {"content_clear": bp_dict["cont_clear_list"][i], "sentence_vec": sentence_vec_byte, "sent_train_flag": config.sent_train_flag} if hnsw_index == 0: update_dict["group_id"] = 0 self.mongo_coll[hnsw_index].update_one(condition, {"$set": update_dict}) return sent_vec_list # 清洗函数 def clear_func(self, content, hnsw_index): if content in {'', None}: return '' # 将content字符串化,防止content是int/float型 if isinstance(content, str) is False: if isinstance(content, int) or isinstance(content, float): return str(content) try: # 进行文本清洗 if "#$#" not in content: content_clear = get_maplef_items(content, hnsw_index, self.is_train) else: content_clear_split = content.split("#$#") content_clear_t = get_maplef_items(content_clear_split[0], hnsw_index, self.is_train) content_clear_x = get_maplef_items(content_clear_split[1], hnsw_index, self.is_train) content_clear = content_clear_t + content_clear_x except Exception as e: # 通用公有变量 global public_topic_id # 日志采集 print(self.log_msg.format(id=public_topic_id, type="清洗错误: "+str(e), message=str(content))) if self.logger is None else None self.logger.error(self.log_msg.format(id=public_topic_id, type="清洗错误: "+str(e), message=str(content))) if self.logger is not None else None # 对于无法清洗的文本通过正则表达式直接获取文本中的中文字符 content_clear = re.sub(r'[^\u4e00-\u9fa5]', '', content) return content_clear # 重叠截取长文本进行Sentence-Bert训练 def truncate_func(self, content): # 设置长文本截断长度 cut_length = 150 # 设置截断重叠长度 overlap = 10 content_cut_list = [] # 若文本长度小于等于截断长度,则取消截取直接返回 cont_length = len(content) if cont_length <= cut_length: content_cut_list = [content] return content_cut_list # 若文本长度大于截断长度,则进行重叠截断 # 设定文本截断尾部合并阈值(针对尾部文本根据长度进行合并) # 防止截断后出现极短文本影响模型效果 tail_merge_value = 0.5 * cut_length for i in range(0,cont_length,cut_length-overlap): tail_idx = i + cut_length cut_content = content[i:tail_idx] # 保留单词完整性 # 判断尾部字符 if cont_length - tail_idx > tail_merge_value: for j in range(len(cut_content)-1,-1,-1): # 判断当前字符是否为字母或者数字 # 若不是字母或者数字则截取成功 if re.search('[A-Za-z]', cut_content[j]) is None: cut_content = cut_content[:j+1] break else: cut_content = content[i:] # 判断头部字符 if i != 0: for k in range(len(cut_content)): # 判断当前字符是否为字母或者数字 # 若不是字母或者数字则截取成功 if re.search('[A-Za-z]', cut_content[k]) is None: cut_content = cut_content[k+1:] break # 将头部和尾部都处理好的截断文本存入content_cut_list content_cut_list.append(cut_content) # 针对尾部文本截断长度为140-150以及满足尾部合并阈值的文本 # 进行重叠截断进行特殊处理 if cont_length - tail_idx <= tail_merge_value: break return content_cut_list # 数据清洗处理函数 def content_clear_process(self, data, hnsw_index): # 全内容清洗组合列表 content_clear_list = [] if 'content' in data: content_clear_list.append(self.clear_func(data['content'], hnsw_index)) elif 'stem' in data: content_clear_list.append(self.clear_func(data['stem'], hnsw_index)) # 若题目中有小题,则对小题进行处理(递归实现) if 'slave' in data: content_clear_list = self.slave_func(data['slave'], content_clear_list) # 若题目中有选项,则对选项进行处理 if 'option' in data: content_clear_list.extend(self.option_func(data['option'], hnsw_index)) if 'options' in data: content_clear_list.extend(self.option_func(data['options'], hnsw_index)) # 去除文本中的空格以及空字符串 content_clear_list = [re.sub(r',+', ',', re.sub(r'[\s_]', '', content)) for content in content_clear_list] content_clear_list = [content for content in content_clear_list if content != ''] # 将清洗数据拼接 content_clear = ";".join(content_clear_list) # 去除题目开头"(多选)/(..分)" content_clear = re.sub(r'^(\([单多]选\)|\[[单多]选\])', '', content_clear) content_clear = re.sub(r'^(\(.*?\d{1,2}分.*?\)|\[.*?\d{1,2}分.*?\])', '', content_clear) # 去除题目开头(...年...[中模月]考)文本 head_search = re.search(r'^(\(.*?[\)\]]?\)|\[.*?[\)\]]?\])', content_clear) if head_search is not None and 5 < head_search.span(0)[1] < 40: head_value = content_clear[head_search.span(0)[0]+1:head_search.span(0)[1]-1] if re.search(r'.*?(\d{2}|[模检测训练考试验期省市县外第初高中学]).*?[模检测训练考试验期省市县外第初高中学].*?', head_value): content_clear = content_clear[head_search.span(0)[1]:].lstrip() # 将文本中的选项"A.B.C.D."改为";" content_clear = re.sub(r'[ABCD]\.', ';', content_clear) # 对于只有图片格式以及标点符号的信息进行特殊处理(去除标点符号/空格/连接符) if re.sub(r'[\.、。,;\:\?!#\-> ]+', '', content_clear) == '': content_clear = '' # 重叠截取长文本用于进行Sentence-Bert训练 content_cut_list = self.truncate_func(content_clear) return content_clear, content_cut_list # 小题处理函数(递归实现) def slave_func(self, slave_data, content_clear_list, hnsw_index): # 若小题列表为空,则返回content_clear_list if slave_data is None or len(slave_data) == 0: return content_clear_list for slave in slave_data: if 'content' in slave: content_clear_list.append(self.clear_func(slave['content'], hnsw_index)) if 'stem' in slave: content_clear_list.append(self.clear_func(slave['stem'], hnsw_index)) if 'option' in slave: content_clear_list.extend(self.option_func(slave['option'], hnsw_index)) if 'options' in slave: content_clear_list.extend(self.option_func(slave['options'], hnsw_index)) if 'slave' in slave: content_clear_list = self.slave_func(slave['slave'], content_clear_list, hnsw_index) return content_clear_list # 选项处理函数 def option_func(self, option_list, hnsw_index): # 若选项列表为空,则返回空列表 if option_list is None or len(option_list) == 0: return [] option_clear_list = [] for option in option_list: if isinstance(option, dict): if 'content' in option: option_clear_list.append(self.clear_func(option['content'], hnsw_index)) elif 'stem' in option: option_clear_list.append(self.clear_func(option['stem'], hnsw_index)) elif isinstance(option, str): option_clear_list.append(self.clear_func(option, hnsw_index)) elif isinstance(option, int) or isinstance(option, float): option_clear_list.append(str(option)) # 若选项中只有1,2,3或(1)(2)(3),或者只有标点符号,则过滤该选项 option_clear_list = [option for option in option_clear_list if re.sub(r'\(\d\)[、,;\.]?|(\d[、,;])+\d','',re.sub(r'\s','',option))!='' and re.sub(r'[\.、。,;\:\?!#\-> ]+','',option)!=''] return option_clear_list if __name__ == "__main__": # 获取mongodb数据 mongo_coll = config.mongo_coll_list test_data = [{'topic_id': '453368', 'topic_type_id': '1', 'subject_id': '3', 'stem': '
测试
', 'key': 'B', 'option': ['1
', '2
', '3
', '4
']}] dpp = DataPreProcessing(mongo_coll) # res = dpp.content_clear_process(test_data, hnsw_index=0) # print(res[0]) res = dpp(test_data, hnsw_index=0) print(res[1]) # print(dpp.content_clear_process(test_data[0], hnsw_index=0)[0]) # print(dpp(test_data, hnsw_index=0, is_retrieve=True))