import time import pickle import requests from bson.binary import Binary from fuzzywuzzy import fuzz from sentence_transformers import util from pprint import pprint import config from data_preprocessing import DataPreProcessing class Hnsw_Logic(): def __init__(self, logger=None): # 配置初始数据 self.mongo_coll = config.mongo_coll_list self.vector_dim = config.vector_dim self.database_threshold = config.database_threshold self.hnsw_update_url = config.hnsw_update_url self.hnsw_retrieve_url = config.hnsw_retrieve_url # 日志采集 self.logger = logger self.log_msg = config.log_msg # 数据预处理初始化 self.dpp = DataPreProcessing(self.mongo_coll, self.logger) # HNSW查询逻辑判断 def logic_process(self, retrieve_list, hnsw_index): # 调用清洗分词函数和句向量计算函数 sent_vec_list, cont_clear_list = self.dpp(retrieve_list, hnsw_index, is_retrieve=True) # 云题库HNSW查重 cloud_list = self.retrieve(retrieve_list, sent_vec_list, cont_clear_list, hnsw_index=0) if hnsw_index == 0: return cloud_list # 校本题库HNSW查重 school_list = self.retrieve(retrieve_list, sent_vec_list, cont_clear_list, hnsw_index=1) # 遍历retrieve_list, 将数据插入mongodb for i,data in enumerate(retrieve_list): topic_id = data["topic_id"] # 判断相似度并将符合要求的数据存入mongodb if len(school_list[i]) > 0 and school_list[i][0][1] > 0.97: continue # 判断清洗文本长度,若长度小于10,则表示清洗失败需要过滤 if len(cont_clear_list[i]) < 10: continue # 防止出现重复的topic_id if self.mongo_coll[1].find_one({"topic_id": int(topic_id)}) is None: try: self.school_mongodb_insert(data, sent_vec_list[i], cont_clear_list[i]) # 将数据实时更新至hnsw模型 self.update(int(topic_id), hnsw_index=1) except Exception as e: # 日志采集 self.logger.info(self.log_msg.format(id=int(topic_id), type="chc retrieval insert", message="chc查重数据插入失败-"+str(e))) return cloud_list, school_list # 将校本题库接收数据插入mongodb def school_mongodb_insert(self, data, sentence_vec, content_clear): sentence_vec_byte = Binary(pickle.dumps(sentence_vec, protocol=-1), subtype=128) insert_dict = { 'topic_id': int(data['topic_id']), 'content_raw': data['stem'] if 'stem' in data else '', 'content': content_clear, 'content_clear': content_clear, 'sentence_vec': sentence_vec_byte, 'sent_train_flag': config.sent_train_flag, 'topic_type_id': int(data['topic_type_id']) if 'topic_type_id' in data else 0, 'school_id': int(data['school_id']) if 'school_id' in data else 0, 'parse': data['parse'] if 'parse' in data else '', 'answer': data['answer'] if 'answer' in data else '', 'save_time': time.time(), 'subject_id': int(data['subject_id']) if 'subject_id' in data else 0 } # 将数据插入mongodb self.mongo_coll[1].insert_one(insert_dict) # 日志采集 self.logger.info(self.log_msg.format(id=int(data['topic_id']), type="chc retrieval insert", message="已将chc查重数据插入mongo_coll_school")) # HNSW增/改 def update(self, update_id, hnsw_index): if hnsw_index == 0: # 数据清洗、分词与句向量化 cloud_data = self.mongo_coll[hnsw_index].find_one({"topic_id": update_id}) self.dpp([cloud_data], hnsw_index=0) # 调用hnsw接口更新数据 update_dict = dict(id=update_id, hnsw_index=hnsw_index) try: requests.post(self.hnsw_update_url, json=update_dict, timeout=10) # 日志采集 if self.logger is not None: db_name = "云题库" if hnsw_index == 0 else "校本题库" self.logger.info(config.log_msg.format(id=update_id, type="{}数据更新".format(db_name), message="数据更新完毕")) except Exception as e: # 日志采集 if self.logger is not None: self.logger.error(self.log_msg.format(id="HNSW更新error", type="当前题目HNSW更新error", message=update_id)) # HNSW查(支持多学科混合查重) def retrieve(self, retrieve_list, sent_vec_list, cont_clear_list, hnsw_index): retrieve_res_list = [] # 遍历检索查重数据 for i,query_vec in enumerate(sent_vec_list): # 判断句向量维度 if "subject_id" not in retrieve_list[i] or query_vec.size != self.vector_dim: retrieve_res_list.append([]) continue subject_id = int(retrieve_list[i]["subject_id"]) # 调用hnsw接口检索数据 post_dict = dict(query_vec=query_vec.tolist(), hnsw_index=hnsw_index) try: query_labels = requests.post(self.hnsw_retrieve_url, json=post_dict, timeout=10).json() except Exception as e: query_labels = [] # 日志采集 if self.logger is not None: self.logger.error(self.log_msg.format(id="HNSW检索error", type="当前题目HNSW检索error", message=retrieve_list[i]["topic_id"])) if len(query_labels) == 0: retrieve_res_list.append([]) continue # 批量读取数据库 mongo_find_dict = {"topic_id": {"$in": query_labels}, "subject_id": subject_id} # 通过{"is_stop": 0}来过滤被删除题目的topic_id if hnsw_index == 0: mongo_find_dict["is_stop"] = 0 # 增加题型判断,准确定位重题 if "topic_type_id" in retrieve_list[i]: mongo_find_dict["topic_type_id"] = int(retrieve_list[i]["topic_type_id"]) query_dataset = self.mongo_coll[hnsw_index].find(mongo_find_dict) # 返回大于阈值的结果 cos_threshold = self.database_threshold[hnsw_index][0] fuzz_threshold = self.database_threshold[hnsw_index][1] retrieve_value_list = [] for label_data in query_dataset: # 防止出现重复topic_id, 预先进行过滤 if "sentence_vec" not in label_data: continue # 计算余弦相似度得分 label_vec = pickle.loads(label_data["sentence_vec"]) if label_vec.size != self.vector_dim: continue cosine_score = util.cos_sim(query_vec, label_vec)[0][0] # 阈值判断 if cosine_score < cos_threshold: continue # 对于学科进行编辑距离(fuzzywuzzy-200字符)验证,若小于设定分则过滤 if fuzz.ratio(cont_clear_list[i][:200], label_data["content_clear"][:200]) / 100 < fuzz_threshold: continue retrieve_value = [label_data["topic_id"], int(cosine_score*100)/100] retrieve_value_list.append(retrieve_value) # 将组合结果按照score降序排序并取得分前十个结果 score_sort_list = sorted(retrieve_value_list, key=lambda x: x[1], reverse=True)[:20] # 以列表形式返回最终查重结果 retrieve_res_list.append(score_sort_list) # 日志采集 if self.logger is not None: self.logger.info(self.log_msg.format( id="云题库查重" if hnsw_index == 0 else "校本题库查重", type="repeat检索" if hnsw_index == 0 else "chc检索", message=str({idx+1:ele for idx,ele in enumerate(retrieve_res_list)}))) return retrieve_res_list if __name__ == "__main__": # 获取mongodb数据 mongo_coll = config.mongo_coll hnsw_logic = Hnsw_Logic() test_data = [] for topic_id in [201511100832270]: test_data.append(mongo_coll.find_one({"topic_id":topic_id})) res = hnsw_logic.retrieve(test_data, hnsw_index=0) pprint(res)