hnsw_model_train.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import os
  2. import time
  3. import pickle
  4. import hnswlib
  5. import config
  6. from data_preprocessing import shuffle_data_pair, DataPreProcessing
  7. class HNSW_Model_Train():
  8. def __init__(self, logger=None):
  9. # 配置初始数据
  10. self.mongo_coll = config.mongo_coll
  11. self.dpp = DataPreProcessing(self.mongo_coll, is_train=True)
  12. # # bert-whitening参数
  13. # with open(config.whitening_path, 'rb') as f:
  14. # self.kernel, self.bias = pickle.load(f)
  15. # 日志采集
  16. self.logger = logger
  17. self.log_msg = config.log_msg
  18. # 训练HNSW模型并重启服务
  19. def __call__(self):
  20. # 全学科hnsw模型训练
  21. start = time.time()
  22. # 全学科hnsw模型训练
  23. self.subject_model_train()
  24. # 日志采集
  25. self.logger.info(self.log_msg.format(
  26. id=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
  27. type="HNSW模型训练总计耗时:",
  28. message=time.time()-start)) if self.logger else None
  29. # 模型训练结束后重启服务
  30. print(self.log_msg.format(
  31. id=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
  32. type="HNSW模型训练总计耗时:",
  33. message=time.time()-start)) if self.logger is None else None
  34. # 全学科hnsw模型训练
  35. def subject_model_train(self):
  36. self.clear_embedding_update()
  37. origin_dataset = self.mongo_coll.find(no_cursor_timeout=True, batch_size=5)
  38. self.hnsw_train(origin_dataset, config.hnsw_path)
  39. # 训练HNSW模型
  40. def hnsw_train(self, origin_dataset, hnsw_path):
  41. start0 = time.time()
  42. idx_list = []
  43. vec_list = []
  44. for data in origin_dataset:
  45. if "sentence_vec" not in data:
  46. continue
  47. sentence_vec = pickle.loads(data["sentence_vec"])
  48. if sentence_vec.size != config.vector_dim:
  49. continue
  50. # sentence_vec = (sentence_vec + self.bias).dot(self.kernel).reshape(-1)
  51. idx_list.append(data["id"])
  52. vec_list.append(sentence_vec)
  53. # 初始化HNSW搜索图
  54. # possible options are l2, cosine or ip
  55. hnsw_p = hnswlib.Index(space = config.hnsw_metric, dim = config.vector_dim)
  56. # 初始化HNSW索引及相关参数
  57. hnsw_p.init_index(max_elements = config.num_elements, ef_construction = 200, M = 16)
  58. # ef要大于召回数据数量
  59. hnsw_p.set_ef(config.hnsw_set_ef)
  60. # 设置线程数量-during batch search/construction
  61. hnsw_p.set_num_threads(4)
  62. # 将句向量加入到HNSW
  63. if len(idx_list) > 0:
  64. # 按数据对应顺序随机打乱数据
  65. idx_list, vec_list = shuffle_data_pair(idx_list, vec_list)
  66. # 将数据进行HNSW构图
  67. hnsw_p.add_items(vec_list, idx_list)
  68. # 保存HNSW图模型
  69. # 注意:HNSW需要进行路径管理-------------------------------------tjt
  70. os.chdir(config.data_root_path)
  71. hnsw_p.save_index(hnsw_path)
  72. os.chdir(config.root_path)
  73. # 注意:HNSW需要进行路径管理-------------------------------------tjt
  74. # 日志采集
  75. self.logger.info(self.log_msg.format(
  76. id=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
  77. type=hnsw_path+"模型训练耗时:",
  78. message=time.time()-start0)) if self.logger else None
  79. print(self.log_msg.format(
  80. id=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
  81. type=hnsw_path+"模型训练耗时:",
  82. message=time.time()-start0)) if self.logger is None else None
  83. # 数据清洗与句向量计算
  84. def clear_embedding_update(self):
  85. find_dict = {"sent_train_flag": {"$exists": 0}}
  86. origin_dataset = self.mongo_coll.find(find_dict, no_cursor_timeout=True, batch_size=5)
  87. self.dpp(origin_dataset)
  88. if __name__ == "__main__":
  89. hm_train = HNSW_Model_Train()
  90. hm_train()