data_preprocessing.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. import re
  2. import random
  3. import pickle
  4. from bson.binary import Binary
  5. from copy import deepcopy
  6. import numpy as np
  7. from concurrent.futures import ThreadPoolExecutor
  8. from sentence_transformers import SentenceTransformer
  9. import config
  10. import config
  11. from main_clear.sci_clear import get_maplef_items
  12. # 按数据对应顺序随机打乱数据
  13. def shuffle_data_pair(idx_list, vec_list):
  14. zip_list = list(zip(idx_list, vec_list))
  15. random.shuffle(zip_list)
  16. idx_list, vec_list = zip(*zip_list)
  17. return idx_list, vec_list
  18. # 通用公有变量
  19. public_topic_id = 0
  20. # 数据预处理
  21. class DataPreProcessing():
  22. def __init__(self, mongo_coll, logger=None, is_train=False):
  23. self.mongo_coll = mongo_coll
  24. self.sbert_model = SentenceTransformer(config.sbert_path)
  25. self.is_train = is_train
  26. # 日志采集
  27. self.logger = logger
  28. self.log_msg = config.log_msg
  29. # 主函数
  30. def __call__(self, origin_dataset, hnsw_index, is_retrieve=False):
  31. # 句向量存储列表
  32. sent_vec_list = []
  33. # 批量处理数据字典
  34. bp_dict = deepcopy(config.batch_processing_dict)
  35. if self.is_train is False:
  36. hnsw_index_list = [hnsw_index] * len(origin_dataset)
  37. with ThreadPoolExecutor(max_workers=5) as executor:
  38. executor_list = list(executor.map(self.content_clear_process, origin_dataset, hnsw_index_list))
  39. cont_clear_tuple, cont_cut_tuple = zip(*executor_list)
  40. for data_idx, data in enumerate(origin_dataset):
  41. # 通用公有变量
  42. global public_topic_id
  43. # 记录topic_id
  44. topic_id = data["topic_id"] if "topic_id" in data else data_idx + 1
  45. public_topic_id = topic_id
  46. print(topic_id) if self.logger is None else None
  47. if self.is_train is True:
  48. # 数据清洗处理函数
  49. content_clear, content_cut_list = self.content_clear_process(data, hnsw_index)
  50. # 根据self.is_train赋值content_clear, content_cut_list
  51. content_clear = content_clear if self.is_train else cont_clear_tuple[data_idx]
  52. content_cut_list = content_cut_list if self.is_train else cont_cut_tuple[data_idx]
  53. # 日志采集
  54. self.logger.info(self.log_msg.format(
  55. id=topic_id,
  56. type="数据清洗结果",
  57. message=content_clear)) if self.logger and is_retrieve else None
  58. print(content_clear) if self.logger is None else None
  59. bp_dict["topic_id_list"].append(data["topic_id"]) if is_retrieve is False else None
  60. bp_dict["cont_clear_list"].append(content_clear)
  61. # 将所有截断数据融合进行一次句向量计算
  62. bp_dict["cont_cut_list"].extend(content_cut_list)
  63. # 获取每条数据的截断长度
  64. bp_dict["cut_idx_list"].append(bp_dict["cut_idx_list"][-1]+len(content_cut_list))
  65. # 设置批量处理长度,若满足条件则进行批量处理
  66. if (data_idx+1) % 5000 == 0:
  67. sent_vec_list = self.batch_processing(sent_vec_list, bp_dict, hnsw_index, is_retrieve)
  68. # 数据满足条件处理完毕后,则重置数据结构
  69. bp_dict = deepcopy(config.batch_processing_dict)
  70. if len(bp_dict["cont_clear_list"]) > 0:
  71. sent_vec_list = self.batch_processing(sent_vec_list, bp_dict, hnsw_index, is_retrieve)
  72. return sent_vec_list, bp_dict["cont_clear_list"]
  73. # 数据批量处理计算句向量
  74. def batch_processing(self, sent_vec_list, bp_dict, hnsw_index, is_retrieve):
  75. vec_list = self.sbert_model.encode(bp_dict["cont_cut_list"])
  76. # 计算题目中每个句子的完整句向量
  77. sent_length = len(bp_dict["cut_idx_list"]) - 1
  78. for i in range(sent_length):
  79. sentence_vec = np.array([np.nan])
  80. if bp_dict["cont_clear_list"][i] != '':
  81. # 平均池化
  82. sentence_vec = np.sum(vec_list[bp_dict["cut_idx_list"][i]:bp_dict["cut_idx_list"][i+1]], axis=0) \
  83. /(bp_dict["cut_idx_list"][i+1]-bp_dict["cut_idx_list"][i])
  84. sent_vec_list.append(sentence_vec) if self.is_train is False else None
  85. # 将结果存入数据库
  86. if is_retrieve is False:
  87. condition = {"topic_id": bp_dict["topic_id_list"][i]}
  88. # 用二进制存储句向量以节约存储空间
  89. sentence_vec_byte = Binary(pickle.dumps(sentence_vec, protocol=-1), subtype=128)
  90. # 需要新增train_flag,防止机器奔溃重复训练
  91. update_dict = {"content_clear": bp_dict["cont_clear_list"][i],
  92. "sentence_vec": sentence_vec_byte,
  93. "sent_train_flag": config.sent_train_flag}
  94. if hnsw_index == 0:
  95. update_dict["group_id"] = 0
  96. self.mongo_coll[hnsw_index].update_one(condition, {"$set": update_dict})
  97. return sent_vec_list
  98. # 清洗函数
  99. def clear_func(self, content, hnsw_index):
  100. if content in {'', None}:
  101. return ''
  102. # 将content字符串化,防止content是int/float型
  103. if isinstance(content, str) is False:
  104. if isinstance(content, int) or isinstance(content, float):
  105. return str(content)
  106. try:
  107. # 进行文本清洗
  108. if "#$#" not in content:
  109. content_clear = get_maplef_items(content, hnsw_index, self.is_train)
  110. else:
  111. content_clear_split = content.split("#$#")
  112. content_clear_t = get_maplef_items(content_clear_split[0], hnsw_index, self.is_train)
  113. content_clear_x = get_maplef_items(content_clear_split[1], hnsw_index, self.is_train)
  114. content_clear = content_clear_t + content_clear_x
  115. except Exception as e:
  116. # 通用公有变量
  117. global public_topic_id
  118. # 日志采集
  119. print(self.log_msg.format(id=public_topic_id,
  120. type="清洗错误: "+str(e),
  121. message=str(content))) if self.logger is None else None
  122. self.logger.error(self.log_msg.format(id=public_topic_id,
  123. type="清洗错误: "+str(e),
  124. message=str(content))) if self.logger is not None else None
  125. # 对于无法清洗的文本通过正则表达式直接获取文本中的中文字符
  126. content_clear = re.sub(r'[^\u4e00-\u9fa5]', '', content)
  127. return content_clear
  128. # 重叠截取长文本进行Sentence-Bert训练
  129. def truncate_func(self, content):
  130. # 设置长文本截断长度
  131. cut_length = 150
  132. # 设置截断重叠长度
  133. overlap = 10
  134. content_cut_list = []
  135. # 若文本长度小于等于截断长度,则取消截取直接返回
  136. cont_length = len(content)
  137. if cont_length <= cut_length:
  138. content_cut_list = [content]
  139. return content_cut_list
  140. # 若文本长度大于截断长度,则进行重叠截断
  141. # 设定文本截断尾部合并阈值(针对尾部文本根据长度进行合并)
  142. # 防止截断后出现极短文本影响模型效果
  143. tail_merge_value = 0.5 * cut_length
  144. for i in range(0,cont_length,cut_length-overlap):
  145. tail_idx = i + cut_length
  146. cut_content = content[i:tail_idx]
  147. # 保留单词完整性
  148. # 判断尾部字符
  149. if cont_length - tail_idx > tail_merge_value:
  150. for j in range(len(cut_content)-1,-1,-1):
  151. # 判断当前字符是否为字母或者数字
  152. # 若不是字母或者数字则截取成功
  153. if re.search('[A-Za-z]', cut_content[j]) is None:
  154. cut_content = cut_content[:j+1]
  155. break
  156. else:
  157. cut_content = content[i:]
  158. # 判断头部字符
  159. if i != 0:
  160. for k in range(len(cut_content)):
  161. # 判断当前字符是否为字母或者数字
  162. # 若不是字母或者数字则截取成功
  163. if re.search('[A-Za-z]', cut_content[k]) is None:
  164. cut_content = cut_content[k+1:]
  165. break
  166. # 将头部和尾部都处理好的截断文本存入content_cut_list
  167. content_cut_list.append(cut_content)
  168. # 针对尾部文本截断长度为140-150以及满足尾部合并阈值的文本
  169. # 进行重叠截断进行特殊处理
  170. if cont_length - tail_idx <= tail_merge_value:
  171. break
  172. return content_cut_list
  173. # 数据清洗处理函数
  174. def content_clear_process(self, data, hnsw_index):
  175. # 全内容清洗组合列表
  176. content_clear_list = []
  177. if 'content' in data:
  178. content_clear_list.append(self.clear_func(data['content'], hnsw_index))
  179. elif 'stem' in data:
  180. content_clear_list.append(self.clear_func(data['stem'], hnsw_index))
  181. # 若题目中有小题,则对小题进行处理(递归实现)
  182. if 'slave' in data:
  183. content_clear_list = self.slave_func(data['slave'], content_clear_list)
  184. # 若题目中有选项,则对选项进行处理
  185. if 'option' in data:
  186. content_clear_list.extend(self.option_func(data['option'], hnsw_index))
  187. if 'options' in data:
  188. content_clear_list.extend(self.option_func(data['options'], hnsw_index))
  189. # 去除文本中的空格以及空字符串
  190. content_clear_list = [re.sub(r',+', ',', re.sub(r'[\s_]', '', content))
  191. for content in content_clear_list]
  192. content_clear_list = [content for content in content_clear_list if content != '']
  193. # 将清洗数据拼接
  194. content_clear = ";".join(content_clear_list)
  195. # 去除题目开头"(多选)/(..分)"
  196. content_clear = re.sub(r'^(\([单多]选\)|\[[单多]选\])', '', content_clear)
  197. content_clear = re.sub(r'^(\(.*?\d{1,2}分.*?\)|\[.*?\d{1,2}分.*?\])', '', content_clear)
  198. # 去除题目开头(...年...[中模月]考)文本
  199. head_search = re.search(r'^(\(.*?[\)\]]?\)|\[.*?[\)\]]?\])', content_clear)
  200. if head_search is not None and 5 < head_search.span(0)[1] < 40:
  201. head_value = content_clear[head_search.span(0)[0]+1:head_search.span(0)[1]-1]
  202. if re.search(r'.*?(\d{2}|[模检测训练考试验期省市县外第初高中学]).*?[模检测训练考试验期省市县外第初高中学].*?', head_value):
  203. content_clear = content_clear[head_search.span(0)[1]:].lstrip()
  204. # 将文本中的选项"A.B.C.D."改为";"
  205. content_clear = re.sub(r'[ABCD]\.', ';', content_clear)
  206. # 对于只有图片格式以及标点符号的信息进行特殊处理(去除标点符号/空格/连接符)
  207. if re.sub(r'[\.、。,;\:\?!#\-> ]+', '', content_clear) == '':
  208. content_clear = ''
  209. # 重叠截取长文本用于进行Sentence-Bert训练
  210. content_cut_list = self.truncate_func(content_clear)
  211. return content_clear, content_cut_list
  212. # 小题处理函数(递归实现)
  213. def slave_func(self, slave_data, content_clear_list, hnsw_index):
  214. # 若小题列表为空,则返回content_clear_list
  215. if slave_data is None or len(slave_data) == 0:
  216. return content_clear_list
  217. for slave in slave_data:
  218. if 'content' in slave:
  219. content_clear_list.append(self.clear_func(slave['content'], hnsw_index))
  220. if 'stem' in slave:
  221. content_clear_list.append(self.clear_func(slave['stem'], hnsw_index))
  222. if 'option' in slave:
  223. content_clear_list.extend(self.option_func(slave['option'], hnsw_index))
  224. if 'options' in slave:
  225. content_clear_list.extend(self.option_func(slave['options'], hnsw_index))
  226. if 'slave' in slave:
  227. content_clear_list = self.slave_func(slave['slave'], content_clear_list, hnsw_index)
  228. return content_clear_list
  229. # 选项处理函数
  230. def option_func(self, option_list, hnsw_index):
  231. # 若选项列表为空,则返回空列表
  232. if option_list is None or len(option_list) == 0:
  233. return []
  234. option_clear_list = []
  235. for option in option_list:
  236. if isinstance(option, dict):
  237. if 'content' in option:
  238. option_clear_list.append(self.clear_func(option['content'], hnsw_index))
  239. elif 'stem' in option:
  240. option_clear_list.append(self.clear_func(option['stem'], hnsw_index))
  241. elif isinstance(option, str):
  242. option_clear_list.append(self.clear_func(option, hnsw_index))
  243. elif isinstance(option, int) or isinstance(option, float):
  244. option_clear_list.append(str(option))
  245. # 若选项中只有1,2,3或(1)(2)(3),或者只有标点符号,则过滤该选项
  246. option_clear_list = [option for option in option_clear_list
  247. if re.sub(r'\(\d\)[、,;\.]?|(\d[、,;])+\d','',re.sub(r'\s','',option))!=''
  248. and re.sub(r'[\.、。,;\:\?!#\-> ]+','',option)!='']
  249. return option_clear_list
  250. if __name__ == "__main__":
  251. # 获取mongodb数据
  252. mongo_coll = config.mongo_coll_list
  253. test_data = [{'topic_id': '453368', 'topic_type_id': '1', 'subject_id': '3', 'stem': '<p>测试</p>', 'key': 'B', 'option': ['<p>1</p>', '<p>2</p>', '<p>3</p>', '<p>4</p>']}]
  254. dpp = DataPreProcessing(mongo_coll)
  255. # res = dpp.content_clear_process(test_data, hnsw_index=0)
  256. # print(res[0])
  257. res = dpp(test_data, hnsw_index=0)
  258. print(res[1])
  259. # print(dpp.content_clear_process(test_data[0], hnsw_index=0)[0])
  260. # print(dpp(test_data, hnsw_index=0, is_retrieve=True))