123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115 |
- import os
- import gc
- import time
- import pickle
- import hnswlib
- import config
- from data_preprocessing import shuffle_data_pair, DataPreProcessing
- from restart_server import server_killer, restart_hnsw_app
- class HNSW_Model_Train():
- def __init__(self):
- # 配置初始数据
- self.mongo_coll = config.mongo_coll_list
- self.dpp = DataPreProcessing(self.mongo_coll, is_train=True)
- # 训练HNSW模型并重启服务(尽量在ASR_app.py每天定时保存当前运行模型之后运行)
- def __call__(self):
- # 关闭服务进程
- server_killer(port=8858)
- # 清空hnsw_update_data.txt中更新数据
- with open(config.hnsw_update_save_path, 'w', encoding='utf8') as f:
- f.write("")
- # 全学科hnsw模型训练
- start = time.time()
- for hnsw_index in range(config.hnsw_num):
- self.subject_model_train(hnsw_index)
- # 日志采集
- print(config.log_msg.format(id=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
- type="HNSW模型训练总计耗时:",
- message=time.time()-start))
- # 模型训练结束后重启服务
- self.restart_server()
- # 全学科hnsw模型训练
- def subject_model_train(self, hnsw_index):
- if hnsw_index == 0:
- self.clear_embedding_update(hnsw_index)
- find_dict = {"is_stop": 0} if hnsw_index == 0 else {}
- origin_dataset = self.mongo_coll[hnsw_index].find(find_dict, no_cursor_timeout=True, batch_size=5)
- self.hnsw_train(origin_dataset, config.hnsw_path_list[hnsw_index])
- # 训练HNSW模型
- def hnsw_train(self, origin_dataset, hnsw_path):
- start0 = time.time()
-
- # 初始化HNSW搜索图
- # possible options are l2, cosine or ip
- hnsw_p = hnswlib.Index(space = config.hnsw_metric, dim = config.vector_dim)
- # 初始化HNSW索引及相关参数
- hnsw_p.init_index(max_elements = config.num_elements, ef_construction = 200, M = 16)
- # ef要大于召回数据数量
- hnsw_p.set_ef(config.hnsw_set_ef)
- # 设置线程数量-during batch search/construction
- hnsw_p.set_num_threads(4)
- # 从mongodb获取句向量和标签
- idx_list, vec_list = [], []
- for data_idx,data in enumerate(origin_dataset):
- if "sentence_vec" not in data:
- continue
- sentence_vec = pickle.loads(data["sentence_vec"])
- if sentence_vec.size != config.vector_dim:
- continue
- idx_list.append(data["topic_id"])
- vec_list.append(sentence_vec)
- # 设置批量处理长度,若满足条件则进行批量处理
- if (data_idx+1) % 500000 == 0:
- # 按数据对应顺序随机打乱数据
- idx_list, vec_list = shuffle_data_pair(idx_list, vec_list)
- # 将数据进行HNSW构图
- hnsw_p.add_items(vec_list, idx_list)
- idx_list, vec_list = [], []
- # 将句向量加入到HNSW
- if len(idx_list) > 0:
- # 按数据对应顺序随机打乱数据
- idx_list, vec_list = shuffle_data_pair(idx_list, vec_list)
- # 将数据进行HNSW构图
- hnsw_p.add_items(vec_list, idx_list)
-
- # 保存HNSW图模型
- # 注意:HNSW需要进行路径管理-------------------------------------tjt
- os.chdir(config.data_root_path)
- hnsw_p.save_index(hnsw_path)
- os.chdir(config.root_path)
- # 注意:HNSW需要进行路径管理-------------------------------------tjt
- time.sleep(5)
- # 无效变量内存回收
- del hnsw_p
- gc.collect()
- # 日志采集
- print(config.log_msg.format(id=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
- type=hnsw_path+"模型训练耗时:",
- message=time.time()-start0))
- # 数据清洗与句向量计算
- def clear_embedding_update(self, hnsw_index):
- find_dict = {"sent_train_flag": {"$exists": 0}}
- if hnsw_index == 0:
- find_dict["is_stop"] = 0
- origin_dataset = self.mongo_coll[hnsw_index].find(find_dict, no_cursor_timeout=True, batch_size=5)
- self.dpp(origin_dataset, hnsw_index)
- # 模型训练结束后重启服务
- def restart_server(self):
- restart_hnsw_app()
- # 日志采集
- print(config.log_msg.format(id=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
- type="重启hnsw_app服务",
- message="HNSW模型训练完毕, 已重新启动hnsw_app服务"))
- if __name__ == "__main__":
- hm_train = HNSW_Model_Train()
- hm_train()
|