import os import hnswlib import numpy as np import config class HNSW(): def __init__(self, logger=None): # 配置初始数据 self.vector_dim = config.vector_dim # 日志采集 self.logger = logger self.log_msg = config.log_msg # 加载HNSW图模型 self.hnsw_p = self.load_hnsw(config.hnsw_path) # 加载HNSW图模型 def load_hnsw(self, hnsw_path): #加载hnsw需要重新定义hnsw_p以及set_ef-------tjt # 初始化HNSW搜索图 hnsw_p = hnswlib.Index(space=config.hnsw_metric, dim=self.vector_dim) # 加载HNSW模型数据,可重新修改最大元素数量(max_elements) # 注意:HNSW需要进行路径管理-------------------------------------tjt os.chdir(config.data_root_path) hnsw_p.load_index(hnsw_path) os.chdir(config.root_path) # 注意:HNSW需要进行路径管理-------------------------------------tjt hnsw_p.set_ef(config.hnsw_set_ef) # ef should always be > k return hnsw_p # HNSW查(支持多学科混合查重) def retrieve(self, query_vec, k_num=50): try: query_labels, _ = self.hnsw_p.knn_query(query_vec, k_num) except Exception as e: try: query_labels, _ = self.hnsw_p.knn_query(query_vec, k=10) except Exception as e: query_labels = np.array([[]]) # 日志采集 if self.logger is not None: self.logger.error(self.log_msg.format(id="HNSW检索error(knn_query)", type="当前题目HNSW检索error", message="knn_query-topK数量不匹配")) return query_labels[0].tolist()