123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596 |
- import os
- import pickle
- import hnswlib
- import numpy as np
- import config
- class HNSW():
- def __init__(self, logger=None):
- # 配置初始数据
- self.mongo_coll = config.mongo_coll_list
- self.vector_dim = config.vector_dim
- # 日志采集
- self.logger = logger
- self.log_msg = config.log_msg
- self.hnsw_list = [self.load_hnsw(hnsw_path)
- for hnsw_path in config.hnsw_path_list]
- # hnsw模型更新标志列表
- self.hnsw_update_flag_list = [0 for _ in range(config.hnsw_num)]
- # HNSW启动更新保存数据(以防程序错误导致模型数据丢失,读取待更新数据文件)
- self.hnsw_start_update_save()
-
- # 加载HNSW图模型
- def load_hnsw(self, hnsw_path):
- #加载hnsw需要重新定义hnsw_p以及set_ef-------tjt
- # 初始化HNSW搜索图
- hnsw_p = hnswlib.Index(space=config.hnsw_metric, dim=self.vector_dim)
- # 加载HNSW模型数据,可重新修改最大元素数量(max_elements)
- # 注意:HNSW需要进行路径管理-------------------------------------tjt
- os.chdir(config.data_root_path)
- hnsw_p.load_index(hnsw_path)
- os.chdir(config.root_path)
- # 注意:HNSW需要进行路径管理-------------------------------------tjt
- hnsw_p.set_ef(config.hnsw_set_ef) # ef should always be > k
-
- return hnsw_p
- # 将更新结果存入HNSW模型
- def save_hnsw(self):
- # 清空hnsw_update_data.txt中更新数据
- with open(config.hnsw_update_save_path, 'w', encoding='utf8') as f:
- f.write("")
- # 注意:HNSW需要进行路径管理-------------------------------------tjt
- os.chdir(config.data_root_path)
- for i,flag in enumerate(self.hnsw_update_flag_list):
- # 恢复hnsw_update_flag_list初始状态
- self.hnsw_update_flag_list[i] = 0
- self.hnsw_list[i].save_index(config.hnsw_path_list[i]) if flag == 1 else None
- os.chdir(config.root_path)
- # 注意:HNSW需要进行路径管理-------------------------------------tjt
-
- # HNSW启动更新保存数据(以防程序错误导致模型数据丢失,读取待更新数据文件)
- def hnsw_start_update_save(self):
- # 判断hnsw_update_data.txt是否存在
- if os.path.exists(config.hnsw_update_save_path) is False:
- return
- # 读取hnsw_update_data.txt进行处理
- with open(config.hnsw_update_save_path, 'r', encoding='utf8') as f:
- hnsw_update_data = f.read()
- if hnsw_update_data == "":
- return
- hnsw_update_list = [eval(ele) for ele in hnsw_update_data.split("\n") if ele != ""]
- for update_dict in hnsw_update_list:
- self.update(update_dict["id"], update_dict["hnsw_index"])
- # 保存HNSW模型
- self.save_hnsw()
- # HNSW增/改
- def update(self, update_id, hnsw_index):
- # 从数据库读取新增数据
- update_data = self.mongo_coll[hnsw_index].find_one({"topic_id": update_id})
- sent_vec = pickle.loads(update_data["sentence_vec"])
- # 将新增/修改数据构图
- if sent_vec.size == self.vector_dim:
- self.hnsw_list[hnsw_index].add_items(sent_vec, update_id)
- # hnsw模型更新标志列表
- self.hnsw_update_flag_list[hnsw_index] = 1
- # HNSW查(支持多学科混合查重)
- def retrieve(self, query_vec, hnsw_index, k_num=50):
- try:
- query_labels, _ = self.hnsw_list[hnsw_index].knn_query(query_vec, k_num)
- except Exception as e:
- try:
- query_labels, _ = self.hnsw_list[hnsw_index].knn_query(query_vec, k=10)
- except Exception as e:
- query_labels = np.array([[]])
- # 日志采集
- if self.logger is not None:
- self.logger.error(self.log_msg.format(id="HNSW检索error(knn_query)",
- type="当前题目HNSW检索error",
- message="knn_query-topK数量不匹配"))
-
- return query_labels[0].tolist()
|