import os import gc import time import pickle import hnswlib import config from data_preprocessing import shuffle_data_pair, DataPreProcessing from restart_server import server_killer, restart_hnsw_app class HNSW_Model_Train(): def __init__(self): # 配置初始数据 self.mongo_coll = config.mongo_coll_list self.dpp = DataPreProcessing(self.mongo_coll, is_train=True) # 训练HNSW模型并重启服务(尽量在ASR_app.py每天定时保存当前运行模型之后运行) def __call__(self): # 关闭服务进程 server_killer(port=8858) # 清空hnsw_update_data.txt中更新数据 with open(config.hnsw_update_save_path, 'w', encoding='utf8') as f: f.write("") # 全学科hnsw模型训练 start = time.time() for hnsw_index in range(config.hnsw_num): self.subject_model_train(hnsw_index) # 日志采集 print(config.log_msg.format(id=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()), type="HNSW模型训练总计耗时:", message=time.time()-start)) # 模型训练结束后重启服务 self.restart_server() # 全学科hnsw模型训练 def subject_model_train(self, hnsw_index): if hnsw_index == 0: self.clear_embedding_update(hnsw_index) find_dict = {"is_stop": 0} if hnsw_index == 0 else {} origin_dataset = self.mongo_coll[hnsw_index].find(find_dict, no_cursor_timeout=True, batch_size=5) self.hnsw_train(origin_dataset, config.hnsw_path_list[hnsw_index]) # 训练HNSW模型 def hnsw_train(self, origin_dataset, hnsw_path): start0 = time.time() # 初始化HNSW搜索图 # possible options are l2, cosine or ip hnsw_p = hnswlib.Index(space = config.hnsw_metric, dim = config.vector_dim) # 初始化HNSW索引及相关参数 hnsw_p.init_index(max_elements = config.num_elements, ef_construction = 200, M = 16) # ef要大于召回数据数量 hnsw_p.set_ef(config.hnsw_set_ef) # 设置线程数量-during batch search/construction hnsw_p.set_num_threads(4) # 从mongodb获取句向量和标签 idx_list, vec_list = [], [] for data_idx,data in enumerate(origin_dataset): if "sentence_vec" not in data: continue sentence_vec = pickle.loads(data["sentence_vec"]) if sentence_vec.size != config.vector_dim: continue idx_list.append(data["topic_id"]) vec_list.append(sentence_vec) # 设置批量处理长度,若满足条件则进行批量处理 if (data_idx+1) % 500000 == 0: # 按数据对应顺序随机打乱数据 idx_list, vec_list = shuffle_data_pair(idx_list, vec_list) # 将数据进行HNSW构图 hnsw_p.add_items(vec_list, idx_list) idx_list, vec_list = [], [] # 将句向量加入到HNSW if len(idx_list) > 0: # 按数据对应顺序随机打乱数据 idx_list, vec_list = shuffle_data_pair(idx_list, vec_list) # 将数据进行HNSW构图 hnsw_p.add_items(vec_list, idx_list) # 保存HNSW图模型 # 注意:HNSW需要进行路径管理-------------------------------------tjt os.chdir(config.data_root_path) hnsw_p.save_index(hnsw_path) os.chdir(config.root_path) # 注意:HNSW需要进行路径管理-------------------------------------tjt time.sleep(5) # 无效变量内存回收 del hnsw_p gc.collect() # 日志采集 print(config.log_msg.format(id=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()), type=hnsw_path+"模型训练耗时:", message=time.time()-start0)) # 数据清洗与句向量计算 def clear_embedding_update(self, hnsw_index): find_dict = {"sent_train_flag": {"$exists": 0}} if hnsw_index == 0: find_dict["is_stop"] = 0 origin_dataset = self.mongo_coll[hnsw_index].find(find_dict, no_cursor_timeout=True, batch_size=5) self.dpp(origin_dataset, hnsw_index) # 模型训练结束后重启服务 def restart_server(self): restart_hnsw_app() # 日志采集 print(config.log_msg.format(id=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()), type="重启hnsw_app服务", message="HNSW模型训练完毕, 已重新启动hnsw_app服务")) if __name__ == "__main__": hm_train = HNSW_Model_Train() hm_train()