hnsw_model_train.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import os
  2. import time
  3. import numpy as np
  4. import pickle
  5. import hnswlib
  6. import config
  7. from data_preprocessing import shuffle_data_pair, DataPreProcessing
  8. class HNSW_Model_Train():
  9. def __init__(self, logger=None):
  10. # 配置初始数据
  11. self.mongo_coll = config.mongo_coll
  12. self.dpp = DataPreProcessing(self.mongo_coll, is_train=True)
  13. # # bert-whitening参数
  14. # with open(config.whitening_path, 'rb') as f:
  15. # self.kernel, self.bias = pickle.load(f)
  16. # 日志采集
  17. self.logger = logger
  18. self.log_msg = config.log_msg
  19. # 训练HNSW模型并重启服务
  20. def __call__(self):
  21. # 全学科hnsw模型训练
  22. start = time.time()
  23. # 全学科hnsw模型训练
  24. self.subject_model_train()
  25. # 日志采集
  26. self.logger.info(self.log_msg.format(
  27. id=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
  28. type="HNSW模型训练总计耗时:",
  29. message=time.time()-start)) if self.logger else None
  30. # 模型训练结束后重启服务
  31. print(self.log_msg.format(
  32. id=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
  33. type="HNSW模型训练总计耗时:",
  34. message=time.time()-start)) if self.logger is None else None
  35. # 全学科hnsw模型训练
  36. def subject_model_train(self):
  37. self.clear_embedding_update()
  38. origin_dataset = self.mongo_coll.find(no_cursor_timeout=True, batch_size=5)
  39. self.hnsw_train(origin_dataset, config.hnsw_path)
  40. # 训练HNSW模型
  41. def hnsw_train(self, origin_dataset, hnsw_path):
  42. start0 = time.time()
  43. # 初始化HNSW搜索图
  44. # possible options are l2, cosine or ip
  45. hnsw_p = hnswlib.Index(space = config.hnsw_metric, dim = config.vector_dim)
  46. # 初始化HNSW索引及相关参数
  47. hnsw_p.init_index(max_elements = config.num_elements, ef_construction = 200, M = 16)
  48. # ef要大于召回数据数量
  49. hnsw_p.set_ef(config.hnsw_set_ef)
  50. # 设置线程数量-during batch search/construction
  51. hnsw_p.set_num_threads(4)
  52. idx_list, vec_list = [], []
  53. for data_idx,data in enumerate(origin_dataset):
  54. if "sentence_vec" not in data:
  55. continue
  56. sentence_vec = pickle.loads(data["sentence_vec"]).astype(np.float32)
  57. if sentence_vec.size != config.vector_dim:
  58. continue
  59. idx_list.append(data["id"])
  60. vec_list.append(sentence_vec)
  61. # 设置批量处理长度,若满足条件则进行批量处理
  62. if (data_idx+1) % 300000 == 0:
  63. # 按数据对应顺序随机打乱数据
  64. idx_list, vec_list = shuffle_data_pair(idx_list, vec_list)
  65. # 将数据进行HNSW构图
  66. hnsw_p.add_items(vec_list, idx_list)
  67. idx_list, vec_list = [], []
  68. # 将句向量加入到HNSW
  69. if len(idx_list) > 0:
  70. # 按数据对应顺序随机打乱数据
  71. idx_list, vec_list = shuffle_data_pair(idx_list, vec_list)
  72. # 将数据进行HNSW构图
  73. hnsw_p.add_items(vec_list, idx_list)
  74. # 保存HNSW图模型
  75. # 注意:HNSW需要进行路径管理-------------------------------------tjt
  76. os.chdir(config.data_root_path)
  77. hnsw_p.save_index(hnsw_path)
  78. os.chdir(config.root_path)
  79. # 注意:HNSW需要进行路径管理-------------------------------------tjt
  80. # 日志采集
  81. self.logger.info(self.log_msg.format(
  82. id=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
  83. type=hnsw_path+"模型训练耗时:",
  84. message=time.time()-start0)) if self.logger else None
  85. print(self.log_msg.format(
  86. id=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
  87. type=hnsw_path+"模型训练耗时:",
  88. message=time.time()-start0)) if self.logger is None else None
  89. # 数据清洗与句向量计算
  90. def clear_embedding_update(self):
  91. find_dict = {"sent_train_flag": {"$exists": 0}}
  92. origin_dataset = self.mongo_coll.find(find_dict, no_cursor_timeout=True, batch_size=5)
  93. self.dpp(origin_dataset)
  94. if __name__ == "__main__":
  95. hm_train = HNSW_Model_Train()
  96. hm_train()