hnsw_model.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. import os
  2. import hnswlib
  3. import numpy as np
  4. import config
  5. class HNSW():
  6. def __init__(self, logger=None):
  7. # 配置初始数据
  8. self.vector_dim = config.vector_dim
  9. # 日志采集
  10. self.logger = logger
  11. self.log_msg = config.log_msg
  12. # 加载HNSW图模型
  13. self.hnsw_p = self.load_hnsw(config.hnsw_path)
  14. # 加载HNSW图模型
  15. def load_hnsw(self, hnsw_path):
  16. #加载hnsw需要重新定义hnsw_p以及set_ef-------tjt
  17. # 初始化HNSW搜索图
  18. hnsw_p = hnswlib.Index(space=config.hnsw_metric, dim=self.vector_dim)
  19. # 加载HNSW模型数据,可重新修改最大元素数量(max_elements)
  20. # 注意:HNSW需要进行路径管理-------------------------------------tjt
  21. os.chdir(config.data_root_path)
  22. hnsw_p.load_index(hnsw_path)
  23. os.chdir(config.root_path)
  24. # 注意:HNSW需要进行路径管理-------------------------------------tjt
  25. hnsw_p.set_ef(config.hnsw_set_ef) # ef should always be > k
  26. return hnsw_p
  27. # HNSW查(支持多学科混合查重)
  28. def retrieve(self, query_vec, k_num=50):
  29. try:
  30. query_labels, _ = self.hnsw_p.knn_query(query_vec, k_num)
  31. except Exception as e:
  32. try:
  33. query_labels, _ = self.hnsw_p.knn_query(query_vec, k=10)
  34. except Exception as e:
  35. query_labels = np.array([[]])
  36. # 日志采集
  37. if self.logger is not None:
  38. self.logger.error(self.log_msg.format(id="HNSW检索error(knn_query)",
  39. type="当前题目HNSW检索error",
  40. message="knn_query-topK数量不匹配"))
  41. return query_labels[0].tolist()