import os import pickle import hnswlib import numpy as np import config class HNSW(): def __init__(self, logger=None): # 配置初始数据 self.mongo_coll = config.mongo_coll_list self.vector_dim = config.vector_dim # 日志采集 self.logger = logger self.log_msg = config.log_msg self.hnsw_list = [self.load_hnsw(hnsw_path) for hnsw_path in config.hnsw_path_list] # hnsw模型更新标志列表 self.hnsw_update_flag_list = [0 for _ in range(config.hnsw_num)] # HNSW启动更新保存数据(以防程序错误导致模型数据丢失,读取待更新数据文件) self.hnsw_start_update_save() # 加载HNSW图模型 def load_hnsw(self, hnsw_path): #加载hnsw需要重新定义hnsw_p以及set_ef-------tjt # 初始化HNSW搜索图 hnsw_p = hnswlib.Index(space=config.hnsw_metric, dim=self.vector_dim) # 加载HNSW模型数据,可重新修改最大元素数量(max_elements) # 注意:HNSW需要进行路径管理-------------------------------------tjt os.chdir(config.data_root_path) hnsw_p.load_index(hnsw_path) os.chdir(config.root_path) # 注意:HNSW需要进行路径管理-------------------------------------tjt hnsw_p.set_ef(config.hnsw_set_ef) # ef should always be > k return hnsw_p # 将更新结果存入HNSW模型 def save_hnsw(self): # 清空hnsw_update_data.txt中更新数据 with open(config.hnsw_update_save_path, 'w', encoding='utf8') as f: f.write("") # 注意:HNSW需要进行路径管理-------------------------------------tjt os.chdir(config.data_root_path) for i,flag in enumerate(self.hnsw_update_flag_list): # 恢复hnsw_update_flag_list初始状态 self.hnsw_update_flag_list[i] = 0 self.hnsw_list[i].save_index(config.hnsw_path_list[i]) if flag == 1 else None os.chdir(config.root_path) # 注意:HNSW需要进行路径管理-------------------------------------tjt # HNSW启动更新保存数据(以防程序错误导致模型数据丢失,读取待更新数据文件) def hnsw_start_update_save(self): # 判断hnsw_update_data.txt是否存在 if os.path.exists(config.hnsw_update_save_path) is False: return # 读取hnsw_update_data.txt进行处理 with open(config.hnsw_update_save_path, 'r', encoding='utf8') as f: hnsw_update_data = f.read() if hnsw_update_data == "": return hnsw_update_list = [eval(ele) for ele in hnsw_update_data.split("\n") if ele != ""] for update_dict in hnsw_update_list: self.update(update_dict["id"], update_dict["hnsw_index"]) # 保存HNSW模型 self.save_hnsw() # HNSW增/改 def update(self, update_id, hnsw_index): # 从数据库读取新增数据 update_data = self.mongo_coll[hnsw_index].find_one({"topic_id": update_id}) sent_vec = pickle.loads(update_data["sentence_vec"]) # 将新增/修改数据构图 if sent_vec.size == self.vector_dim: self.hnsw_list[hnsw_index].add_items(sent_vec, update_id) # hnsw模型更新标志列表 self.hnsw_update_flag_list[hnsw_index] = 1 # HNSW查(支持多学科混合查重) def retrieve(self, query_vec, hnsw_index, k_num=50): try: query_labels, _ = self.hnsw_list[hnsw_index].knn_query(query_vec, k_num) except Exception as e: try: query_labels, _ = self.hnsw_list[hnsw_index].knn_query(query_vec, k=10) except Exception as e: query_labels = np.array([[]]) # 日志采集 if self.logger is not None: self.logger.error(self.log_msg.format(id="HNSW检索error(knn_query)", type="当前题目HNSW检索error", message="knn_query-topK数量不匹配")) return query_labels[0].tolist()