hnsw_model.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import os
  2. import pickle
  3. import hnswlib
  4. import numpy as np
  5. import config
  6. class HNSW():
  7. def __init__(self, logger=None):
  8. # 配置初始数据
  9. self.mongo_coll = config.mongo_coll_list
  10. self.vector_dim = config.vector_dim
  11. # 日志采集
  12. self.logger = logger
  13. self.log_msg = config.log_msg
  14. self.hnsw_list = [self.load_hnsw(hnsw_path)
  15. for hnsw_path in config.hnsw_path_list]
  16. # hnsw模型更新标志列表
  17. self.hnsw_update_flag_list = [0 for _ in range(config.hnsw_num)]
  18. # HNSW启动更新保存数据(以防程序错误导致模型数据丢失,读取待更新数据文件)
  19. self.hnsw_start_update_save()
  20. # 加载HNSW图模型
  21. def load_hnsw(self, hnsw_path):
  22. #加载hnsw需要重新定义hnsw_p以及set_ef-------tjt
  23. # 初始化HNSW搜索图
  24. hnsw_p = hnswlib.Index(space=config.hnsw_metric, dim=self.vector_dim)
  25. # 加载HNSW模型数据,可重新修改最大元素数量(max_elements)
  26. # 注意:HNSW需要进行路径管理-------------------------------------tjt
  27. os.chdir(config.data_root_path)
  28. hnsw_p.load_index(hnsw_path)
  29. os.chdir(config.root_path)
  30. # 注意:HNSW需要进行路径管理-------------------------------------tjt
  31. hnsw_p.set_ef(config.hnsw_set_ef) # ef should always be > k
  32. return hnsw_p
  33. # 将更新结果存入HNSW模型
  34. def save_hnsw(self):
  35. # 清空hnsw_update_data.txt中更新数据
  36. with open(config.hnsw_update_save_path, 'w', encoding='utf8') as f:
  37. f.write("")
  38. # 注意:HNSW需要进行路径管理-------------------------------------tjt
  39. os.chdir(config.data_root_path)
  40. for i,flag in enumerate(self.hnsw_update_flag_list):
  41. # 恢复hnsw_update_flag_list初始状态
  42. self.hnsw_update_flag_list[i] = 0
  43. self.hnsw_list[i].save_index(config.hnsw_path_list[i]) if flag == 1 else None
  44. os.chdir(config.root_path)
  45. # 注意:HNSW需要进行路径管理-------------------------------------tjt
  46. # HNSW启动更新保存数据(以防程序错误导致模型数据丢失,读取待更新数据文件)
  47. def hnsw_start_update_save(self):
  48. # 判断hnsw_update_data.txt是否存在
  49. if os.path.exists(config.hnsw_update_save_path) is False:
  50. return
  51. # 读取hnsw_update_data.txt进行处理
  52. with open(config.hnsw_update_save_path, 'r', encoding='utf8') as f:
  53. hnsw_update_data = f.read()
  54. if hnsw_update_data == "":
  55. return
  56. hnsw_update_list = [eval(ele) for ele in hnsw_update_data.split("\n") if ele != ""]
  57. for update_dict in hnsw_update_list:
  58. self.update(update_dict["id"], update_dict["hnsw_index"])
  59. # 保存HNSW模型
  60. self.save_hnsw()
  61. # HNSW增/改
  62. def update(self, update_id, hnsw_index):
  63. # 从数据库读取新增数据
  64. update_data = self.mongo_coll[hnsw_index].find_one({"topic_id": update_id})
  65. sent_vec = pickle.loads(update_data["sentence_vec"])
  66. # 将新增/修改数据构图
  67. if sent_vec.size == self.vector_dim:
  68. self.hnsw_list[hnsw_index].add_items(sent_vec, update_id)
  69. # hnsw模型更新标志列表
  70. self.hnsw_update_flag_list[hnsw_index] = 1
  71. # HNSW查(支持多学科混合查重)
  72. def retrieve(self, query_vec, hnsw_index, k_num=50):
  73. try:
  74. query_labels, _ = self.hnsw_list[hnsw_index].knn_query(query_vec, k_num)
  75. except Exception as e:
  76. try:
  77. query_labels, _ = self.hnsw_list[hnsw_index].knn_query(query_vec, k=10)
  78. except Exception as e:
  79. query_labels = np.array([[]])
  80. # 日志采集
  81. if self.logger is not None:
  82. self.logger.error(self.log_msg.format(id="HNSW检索error(knn_query)",
  83. type="当前题目HNSW检索error",
  84. message="knn_query-topK数量不匹配"))
  85. return query_labels[0].tolist()