123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305 |
- import re
- import random
- import pickle
- from bson.binary import Binary
- from copy import deepcopy
- import numpy as np
- from concurrent.futures import ThreadPoolExecutor
- from sentence_transformers import SentenceTransformer
- import config
- import config
- from main_clear.sci_clear import get_maplef_items
- # 按数据对应顺序随机打乱数据
- def shuffle_data_pair(idx_list, vec_list):
- zip_list = list(zip(idx_list, vec_list))
- random.shuffle(zip_list)
- idx_list, vec_list = zip(*zip_list)
- return idx_list, vec_list
- # 通用公有变量
- public_topic_id = 0
- # 数据预处理
- class DataPreProcessing():
- def __init__(self, mongo_coll, logger=None, is_train=False):
- self.mongo_coll = mongo_coll
- self.sbert_model = SentenceTransformer(config.sbert_path)
- self.is_train = is_train
- # 日志采集
- self.logger = logger
- self.log_msg = config.log_msg
- # 主函数
- def __call__(self, origin_dataset, hnsw_index, is_retrieve=False):
- # 句向量存储列表
- sent_vec_list = []
- # 批量处理数据字典
- bp_dict = deepcopy(config.batch_processing_dict)
- if self.is_train is False:
- hnsw_index_list = [hnsw_index] * len(origin_dataset)
- with ThreadPoolExecutor(max_workers=5) as executor:
- executor_list = list(executor.map(self.content_clear_process, origin_dataset, hnsw_index_list))
- cont_clear_tuple, cont_cut_tuple = zip(*executor_list)
-
- for data_idx, data in enumerate(origin_dataset):
- # 通用公有变量
- global public_topic_id
- # 记录topic_id
- topic_id = data["topic_id"] if "topic_id" in data else data_idx + 1
- public_topic_id = topic_id
-
- print(topic_id) if self.logger is None else None
- if self.is_train is True:
- # 数据清洗处理函数
- content_clear, content_cut_list = self.content_clear_process(data, hnsw_index)
- # 根据self.is_train赋值content_clear, content_cut_list
- content_clear = content_clear if self.is_train else cont_clear_tuple[data_idx]
- content_cut_list = content_cut_list if self.is_train else cont_cut_tuple[data_idx]
- # 日志采集
- self.logger.info(self.log_msg.format(
- id=topic_id,
- type="数据清洗结果",
- message=content_clear)) if self.logger and is_retrieve else None
- print(content_clear) if self.logger is None else None
- bp_dict["topic_id_list"].append(data["topic_id"]) if is_retrieve is False else None
- bp_dict["cont_clear_list"].append(content_clear)
- # 将所有截断数据融合进行一次句向量计算
- bp_dict["cont_cut_list"].extend(content_cut_list)
- # 获取每条数据的截断长度
- bp_dict["cut_idx_list"].append(bp_dict["cut_idx_list"][-1]+len(content_cut_list))
- # 设置批量处理长度,若满足条件则进行批量处理
- if (data_idx+1) % 5000 == 0:
- sent_vec_list = self.batch_processing(sent_vec_list, bp_dict, hnsw_index, is_retrieve)
- # 数据满足条件处理完毕后,则重置数据结构
- bp_dict = deepcopy(config.batch_processing_dict)
-
- if len(bp_dict["cont_clear_list"]) > 0:
- sent_vec_list = self.batch_processing(sent_vec_list, bp_dict, hnsw_index, is_retrieve)
- return sent_vec_list, bp_dict["cont_clear_list"]
- # 数据批量处理计算句向量
- def batch_processing(self, sent_vec_list, bp_dict, hnsw_index, is_retrieve):
- vec_list = self.sbert_model.encode(bp_dict["cont_cut_list"])
- # 计算题目中每个句子的完整句向量
- sent_length = len(bp_dict["cut_idx_list"]) - 1
- for i in range(sent_length):
- sentence_vec = np.array([np.nan])
- if bp_dict["cont_clear_list"][i] != '':
- # 平均池化
- sentence_vec = np.sum(vec_list[bp_dict["cut_idx_list"][i]:bp_dict["cut_idx_list"][i+1]], axis=0) \
- /(bp_dict["cut_idx_list"][i+1]-bp_dict["cut_idx_list"][i])
- sent_vec_list.append(sentence_vec) if self.is_train is False else None
- # 将结果存入数据库
- if is_retrieve is False:
- condition = {"topic_id": bp_dict["topic_id_list"][i]}
- # 用二进制存储句向量以节约存储空间
- sentence_vec_byte = Binary(pickle.dumps(sentence_vec, protocol=-1), subtype=128)
- # 需要新增train_flag,防止机器奔溃重复训练
- update_dict = {"content_clear": bp_dict["cont_clear_list"][i],
- "sentence_vec": sentence_vec_byte,
- "sent_train_flag": config.sent_train_flag}
- if hnsw_index == 0:
- update_dict["group_id"] = 0
- self.mongo_coll[hnsw_index].update_one(condition, {"$set": update_dict})
- return sent_vec_list
- # 清洗函数
- def clear_func(self, content, hnsw_index):
- if content in {'', None}:
- return ''
- # 将content字符串化,防止content是int/float型
- if isinstance(content, str) is False:
- if isinstance(content, int) or isinstance(content, float):
- return str(content)
- try:
- # 进行文本清洗
- if "#$#" not in content:
- content_clear = get_maplef_items(content, hnsw_index, self.is_train)
- else:
- content_clear_split = content.split("#$#")
- content_clear_t = get_maplef_items(content_clear_split[0], hnsw_index, self.is_train)
- content_clear_x = get_maplef_items(content_clear_split[1], hnsw_index, self.is_train)
- content_clear = content_clear_t + content_clear_x
- except Exception as e:
- # 通用公有变量
- global public_topic_id
- # 日志采集
- print(self.log_msg.format(id=public_topic_id,
- type="清洗错误: "+str(e),
- message=str(content))) if self.logger is None else None
- self.logger.error(self.log_msg.format(id=public_topic_id,
- type="清洗错误: "+str(e),
- message=str(content))) if self.logger is not None else None
-
- # 对于无法清洗的文本通过正则表达式直接获取文本中的中文字符
- content_clear = re.sub(r'[^\u4e00-\u9fa5]', '', content)
- return content_clear
- # 重叠截取长文本进行Sentence-Bert训练
- def truncate_func(self, content):
- # 设置长文本截断长度
- cut_length = 150
- # 设置截断重叠长度
- overlap = 10
- content_cut_list = []
- # 若文本长度小于等于截断长度,则取消截取直接返回
- cont_length = len(content)
- if cont_length <= cut_length:
- content_cut_list = [content]
- return content_cut_list
- # 若文本长度大于截断长度,则进行重叠截断
- # 设定文本截断尾部合并阈值(针对尾部文本根据长度进行合并)
- # 防止截断后出现极短文本影响模型效果
- tail_merge_value = 0.5 * cut_length
- for i in range(0,cont_length,cut_length-overlap):
- tail_idx = i + cut_length
- cut_content = content[i:tail_idx]
- # 保留单词完整性
- # 判断尾部字符
- if cont_length - tail_idx > tail_merge_value:
- for j in range(len(cut_content)-1,-1,-1):
- # 判断当前字符是否为字母或者数字
- # 若不是字母或者数字则截取成功
- if re.search('[A-Za-z]', cut_content[j]) is None:
- cut_content = cut_content[:j+1]
- break
- else:
- cut_content = content[i:]
- # 判断头部字符
- if i != 0:
- for k in range(len(cut_content)):
- # 判断当前字符是否为字母或者数字
- # 若不是字母或者数字则截取成功
- if re.search('[A-Za-z]', cut_content[k]) is None:
- cut_content = cut_content[k+1:]
- break
- # 将头部和尾部都处理好的截断文本存入content_cut_list
- content_cut_list.append(cut_content)
- # 针对尾部文本截断长度为140-150以及满足尾部合并阈值的文本
- # 进行重叠截断进行特殊处理
- if cont_length - tail_idx <= tail_merge_value:
- break
-
- return content_cut_list
- # 数据清洗处理函数
- def content_clear_process(self, data, hnsw_index):
- # 全内容清洗组合列表
- content_clear_list = []
- if 'content' in data:
- content_clear_list.append(self.clear_func(data['content'], hnsw_index))
- elif 'stem' in data:
- content_clear_list.append(self.clear_func(data['stem'], hnsw_index))
- # 若题目中有小题,则对小题进行处理(递归实现)
- if 'slave' in data:
- content_clear_list = self.slave_func(data['slave'], content_clear_list)
-
- # 若题目中有选项,则对选项进行处理
- if 'option' in data:
- content_clear_list.extend(self.option_func(data['option'], hnsw_index))
- if 'options' in data:
- content_clear_list.extend(self.option_func(data['options'], hnsw_index))
- # 去除文本中的空格以及空字符串
- content_clear_list = [re.sub(r',+', ',', re.sub(r'[\s_]', '', content))
- for content in content_clear_list]
- content_clear_list = [content for content in content_clear_list if content != '']
- # 将清洗数据拼接
- content_clear = ";".join(content_clear_list)
- # 去除题目开头"(多选)/(..分)"
- content_clear = re.sub(r'^(\([单多]选\)|\[[单多]选\])', '', content_clear)
- content_clear = re.sub(r'^(\(.*?\d{1,2}分.*?\)|\[.*?\d{1,2}分.*?\])', '', content_clear)
- # 去除题目开头(...年...[中模月]考)文本
- head_search = re.search(r'^(\(.*?[\)\]]?\)|\[.*?[\)\]]?\])', content_clear)
- if head_search is not None and 5 < head_search.span(0)[1] < 40:
- head_value = content_clear[head_search.span(0)[0]+1:head_search.span(0)[1]-1]
- if re.search(r'.*?(\d{2}|[模检测训练考试验期省市县外第初高中学]).*?[模检测训练考试验期省市县外第初高中学].*?', head_value):
- content_clear = content_clear[head_search.span(0)[1]:].lstrip()
- # 将文本中的选项"A.B.C.D."改为";"
- content_clear = re.sub(r'[ABCD]\.', ';', content_clear)
- # 对于只有图片格式以及标点符号的信息进行特殊处理(去除标点符号/空格/连接符)
- if re.sub(r'[\.、。,;\:\?!#\-> ]+', '', content_clear) == '':
- content_clear = ''
-
- # 重叠截取长文本用于进行Sentence-Bert训练
- content_cut_list = self.truncate_func(content_clear)
-
- return content_clear, content_cut_list
- # 小题处理函数(递归实现)
- def slave_func(self, slave_data, content_clear_list, hnsw_index):
- # 若小题列表为空,则返回content_clear_list
- if slave_data is None or len(slave_data) == 0:
- return content_clear_list
- for slave in slave_data:
- if 'content' in slave:
- content_clear_list.append(self.clear_func(slave['content'], hnsw_index))
- if 'stem' in slave:
- content_clear_list.append(self.clear_func(slave['stem'], hnsw_index))
- if 'option' in slave:
- content_clear_list.extend(self.option_func(slave['option'], hnsw_index))
- if 'options' in slave:
- content_clear_list.extend(self.option_func(slave['options'], hnsw_index))
- if 'slave' in slave:
- content_clear_list = self.slave_func(slave['slave'], content_clear_list, hnsw_index)
- return content_clear_list
- # 选项处理函数
- def option_func(self, option_list, hnsw_index):
- # 若选项列表为空,则返回空列表
- if option_list is None or len(option_list) == 0:
- return []
- option_clear_list = []
- for option in option_list:
- if isinstance(option, dict):
- if 'content' in option:
- option_clear_list.append(self.clear_func(option['content'], hnsw_index))
- elif 'stem' in option:
- option_clear_list.append(self.clear_func(option['stem'], hnsw_index))
- elif isinstance(option, str):
- option_clear_list.append(self.clear_func(option, hnsw_index))
- elif isinstance(option, int) or isinstance(option, float):
- option_clear_list.append(str(option))
- # 若选项中只有1,2,3或(1)(2)(3),或者只有标点符号,则过滤该选项
- option_clear_list = [option for option in option_clear_list
- if re.sub(r'\(\d\)[、,;\.]?|(\d[、,;])+\d','',re.sub(r'\s','',option))!=''
- and re.sub(r'[\.、。,;\:\?!#\-> ]+','',option)!='']
-
- return option_clear_list
- if __name__ == "__main__":
- # 获取mongodb数据
- mongo_coll = config.mongo_coll_list
- test_data = [{'topic_id': '453368', 'topic_type_id': '1', 'subject_id': '3', 'stem': '<p>测试</p>', 'key': 'B', 'option': ['<p>1</p>', '<p>2</p>', '<p>3</p>', '<p>4</p>']}]
- dpp = DataPreProcessing(mongo_coll)
- # res = dpp.content_clear_process(test_data, hnsw_index=0)
- # print(res[0])
- res = dpp(test_data, hnsw_index=0)
- print(res[1])
- # print(dpp.content_clear_process(test_data[0], hnsw_index=0)[0])
- # print(dpp(test_data, hnsw_index=0, is_retrieve=True))
|