hnsw_logic.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. import time
  2. import pickle
  3. import requests
  4. from bson.binary import Binary
  5. from fuzzywuzzy import fuzz
  6. from sentence_transformers import util
  7. from pprint import pprint
  8. import config
  9. from data_preprocessing import DataPreProcessing
  10. class Hnsw_Logic():
  11. def __init__(self, logger=None):
  12. # 配置初始数据
  13. self.mongo_coll = config.mongo_coll_list
  14. self.vector_dim = config.vector_dim
  15. self.database_threshold = config.database_threshold
  16. self.hnsw_update_url = config.hnsw_update_url
  17. self.hnsw_retrieve_url = config.hnsw_retrieve_url
  18. # 日志采集
  19. self.logger = logger
  20. self.log_msg = config.log_msg
  21. # 数据预处理初始化
  22. self.dpp = DataPreProcessing(self.mongo_coll, self.logger)
  23. # HNSW查询逻辑判断
  24. def logic_process(self, retrieve_list, hnsw_index):
  25. # 调用清洗分词函数和句向量计算函数
  26. sent_vec_list, cont_clear_list = self.dpp(retrieve_list, hnsw_index, is_retrieve=True)
  27. # 云题库HNSW查重
  28. cloud_list = self.retrieve(retrieve_list, sent_vec_list, cont_clear_list, hnsw_index=0)
  29. if hnsw_index == 0:
  30. return cloud_list
  31. # 校本题库HNSW查重
  32. school_list = self.retrieve(retrieve_list, sent_vec_list, cont_clear_list, hnsw_index=1)
  33. # 遍历retrieve_list, 将数据插入mongodb
  34. for i,data in enumerate(retrieve_list):
  35. topic_id = data["topic_id"]
  36. # 判断相似度并将符合要求的数据存入mongodb
  37. if len(school_list[i]) > 0 and school_list[i][0][1] > 0.97:
  38. continue
  39. # 判断清洗文本长度,若长度小于10,则表示清洗失败需要过滤
  40. if len(cont_clear_list[i]) < 10:
  41. continue
  42. # 防止出现重复的topic_id
  43. if self.mongo_coll[1].find_one({"topic_id": int(topic_id)}) is None:
  44. try:
  45. self.school_mongodb_insert(data, sent_vec_list[i], cont_clear_list[i])
  46. # 将数据实时更新至hnsw模型
  47. self.update(int(topic_id), hnsw_index=1)
  48. except Exception as e:
  49. # 日志采集
  50. self.logger.info(self.log_msg.format(id=int(topic_id),
  51. type="chc retrieval insert",
  52. message="chc查重数据插入失败-"+str(e)))
  53. return cloud_list, school_list
  54. # 将校本题库接收数据插入mongodb
  55. def school_mongodb_insert(self, data, sentence_vec, content_clear):
  56. sentence_vec_byte = Binary(pickle.dumps(sentence_vec, protocol=-1), subtype=128)
  57. insert_dict = {
  58. 'topic_id': int(data['topic_id']),
  59. 'content_raw': data['stem'] if 'stem' in data else '',
  60. 'content': content_clear,
  61. 'content_clear': content_clear,
  62. 'sentence_vec': sentence_vec_byte,
  63. 'sent_train_flag': config.sent_train_flag,
  64. 'topic_type_id': int(data['topic_type_id']) if 'topic_type_id' in data else 0,
  65. 'school_id': int(data['school_id']) if 'school_id' in data else 0,
  66. 'parse': data['parse'] if 'parse' in data else '',
  67. 'answer': data['answer'] if 'answer' in data else '',
  68. 'save_time': time.time(),
  69. 'subject_id': int(data['subject_id']) if 'subject_id' in data else 0
  70. }
  71. # 将数据插入mongodb
  72. self.mongo_coll[1].insert_one(insert_dict)
  73. # 日志采集
  74. self.logger.info(self.log_msg.format(id=int(data['topic_id']),
  75. type="chc retrieval insert",
  76. message="已将chc查重数据插入mongo_coll_school"))
  77. # HNSW增/改
  78. def update(self, update_id, hnsw_index):
  79. if hnsw_index == 0:
  80. # 数据清洗、分词与句向量化
  81. cloud_data = self.mongo_coll[hnsw_index].find_one({"topic_id": update_id})
  82. self.dpp([cloud_data], hnsw_index=0)
  83. # 调用hnsw接口更新数据
  84. update_dict = dict(id=update_id, hnsw_index=hnsw_index)
  85. try:
  86. requests.post(self.hnsw_update_url, json=update_dict, timeout=10)
  87. # 日志采集
  88. if self.logger is not None:
  89. db_name = "云题库" if hnsw_index == 0 else "校本题库"
  90. self.logger.info(config.log_msg.format(id=update_id,
  91. type="{}数据更新".format(db_name),
  92. message="数据更新完毕"))
  93. except Exception as e:
  94. # 日志采集
  95. if self.logger is not None:
  96. self.logger.error(self.log_msg.format(id="HNSW更新error",
  97. type="当前题目HNSW更新error",
  98. message=update_id))
  99. # HNSW查(支持多学科混合查重)
  100. def retrieve(self, retrieve_list, sent_vec_list, cont_clear_list, hnsw_index):
  101. retrieve_res_list = []
  102. # 遍历检索查重数据
  103. for i,query_vec in enumerate(sent_vec_list):
  104. # 判断句向量维度
  105. if "subject_id" not in retrieve_list[i] or query_vec.size != self.vector_dim:
  106. retrieve_res_list.append([])
  107. continue
  108. subject_id = int(retrieve_list[i]["subject_id"])
  109. # 调用hnsw接口检索数据
  110. post_dict = dict(query_vec=query_vec.tolist(), hnsw_index=hnsw_index)
  111. try:
  112. query_labels = requests.post(self.hnsw_retrieve_url, json=post_dict, timeout=10).json()
  113. except Exception as e:
  114. query_labels = []
  115. # 日志采集
  116. if self.logger is not None:
  117. self.logger.error(self.log_msg.format(id="HNSW检索error",
  118. type="当前题目HNSW检索error",
  119. message=retrieve_list[i]["topic_id"]))
  120. if len(query_labels) == 0:
  121. retrieve_res_list.append([])
  122. continue
  123. # 批量读取数据库
  124. mongo_find_dict = {"topic_id": {"$in": query_labels}, "subject_id": subject_id}
  125. # 通过{"is_stop": 0}来过滤被删除题目的topic_id
  126. if hnsw_index == 0:
  127. mongo_find_dict["is_stop"] = 0
  128. # 增加题型判断,准确定位重题
  129. if "topic_type_id" in retrieve_list[i]:
  130. mongo_find_dict["topic_type_id"] = int(retrieve_list[i]["topic_type_id"])
  131. query_dataset = self.mongo_coll[hnsw_index].find(mongo_find_dict)
  132. # 返回大于阈值的结果
  133. cos_threshold = self.database_threshold[hnsw_index][0]
  134. fuzz_threshold = self.database_threshold[hnsw_index][1]
  135. retrieve_value_list = []
  136. for label_data in query_dataset:
  137. # 防止出现重复topic_id, 预先进行过滤
  138. if "sentence_vec" not in label_data:
  139. continue
  140. # 计算余弦相似度得分
  141. label_vec = pickle.loads(label_data["sentence_vec"])
  142. if label_vec.size != self.vector_dim:
  143. continue
  144. cosine_score = util.cos_sim(query_vec, label_vec)[0][0]
  145. # 阈值判断
  146. if cosine_score < cos_threshold:
  147. continue
  148. # 对于学科进行编辑距离(fuzzywuzzy-200字符)验证,若小于设定分则过滤
  149. if fuzz.ratio(cont_clear_list[i][:200], label_data["content_clear"][:200]) / 100 < fuzz_threshold:
  150. continue
  151. retrieve_value = [label_data["topic_id"], int(cosine_score*100)/100]
  152. retrieve_value_list.append(retrieve_value)
  153. # 将组合结果按照score降序排序并取得分前十个结果
  154. score_sort_list = sorted(retrieve_value_list, key=lambda x: x[1], reverse=True)[:20]
  155. # 以列表形式返回最终查重结果
  156. retrieve_res_list.append(score_sort_list)
  157. # 日志采集
  158. if self.logger is not None:
  159. self.logger.info(self.log_msg.format(
  160. id="云题库查重" if hnsw_index == 0 else "校本题库查重",
  161. type="repeat检索" if hnsw_index == 0 else "chc检索",
  162. message=str({idx+1:ele for idx,ele in enumerate(retrieve_res_list)})))
  163. return retrieve_res_list
  164. if __name__ == "__main__":
  165. # 获取mongodb数据
  166. mongo_coll = config.mongo_coll
  167. hnsw_logic = Hnsw_Logic()
  168. test_data = []
  169. for topic_id in [201511100832270]:
  170. test_data.append(mongo_coll.find_one({"topic_id":topic_id}))
  171. res = hnsw_logic.retrieve(test_data, hnsw_index=0)
  172. pprint(res)