123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101 |
- import os
- import time
- import pickle
- import hnswlib
- import config
- from data_preprocessing import shuffle_data_pair, DataPreProcessing
- class HNSW_Model_Train():
- def __init__(self, logger=None):
- # 配置初始数据
- self.mongo_coll = config.mongo_coll
- self.dpp = DataPreProcessing(self.mongo_coll, is_train=True)
- # # bert-whitening参数
- # with open(config.whitening_path, 'rb') as f:
- # self.kernel, self.bias = pickle.load(f)
- # 日志采集
- self.logger = logger
- self.log_msg = config.log_msg
- # 训练HNSW模型并重启服务
- def __call__(self):
- # 全学科hnsw模型训练
- start = time.time()
- # 全学科hnsw模型训练
- self.subject_model_train()
- # 日志采集
- self.logger.info(self.log_msg.format(
- id=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
- type="HNSW模型训练总计耗时:",
- message=time.time()-start)) if self.logger else None
- # 模型训练结束后重启服务
- print(self.log_msg.format(
- id=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
- type="HNSW模型训练总计耗时:",
- message=time.time()-start)) if self.logger is None else None
- # 全学科hnsw模型训练
- def subject_model_train(self):
- self.clear_embedding_update()
- origin_dataset = self.mongo_coll.find(no_cursor_timeout=True, batch_size=5)
- self.hnsw_train(origin_dataset, config.hnsw_path)
- # 训练HNSW模型
- def hnsw_train(self, origin_dataset, hnsw_path):
- start0 = time.time()
- idx_list = []
- vec_list = []
- for data in origin_dataset:
- if "sentence_vec" not in data:
- continue
- sentence_vec = pickle.loads(data["sentence_vec"])
- if sentence_vec.size != config.vector_dim:
- continue
- # sentence_vec = (sentence_vec + self.bias).dot(self.kernel).reshape(-1)
- idx_list.append(data["id"])
- vec_list.append(sentence_vec)
-
- # 初始化HNSW搜索图
- # possible options are l2, cosine or ip
- hnsw_p = hnswlib.Index(space = config.hnsw_metric, dim = config.vector_dim)
- # 初始化HNSW索引及相关参数
- hnsw_p.init_index(max_elements = config.num_elements, ef_construction = 200, M = 16)
- # ef要大于召回数据数量
- hnsw_p.set_ef(config.hnsw_set_ef)
- # 设置线程数量-during batch search/construction
- hnsw_p.set_num_threads(4)
- # 将句向量加入到HNSW
- if len(idx_list) > 0:
- # 按数据对应顺序随机打乱数据
- idx_list, vec_list = shuffle_data_pair(idx_list, vec_list)
- # 将数据进行HNSW构图
- hnsw_p.add_items(vec_list, idx_list)
- # 保存HNSW图模型
- # 注意:HNSW需要进行路径管理-------------------------------------tjt
- os.chdir(config.data_root_path)
- hnsw_p.save_index(hnsw_path)
- os.chdir(config.root_path)
- # 注意:HNSW需要进行路径管理-------------------------------------tjt
- # 日志采集
- self.logger.info(self.log_msg.format(
- id=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
- type=hnsw_path+"模型训练耗时:",
- message=time.time()-start0)) if self.logger else None
- print(self.log_msg.format(
- id=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
- type=hnsw_path+"模型训练耗时:",
- message=time.time()-start0)) if self.logger is None else None
- # 数据清洗与句向量计算
- def clear_embedding_update(self):
- find_dict = {"sent_train_flag": {"$exists": 0}}
- origin_dataset = self.mongo_coll.find(find_dict, no_cursor_timeout=True, batch_size=5)
- self.dpp(origin_dataset)
- if __name__ == "__main__":
- hm_train = HNSW_Model_Train()
- hm_train()
|