data_preprocessing.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. import re
  2. import random
  3. import numpy as np
  4. import pickle
  5. from copy import deepcopy
  6. from bson.binary import Binary
  7. from concurrent.futures import ThreadPoolExecutor
  8. from sentence_transformers import SentenceTransformer
  9. import config
  10. from main_clear.sci_clear import get_maplef_items
  11. # 按数据对应顺序随机打乱数据
  12. def shuffle_data_pair(idx_list, vec_list):
  13. zip_list = list(zip(idx_list, vec_list))
  14. random.shuffle(zip_list)
  15. idx_list, vec_list = zip(*zip_list)
  16. return idx_list, vec_list
  17. # 通用公有变量
  18. public_id = 0
  19. # 数据预处理
  20. class DataPreProcessing():
  21. def __init__(self, mongo_coll=None, logger=None, is_train=False):
  22. # 配置初始数据
  23. self.mongo_coll = mongo_coll
  24. self.sbert_model = SentenceTransformer(config.sbert_path)
  25. self.is_train = is_train
  26. # 日志采集
  27. self.logger = logger
  28. self.log_msg = config.log_msg
  29. # 主函数
  30. def __call__(self, origin_dataset, is_retrieve=False):
  31. # 句向量存储列表
  32. sent_vec_list = []
  33. # 批量处理数据字典
  34. bp_dict = deepcopy(config.batch_processing_dict)
  35. # 批量数据清洗
  36. if self.is_train is False:
  37. with ThreadPoolExecutor(max_workers=5) as executor:
  38. executor_list = list(executor.map(self.content_clear_process, origin_dataset))
  39. cont_clear_tuple, cont_cut_tuple = zip(*executor_list)
  40. for data_idx, data in enumerate(origin_dataset):
  41. # 通用公有变量
  42. global public_id
  43. # 记录id
  44. public_id = data["id"] if "id" in data else data_idx + 1
  45. print(public_id) if self.logger is None else None
  46. if self.is_train is True:
  47. content_clear, content_cut_list = self.content_clear_process(data)
  48. # 根据self.is_train赋值content_clear, content_cut_list
  49. content_clear = content_clear if self.is_train else cont_clear_tuple[data_idx]
  50. content_cut_list = content_cut_list if self.is_train else cont_cut_tuple[data_idx]
  51. # 日志采集
  52. self.logger.info(self.log_msg.format(
  53. id=public_id,
  54. type="数据清洗结果",
  55. message=content_clear)) if self.logger and is_retrieve else None
  56. print(content_clear) if self.logger is None else None
  57. bp_dict["id_list"].append(data["id"]) if is_retrieve is False else None
  58. bp_dict["cont_clear_list"].append(content_clear)
  59. # 将所有截断数据融合进行一次句向量计算
  60. bp_dict["cont_cut_list"].extend(content_cut_list)
  61. # 获取每条数据的截断长度
  62. bp_dict["cut_idx_list"].append(bp_dict["cut_idx_list"][-1] + len(content_cut_list))
  63. # 设置批量处理长度,若满足条件则进行批量处理
  64. if (data_idx+1) % 5000 == 0:
  65. sent_vec_list = self.batch_processing(sent_vec_list, bp_dict, is_retrieve)
  66. # 数据满足条件处理完毕后,则重置数据结构
  67. bp_dict = deepcopy(config.batch_processing_dict)
  68. if len(bp_dict["cont_clear_list"]) > 0:
  69. sent_vec_list = self.batch_processing(sent_vec_list, bp_dict, is_retrieve)
  70. return sent_vec_list, bp_dict["cont_clear_list"]
  71. # 数据批量处理计算句向量
  72. def batch_processing(self, sent_vec_list, bp_dict, is_retrieve):
  73. vec_list = self.sbert_model.encode(bp_dict["cont_cut_list"])
  74. # 计算题目中每个句子的完整句向量
  75. sent_length = len(bp_dict["cut_idx_list"]) - 1
  76. for i in range(sent_length):
  77. sentence_vec = np.array([np.nan])
  78. if bp_dict["cont_clear_list"][i] != '':
  79. # 平均池化
  80. sentence_vec = np.sum(vec_list[bp_dict["cut_idx_list"][i]:bp_dict["cut_idx_list"][i+1]], axis=0) \
  81. /(bp_dict["cut_idx_list"][i+1]-bp_dict["cut_idx_list"][i])
  82. sent_vec_list.append(sentence_vec) if self.is_train is False else None
  83. # 将结果存入数据库
  84. if is_retrieve is False:
  85. condition = {"id": bp_dict["id_list"][i]}
  86. # 用二进制存储句向量以节约存储空间
  87. sentence_vec_byte = Binary(pickle.dumps(sentence_vec, protocol=-1), subtype=128)
  88. # 需要新增train_flag,防止机器奔溃重复训练
  89. update_elements = {"$set": {"content_clear": bp_dict["cont_clear_list"][i],
  90. "sentence_vec": sentence_vec_byte,
  91. "sent_train_flag": config.sent_train_flag}}
  92. self.mongo_coll.update_one(condition, update_elements)
  93. return sent_vec_list
  94. # 清洗函数
  95. def clear_func(self, content):
  96. if content in {'', None}:
  97. return ''
  98. # 将content字符串化,防止content是int/float型
  99. if isinstance(content, str) is False:
  100. if isinstance(content, int) or isinstance(content, float):
  101. return str(content)
  102. try:
  103. # 进行文本清洗
  104. content_clear = get_maplef_items(content)
  105. except Exception as e:
  106. # 通用公有变量
  107. global public_id
  108. # 日志采集
  109. print(self.log_msg.format(id=public_id,
  110. type="清洗错误: "+str(e),
  111. message=str(content))) if self.logger is None else None
  112. self.logger.error(self.log_msg.format(id=public_id,
  113. type="清洗错误: "+str(e),
  114. message=str(content))) if self.logger is not None else None
  115. # 对于无法清洗的文本通过正则表达式直接获取文本中的中文字符
  116. content_clear = re.sub(r'[^\u4e00-\u9fa5]', '', content)
  117. return content_clear
  118. # 重叠截取长文本进行Sentence-Bert训练
  119. def truncate_func(self, content):
  120. # 设置长文本截断长度
  121. cut_length = 150
  122. # 设置截断重叠长度
  123. overlap = 10
  124. content_cut_list = []
  125. # 若文本长度小于等于截断长度,则取消截取直接返回
  126. cont_length = len(content)
  127. if cont_length <= cut_length:
  128. content_cut_list = [content]
  129. return content_cut_list
  130. # 若文本长度大于截断长度,则进行重叠截断
  131. # 设定文本截断尾部合并阈值(针对尾部文本根据长度进行合并)
  132. # 防止截断后出现极短文本影响模型效果
  133. tail_merge_value = 0.5 * cut_length
  134. for i in range(0,cont_length,cut_length-overlap):
  135. tail_idx = i + cut_length
  136. cut_content = content[i:tail_idx]
  137. # 保留单词完整性
  138. # 判断尾部字符
  139. if cont_length - tail_idx > tail_merge_value:
  140. for j in range(len(cut_content)-1,-1,-1):
  141. # 判断当前字符是否为字母或者数字
  142. # 若不是字母或者数字则截取成功
  143. if re.search('[A-Za-z]', cut_content[j]) is None:
  144. cut_content = cut_content[:j+1]
  145. break
  146. else:
  147. cut_content = content[i:]
  148. # 判断头部字符
  149. if i != 0:
  150. for k in range(len(cut_content)):
  151. # 判断当前字符是否为字母或者数字
  152. # 若不是字母或者数字则截取成功
  153. if re.search('[A-Za-z]', cut_content[k]) is None:
  154. cut_content = cut_content[k+1:]
  155. break
  156. # 将头部和尾部都处理好的截断文本存入content_cut_list
  157. content_cut_list.append(cut_content)
  158. # 针对尾部文本截断长度为140-150以及满足尾部合并阈值的文本
  159. # 进行重叠截断进行特殊处理
  160. if cont_length - tail_idx <= tail_merge_value:
  161. break
  162. return content_cut_list
  163. # 全文本数据清洗
  164. def content_clear_func(self, content):
  165. # 文本清洗
  166. content_clear = self.clear_func(content)
  167. # 去除文本中的空格以及空字符串
  168. content_clear = re.sub(r',+', ',', re.sub(r'[\s_]', '', content_clear))
  169. # 去除题目开头"【题文】(多选)/(..分)"
  170. content_clear = re.sub(r'\[题文\]', '', content_clear)
  171. content_clear = re.sub(r'(\([单多]选\)|\[[单多]选\])', '', content_clear)
  172. content_clear = re.sub(r'(\(\d{1,2}分\)|\[\d{1,2}分\])', '', content_clear)
  173. # 将文本中的选项"A.B.C.D."改为";"
  174. content_clear = re.sub(r'[ABCD]\.', ';', content_clear)
  175. # # 若选项中只有1,2,3或(1)(2)(3),或者只有标点符号,则过滤该选项
  176. # content_clear = re.sub(r'(\(\d\)[、,;\.]?)+\(\d\)|\d[、,;]+\d', '', content_clear)
  177. # 去除题目开头(...年...[中模月]考)文本
  178. head_search = re.search(r'^(\(.*?[\)\]]?\)|\[.*?[\)\]]?\])', content_clear)
  179. if head_search is not None and 5 < head_search.span(0)[1] < 40:
  180. head_value = content_clear[head_search.span(0)[0]+1:head_search.span(0)[1]-1]
  181. if re.search(r'.*?(\d{2}|[模检测训练考试验期省市县外第初高中学]).*?[模检测训练考试验期省市县外第初高中学].*?', head_value):
  182. content_clear = content_clear[head_search.span(0)[1]:].lstrip()
  183. # 对于只有图片格式以及标点符号的信息进行特殊处理(去除标点符号/空格/连接符)
  184. if re.sub(r'[\.、。,;\:\?!#\-> ]+', '', content_clear) == '':
  185. content_clear = ''
  186. return content_clear
  187. # 数据清洗与长文本重叠截取处理
  188. def content_clear_process(self, data):
  189. # 初始化content_clear
  190. content_clear = ''
  191. # 全文本数据清洗
  192. if "quesBody" in data:
  193. content_clear = self.content_clear_func(data["quesBody"])
  194. elif "stem" in data:
  195. content_clear = self.content_clear_func(data["stem"])
  196. # 重叠截取长文本用于进行Sentence-Bert训练
  197. content_cut_list = self.truncate_func(content_clear)
  198. return content_clear, content_cut_list
  199. if __name__ == "__main__":
  200. # 获取mongodb数据
  201. mongo_coll = config.mongo_coll
  202. test_data = {
  203. 'quesBody': """【题文】某同学设计了如下的电路测量电压表内阻,<i style=\"font-family: Times New Roman;\">R</i>为能够满足实验条件的滑动变阻器,<img src=\"Upload/QBM/be06c18fe93021f9c1f90227bd169e1b.png\"mathml=\"PG1hdGg+PG1yb3c+PG1zdXA+PG1pPlI8L21pPjxtbz7igLI8L21vPjwvbXN1cD48L21yb3c+PC9tYXRoPg==\"latex=\"${{R}^{\\prime }}$\"style=\"vertical-align:middle;\">为电阻箱,电压表量程合适。实验的粗略步骤如下:<br/><img src=\"Upload/QBM/f9db6da2-ddcd-422c-b87e-3481c242ff66.png\" style=\"vertical-align:middle;\" alt=\"\" width=\"164px\" height=\"198px\"><br /><br/>①闭合开关S<sub >1</sub>、S<sub >2</sub>,调节滑动变阻器<i style=\"font-family: Times New Roman;\">R</i>,使电压表指针指向满刻度的<img src=\"Upload/QBM/bf31876698721a199c7c53c6b320aa86.png\"mathml=\"PG1hdGg+PG1yb3c+PG1mcmFjPjxtbj4yPC9tbj48bW4+MzwvbW4+PC9tZnJhYz48L21yb3c+PC9tYXRoPg==\"latex=\"$\\frac{2}{3}$\"style=\"vertical-align:middle;\">处;<br/>②断开开关S<sub >2</sub>,调节某些仪器,使电压表指针指向满刻度的<img src=\"Upload/QBM/4dac452fbb5ef6dd653e7fbbef639484.png\"mathml=\"PG1hdGg+PG1yb3c+PG1mcmFjPjxtbj4xPC9tbj48bW4+MzwvbW4+PC9tZnJhYz48L21yb3c+PC9tYXRoPg==\"latex=\"$\\frac{1}{3}$\"style=\"vertical-align:middle;\">处;<br/>③读出电阻箱的阻值,该阻值即为电压表内阻的测量值;<br/>④断开开关S<sub >1</sub>、S<sub >2</sub>拆下实验仪器,整理器材。<br/>(1)上述实验步骤②中,调节某些仪器时,正确的操作是<bk size=\"10\" index=\"1\" type=\"underline\">__________</bk><br/>A.保持电阻箱阻值<img src=\"Upload/QBM/be06c18fe93021f9c1f90227bd169e1b.png\"mathml=\"PG1hdGg+PG1yb3c+PG1zdXA+PG1pPlI8L21pPjxtbz7igLI8L21vPjwvbXN1cD48L21yb3c+PC9tYXRoPg==\"latex=\"${{R}^{\\prime }}$\"style=\"vertical-align:middle;\">不变,调节滑动变阻器的滑片,使电压表指针指向满刻度的<img src=\"Upload/QBM/4dac452fbb5ef6dd653e7fbbef639484.png\"mathml=\"PG1hdGg+PG1yb3c+PG1mcmFjPjxtbj4xPC9tbj48bW4+MzwvbW4+PC9tZnJhYz48L21yb3c+PC9tYXRoPg==\"latex=\"$\\frac{1}{3}$\"style=\"vertical-align:middle;\">处<br/>B.保持滑动变阻器的滑片位置不变,调节电阻箱阻值<img src=\"Upload/QBM/be06c18fe93021f9c1f90227bd169e1b.png\"mathml=\"PG1hdGg+PG1yb3c+PG1zdXA+PG1pPlI8L21pPjxtbz7igLI8L21vPjwvbXN1cD48L21yb3c+PC9tYXRoPg==\"latex=\"${{R}^{\\prime }}$\"style=\"vertical-align:middle;\">,使电压表指针指向满刻度的<img src=\"Upload/QBM/4dac452fbb5ef6dd653e7fbbef639484.png\"mathml=\"PG1hdGg+PG1yb3c+PG1mcmFjPjxtbj4xPC9tbj48bW4+MzwvbW4+PC9tZnJhYz48L21yb3c+PC9tYXRoPg==\"latex=\"$\\frac{1}{3}$\"style=\"vertical-align:middle;\">处<br/>C.同时调节滑动变阻器和电阻箱的阻值,使电压表指针指向满刻度的<img src=\"Upload/QBM/4dac452fbb5ef6dd653e7fbbef639484.png\"mathml=\"PG1hdGg+PG1yb3c+PG1mcmFjPjxtbj4xPC9tbj48bW4+MzwvbW4+PC9tZnJhYz48L21yb3c+PC9tYXRoPg==\"latex=\"$\\frac{1}{3}$\"style=\"vertical-align:middle;\">处<br/>(2)此实验电压表内阻的测量值与真实值相比<bk size=\"8\" index=\"2\" type=\"underline\">________</bk>(选填“偏大”“偏小”或“相等”);<br/>(3)如实验测得该电压表内阻为8500Ω,要将其量程扩大为原来的<img src=\"Upload/QBM/5c55e4f3eda94bc505f103b10bc1fee7.png\"mathml=\"PG1hdGg+PG1yb3c+PG1mcmFjPjxtbj42PC9tbj48bW4+NTwvbW4+PC9tZnJhYz48L21yb3c+PC9tYXRoPg==\"latex=\"$\\frac{6}{5}$\"style=\"vertical-align:middle;\">倍,需串联<bk size=\"5\" index=\"3\" type=\"underline\">_____</bk>Ω的电阻。""",
  204. 'option': ['$\\left\\{-2,0\\right\\}$', '$\\left\\{-2,0,2\\right\\}$', '$\\left\\{-1,1,2\\right\\}$', '$\\left\\{-1,0,2\\right\\}$']}
  205. dpp = DataPreProcessing(mongo_coll)
  206. string = """如图三角形的每个顶点均在格点上,且每个小正方形的边长为1.<br/><img src="http://tkimgs.zhixinhuixue.net/image/word/2021/09/05/1630825213599903.png" width="197px" height="180px" /><br/>(1)<img src="http://tkimgs.zhixinhuixue.net/image/word/2021/09/05/1630825213498556.png" data-latex="${\angle 1+\angle 2+\angle 3=}$" width="118",height="12" />________<img src="http://tkimgs.zhixinhuixue.net/image/word/2021/09/05/1630825213387699.png" data-latex="${^{\circ } }$" width="5",height="5" />;<br/>(2)求<img src="http://tkimgs.zhixinhuixue.net/image/word/2021/09/05/1630825213446366.png" data-latex="${\triangle ABC}$" width="54",height="11" />的面积."""
  207. string = """已知c水=4.2×103J/(kg·℃),求"""
  208. res = dpp.clear_func(string)
  209. print(res)
  210. # res = dpp.content_clear_process(test_data)
  211. # print(res[0])
  212. # print(dpp.content_clear_process(mongo_coll.find_one({}))[0])
  213. # print(dpp(test_data,is_retrieve=True))