import os import time import pickle import hnswlib import config from data_preprocessing import shuffle_data_pair, DataPreProcessing class HNSW_Model_Train(): def __init__(self, logger=None): # 配置初始数据 self.mongo_coll = config.mongo_coll self.dpp = DataPreProcessing(self.mongo_coll, is_train=True) # # bert-whitening参数 # with open(config.whitening_path, 'rb') as f: # self.kernel, self.bias = pickle.load(f) # 日志采集 self.logger = logger self.log_msg = config.log_msg # 训练HNSW模型并重启服务 def __call__(self): # 全学科hnsw模型训练 start = time.time() # 全学科hnsw模型训练 self.subject_model_train() # 日志采集 self.logger.info(self.log_msg.format( id=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()), type="HNSW模型训练总计耗时:", message=time.time()-start)) if self.logger else None # 模型训练结束后重启服务 print(self.log_msg.format( id=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()), type="HNSW模型训练总计耗时:", message=time.time()-start)) if self.logger is None else None # 全学科hnsw模型训练 def subject_model_train(self): self.clear_embedding_update() origin_dataset = self.mongo_coll.find(no_cursor_timeout=True, batch_size=5) self.hnsw_train(origin_dataset, config.hnsw_path) # 训练HNSW模型 def hnsw_train(self, origin_dataset, hnsw_path): start0 = time.time() idx_list = [] vec_list = [] for data in origin_dataset: if "sentence_vec" not in data: continue sentence_vec = pickle.loads(data["sentence_vec"]) if sentence_vec.size != config.vector_dim: continue # sentence_vec = (sentence_vec + self.bias).dot(self.kernel).reshape(-1) idx_list.append(data["id"]) vec_list.append(sentence_vec) # 初始化HNSW搜索图 # possible options are l2, cosine or ip hnsw_p = hnswlib.Index(space = config.hnsw_metric, dim = config.vector_dim) # 初始化HNSW索引及相关参数 hnsw_p.init_index(max_elements = config.num_elements, ef_construction = 200, M = 16) # ef要大于召回数据数量 hnsw_p.set_ef(config.hnsw_set_ef) # 设置线程数量-during batch search/construction hnsw_p.set_num_threads(4) # 将句向量加入到HNSW if len(idx_list) > 0: # 按数据对应顺序随机打乱数据 idx_list, vec_list = shuffle_data_pair(idx_list, vec_list) # 将数据进行HNSW构图 hnsw_p.add_items(vec_list, idx_list) # 保存HNSW图模型 # 注意:HNSW需要进行路径管理-------------------------------------tjt os.chdir(config.data_root_path) hnsw_p.save_index(hnsw_path) os.chdir(config.root_path) # 注意:HNSW需要进行路径管理-------------------------------------tjt # 日志采集 self.logger.info(self.log_msg.format( id=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()), type=hnsw_path+"模型训练耗时:", message=time.time()-start0)) if self.logger else None print(self.log_msg.format( id=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()), type=hnsw_path+"模型训练耗时:", message=time.time()-start0)) if self.logger is None else None # 数据清洗与句向量计算 def clear_embedding_update(self): find_dict = {"sent_train_flag": {"$exists": 0}} origin_dataset = self.mongo_coll.find(find_dict, no_cursor_timeout=True, batch_size=5) self.dpp(origin_dataset) if __name__ == "__main__": hm_train = HNSW_Model_Train() hm_train()