123456789101112131415161718192021222324252627282930313233343536373839404142434445464748 |
- import os
- import hnswlib
- import numpy as np
- import config
- class HNSW():
- def __init__(self, logger=None):
- # 配置初始数据
- self.vector_dim = config.vector_dim
- # 日志采集
- self.logger = logger
- self.log_msg = config.log_msg
- # 加载HNSW图模型
- self.hnsw_p = self.load_hnsw(config.hnsw_path)
-
- # 加载HNSW图模型
- def load_hnsw(self, hnsw_path):
- #加载hnsw需要重新定义hnsw_p以及set_ef-------tjt
- # 初始化HNSW搜索图
- hnsw_p = hnswlib.Index(space=config.hnsw_metric, dim=self.vector_dim)
- # 加载HNSW模型数据,可重新修改最大元素数量(max_elements)
- # 注意:HNSW需要进行路径管理-------------------------------------tjt
- os.chdir(config.data_root_path)
- hnsw_p.load_index(hnsw_path)
- os.chdir(config.root_path)
- # 注意:HNSW需要进行路径管理-------------------------------------tjt
- hnsw_p.set_ef(config.hnsw_set_ef) # ef should always be > k
-
- return hnsw_p
- # HNSW查(支持多学科混合查重)
- def retrieve(self, query_vec, k_num=50):
- try:
- query_labels, _ = self.hnsw_p.knn_query(query_vec, k_num)
- except Exception as e:
- try:
- query_labels, _ = self.hnsw_p.knn_query(query_vec, k=10)
- except Exception as e:
- query_labels = np.array([[]])
- # 日志采集
- if self.logger is not None:
- self.logger.error(self.log_msg.format(id="HNSW检索error(knn_query)",
- type="当前题目HNSW检索error",
- message="knn_query-topK数量不匹配"))
-
- return query_labels[0].tolist()
|