data_preprocessing.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. import re
  2. import random
  3. import numpy as np
  4. import pickle
  5. from copy import deepcopy
  6. from bson.binary import Binary
  7. from sentence_transformers import SentenceTransformer
  8. import config
  9. from main_clear.sci_clear import get_maplef_items
  10. # 按数据对应顺序随机打乱数据
  11. def shuffle_data_pair(idx_list, vec_list):
  12. zip_list = list(zip(idx_list, vec_list))
  13. random.shuffle(zip_list)
  14. idx_list, vec_list = zip(*zip_list)
  15. return idx_list, vec_list
  16. # 通用公有变量
  17. public_id = 0
  18. public_data_index = -1
  19. # 数据预处理
  20. class DataPreProcessing():
  21. def __init__(self, mongo_coll=None, logger=None, is_train=False):
  22. # 配置初始数据
  23. self.mongo_coll = mongo_coll
  24. self.sbert_model = SentenceTransformer(config.sbert_path)
  25. self.is_train = is_train
  26. # 日志采集
  27. self.logger = logger
  28. self.log_msg = config.log_msg
  29. # 主函数
  30. def __call__(self, origin_dataset, is_retrieve=False):
  31. # 句向量存储列表
  32. sent_vec_list = []
  33. # 批量处理数据字典
  34. bp_dict = deepcopy(config.batch_processing_dict)
  35. for data_idx, data in enumerate(origin_dataset):
  36. # 通用公有变量
  37. global public_id
  38. global public_data_index
  39. public_data_index = data_idx + 1
  40. # 记录id
  41. public_id = data["id"] if "id" in data else 0
  42. print(public_id) if self.logger is None else None
  43. content_clear, content_cut_list = self.content_clear_process(data)
  44. # 日志采集
  45. self.logger.info(self.log_msg.format(
  46. id=public_id if public_id!=0 else public_data_index,
  47. type="数据清洗结果",
  48. message=content_clear)) if self.logger and is_retrieve else None
  49. print(content_clear) if self.logger is None else None
  50. bp_dict["id_list"].append(data["id"]) if is_retrieve is False else None
  51. bp_dict["cont_clear_list"].append(content_clear)
  52. # 将所有截断数据融合进行一次句向量计算
  53. bp_dict["cont_cut_list"].extend(content_cut_list)
  54. # 获取每条数据的截断长度
  55. bp_dict["cut_idx_list"].append(bp_dict["cut_idx_list"][-1] + len(content_cut_list))
  56. # 设置批量处理长度,若满足条件则进行批量处理
  57. if (data_idx+1) % 5000 == 0:
  58. sent_vec_list = self.batch_processing(sent_vec_list, bp_dict, is_retrieve)
  59. # 数据满足条件处理完毕后,则重置数据结构
  60. bp_dict = deepcopy(config.batch_processing_dict)
  61. if len(bp_dict["cont_clear_list"]) > 0:
  62. sent_vec_list = self.batch_processing(sent_vec_list, bp_dict, is_retrieve)
  63. return sent_vec_list, bp_dict["cont_clear_list"]
  64. # 数据批量处理计算句向量
  65. def batch_processing(self, sent_vec_list, bp_dict, is_retrieve):
  66. vec_list = self.sbert_model.encode(bp_dict["cont_cut_list"])
  67. # 计算题目中每个句子的完整句向量
  68. sent_length = len(bp_dict["cut_idx_list"]) - 1
  69. for i in range(sent_length):
  70. sentence_vec = np.array([np.nan])
  71. if bp_dict["cont_clear_list"][i] != '':
  72. # 平均池化
  73. sentence_vec = np.sum(vec_list[bp_dict["cut_idx_list"][i]:bp_dict["cut_idx_list"][i+1]], axis=0) \
  74. /(bp_dict["cut_idx_list"][i+1]-bp_dict["cut_idx_list"][i])
  75. sent_vec_list.append(sentence_vec) if self.is_train is False else None
  76. # 将结果存入数据库
  77. if is_retrieve is False:
  78. condition = {"id": bp_dict["id_list"][i]}
  79. # 用二进制存储句向量以节约存储空间
  80. sentence_vec_byte = Binary(pickle.dumps(sentence_vec, protocol=-1), subtype=128)
  81. # 需要新增train_flag,防止机器奔溃重复训练
  82. update_elements = {"$set": {"content_clear": bp_dict["cont_clear_list"][i],
  83. "sentence_vec": sentence_vec_byte,
  84. "sent_train_flag": config.sent_train_flag}}
  85. self.mongo_coll.update_one(condition, update_elements)
  86. return sent_vec_list
  87. # 清洗函数
  88. def clear_func(self, content):
  89. if content in {'', None}:
  90. return ''
  91. # 将content字符串化,防止content是int/float型
  92. if isinstance(content, str) is False:
  93. if isinstance(content, int) or isinstance(content, float):
  94. return str(content)
  95. try:
  96. # 进行文本清洗
  97. content_clear = get_maplef_items(content)
  98. except Exception as e:
  99. # 通用公有变量
  100. global public_id
  101. global public_data_index
  102. # 日志采集
  103. print(self.log_msg.format(id=public_id if public_id!=0 else public_data_index,
  104. type="清洗错误: "+str(e),
  105. message=str(content))) if self.logger is None else None
  106. self.logger.error(self.log_msg.format(id=public_id if public_id!=0 else public_data_index,
  107. type="清洗错误: "+str(e),
  108. message=str(content))) if self.logger is not None else None
  109. # 对于无法清洗的文本通过正则表达式直接获取文本中的中文字符
  110. content_clear = re.sub(r'[^\u4e00-\u9fa5]', '', content)
  111. return content_clear
  112. # 重叠截取长文本进行Sentence-Bert训练
  113. def truncate_func(self, content):
  114. # 设置长文本截断长度
  115. cut_length = 150
  116. # 设置截断重叠长度
  117. overlap = 10
  118. content_cut_list = []
  119. # 若文本长度小于等于截断长度,则取消截取直接返回
  120. cont_length = len(content)
  121. if cont_length <= cut_length:
  122. content_cut_list = [content]
  123. return content_cut_list
  124. # 若文本长度大于截断长度,则进行重叠截断
  125. # 设定文本截断尾部合并阈值(针对尾部文本根据长度进行合并)
  126. # 防止截断后出现极短文本影响模型效果
  127. tail_merge_value = 0.5 * cut_length
  128. for i in range(0,cont_length,cut_length-overlap):
  129. tail_idx = i + cut_length
  130. cut_content = content[i:tail_idx]
  131. # 保留单词完整性
  132. # 判断尾部字符
  133. if cont_length - tail_idx > tail_merge_value:
  134. for j in range(len(cut_content)-1,-1,-1):
  135. # 判断当前字符是否为字母或者数字
  136. # 若不是字母或者数字则截取成功
  137. if re.search('[A-Za-z]', cut_content[j]) is None:
  138. cut_content = cut_content[:j+1]
  139. break
  140. else:
  141. cut_content = content[i:]
  142. # 判断头部字符
  143. if i != 0:
  144. for k in range(len(cut_content)):
  145. # 判断当前字符是否为字母或者数字
  146. # 若不是字母或者数字则截取成功
  147. if re.search('[A-Za-z]', cut_content[k]) is None:
  148. cut_content = cut_content[k+1:]
  149. break
  150. # 将头部和尾部都处理好的截断文本存入content_cut_list
  151. content_cut_list.append(cut_content)
  152. # 针对尾部文本截断长度为140-150以及满足尾部合并阈值的文本
  153. # 进行重叠截断进行特殊处理
  154. if cont_length - tail_idx <= tail_merge_value:
  155. break
  156. return content_cut_list
  157. # 全文本数据清洗
  158. def content_clear_func(self, content):
  159. # 文本清洗
  160. content_clear = self.clear_func(content)
  161. # 去除文本中的空格以及空字符串
  162. content_clear = re.sub(r',+', ',', re.sub(r'[\s_]', '', content_clear))
  163. # 去除题目开头"【题文】(多选)/(..分)"
  164. content_clear = re.sub(r'\[题文\]', '', content_clear)
  165. content_clear = re.sub(r'(\([单多]选\)|\[[单多]选\])', '', content_clear)
  166. content_clear = re.sub(r'(\(\d{1,2}分\)|\[\d{1,2}分\])', '', content_clear)
  167. # 将文本中的选项"A.B.C.D."改为";"
  168. content_clear = re.sub(r'[ABCD]\.', ';', content_clear)
  169. # # 若选项中只有1,2,3或(1)(2)(3),或者只有标点符号,则过滤该选项
  170. # content_clear = re.sub(r'(\(\d\)[、,;\.]?)+\(\d\)|\d[、,;]+\d', '', content_clear)
  171. # 去除题目开头(...年...[中模月]考)文本
  172. head_search = re.search(r'^(\(.*?[\)\]]?\)|\[.*?[\)\]]?\])', content_clear)
  173. if head_search is not None and 5 < head_search.span(0)[1] < 40:
  174. head_value = content_clear[head_search.span(0)[0]+1:head_search.span(0)[1]-1]
  175. if re.search(r'.*?(\d{2}|[模检测训练考试验期省市县外第初高中学]).*?[模检测训练考试验期省市县外第初高中学].*?', head_value):
  176. content_clear = content_clear[head_search.span(0)[1]:].lstrip()
  177. # 对于只有图片格式以及标点符号的信息进行特殊处理(去除标点符号/空格/连接符)
  178. if re.sub(r'[\.、。,;\:\?!#\-> ]+', '', content_clear) == '':
  179. content_clear = ''
  180. return content_clear
  181. # 数据清洗与长文本重叠截取处理
  182. def content_clear_process(self, data):
  183. # 初始化content_clear
  184. content_clear = ''
  185. # 全文本数据清洗
  186. if "quesBody" in data:
  187. content_clear = self.content_clear_func(data["quesBody"])
  188. elif "stem" in data:
  189. content_clear = self.content_clear_func(data["stem"])
  190. # 重叠截取长文本用于进行Sentence-Bert训练
  191. content_cut_list = self.truncate_func(content_clear)
  192. return content_clear, content_cut_list
  193. if __name__ == "__main__":
  194. # 获取mongodb数据
  195. mongo_coll = config.mongo_coll
  196. test_data = {
  197. 'quesBody': """【题文】某同学设计了如下的电路测量电压表内阻,<i style=\"font-family: Times New Roman;\">R</i>为能够满足实验条件的滑动变阻器,<img src=\"Upload/QBM/be06c18fe93021f9c1f90227bd169e1b.png\"mathml=\"PG1hdGg+PG1yb3c+PG1zdXA+PG1pPlI8L21pPjxtbz7igLI8L21vPjwvbXN1cD48L21yb3c+PC9tYXRoPg==\"latex=\"${{R}^{\\prime }}$\"style=\"vertical-align:middle;\">为电阻箱,电压表量程合适。实验的粗略步骤如下:<br/><img src=\"Upload/QBM/f9db6da2-ddcd-422c-b87e-3481c242ff66.png\" style=\"vertical-align:middle;\" alt=\"\" width=\"164px\" height=\"198px\"><br /><br/>①闭合开关S<sub >1</sub>、S<sub >2</sub>,调节滑动变阻器<i style=\"font-family: Times New Roman;\">R</i>,使电压表指针指向满刻度的<img src=\"Upload/QBM/bf31876698721a199c7c53c6b320aa86.png\"mathml=\"PG1hdGg+PG1yb3c+PG1mcmFjPjxtbj4yPC9tbj48bW4+MzwvbW4+PC9tZnJhYz48L21yb3c+PC9tYXRoPg==\"latex=\"$\\frac{2}{3}$\"style=\"vertical-align:middle;\">处;<br/>②断开开关S<sub >2</sub>,调节某些仪器,使电压表指针指向满刻度的<img src=\"Upload/QBM/4dac452fbb5ef6dd653e7fbbef639484.png\"mathml=\"PG1hdGg+PG1yb3c+PG1mcmFjPjxtbj4xPC9tbj48bW4+MzwvbW4+PC9tZnJhYz48L21yb3c+PC9tYXRoPg==\"latex=\"$\\frac{1}{3}$\"style=\"vertical-align:middle;\">处;<br/>③读出电阻箱的阻值,该阻值即为电压表内阻的测量值;<br/>④断开开关S<sub >1</sub>、S<sub >2</sub>拆下实验仪器,整理器材。<br/>(1)上述实验步骤②中,调节某些仪器时,正确的操作是<bk size=\"10\" index=\"1\" type=\"underline\">__________</bk><br/>A.保持电阻箱阻值<img src=\"Upload/QBM/be06c18fe93021f9c1f90227bd169e1b.png\"mathml=\"PG1hdGg+PG1yb3c+PG1zdXA+PG1pPlI8L21pPjxtbz7igLI8L21vPjwvbXN1cD48L21yb3c+PC9tYXRoPg==\"latex=\"${{R}^{\\prime }}$\"style=\"vertical-align:middle;\">不变,调节滑动变阻器的滑片,使电压表指针指向满刻度的<img src=\"Upload/QBM/4dac452fbb5ef6dd653e7fbbef639484.png\"mathml=\"PG1hdGg+PG1yb3c+PG1mcmFjPjxtbj4xPC9tbj48bW4+MzwvbW4+PC9tZnJhYz48L21yb3c+PC9tYXRoPg==\"latex=\"$\\frac{1}{3}$\"style=\"vertical-align:middle;\">处<br/>B.保持滑动变阻器的滑片位置不变,调节电阻箱阻值<img src=\"Upload/QBM/be06c18fe93021f9c1f90227bd169e1b.png\"mathml=\"PG1hdGg+PG1yb3c+PG1zdXA+PG1pPlI8L21pPjxtbz7igLI8L21vPjwvbXN1cD48L21yb3c+PC9tYXRoPg==\"latex=\"${{R}^{\\prime }}$\"style=\"vertical-align:middle;\">,使电压表指针指向满刻度的<img src=\"Upload/QBM/4dac452fbb5ef6dd653e7fbbef639484.png\"mathml=\"PG1hdGg+PG1yb3c+PG1mcmFjPjxtbj4xPC9tbj48bW4+MzwvbW4+PC9tZnJhYz48L21yb3c+PC9tYXRoPg==\"latex=\"$\\frac{1}{3}$\"style=\"vertical-align:middle;\">处<br/>C.同时调节滑动变阻器和电阻箱的阻值,使电压表指针指向满刻度的<img src=\"Upload/QBM/4dac452fbb5ef6dd653e7fbbef639484.png\"mathml=\"PG1hdGg+PG1yb3c+PG1mcmFjPjxtbj4xPC9tbj48bW4+MzwvbW4+PC9tZnJhYz48L21yb3c+PC9tYXRoPg==\"latex=\"$\\frac{1}{3}$\"style=\"vertical-align:middle;\">处<br/>(2)此实验电压表内阻的测量值与真实值相比<bk size=\"8\" index=\"2\" type=\"underline\">________</bk>(选填“偏大”“偏小”或“相等”);<br/>(3)如实验测得该电压表内阻为8500Ω,要将其量程扩大为原来的<img src=\"Upload/QBM/5c55e4f3eda94bc505f103b10bc1fee7.png\"mathml=\"PG1hdGg+PG1yb3c+PG1mcmFjPjxtbj42PC9tbj48bW4+NTwvbW4+PC9tZnJhYz48L21yb3c+PC9tYXRoPg==\"latex=\"$\\frac{6}{5}$\"style=\"vertical-align:middle;\">倍,需串联<bk size=\"5\" index=\"3\" type=\"underline\">_____</bk>Ω的电阻。""",
  198. 'option': ['$\\left\\{-2,0\\right\\}$', '$\\left\\{-2,0,2\\right\\}$', '$\\left\\{-1,1,2\\right\\}$', '$\\left\\{-1,0,2\\right\\}$']}
  199. dpp = DataPreProcessing(mongo_coll)
  200. string = """如图三角形的每个顶点均在格点上,且每个小正方形的边长为1.<br/><img src="http://tkimgs.zhixinhuixue.net/image/word/2021/09/05/1630825213599903.png" width="197px" height="180px" /><br/>(1)<img src="http://tkimgs.zhixinhuixue.net/image/word/2021/09/05/1630825213498556.png" data-latex="${\angle 1+\angle 2+\angle 3=}$" width="118",height="12" />________<img src="http://tkimgs.zhixinhuixue.net/image/word/2021/09/05/1630825213387699.png" data-latex="${^{\circ } }$" width="5",height="5" />;<br/>(2)求<img src="http://tkimgs.zhixinhuixue.net/image/word/2021/09/05/1630825213446366.png" data-latex="${\triangle ABC}$" width="54",height="11" />的面积."""
  201. string = """已知c水=4.2×103J/(kg·℃),求"""
  202. res = dpp.clear_func(string)
  203. print(res)
  204. # res = dpp.content_clear_process(test_data)
  205. # print(res[0])
  206. # print(dpp.content_clear_process(mongo_coll.find_one({}))[0])
  207. # print(dpp(test_data,is_retrieve=True))