hnsw_model_train.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import os
  2. import gc
  3. import time
  4. import pickle
  5. import hnswlib
  6. import config
  7. from data_preprocessing import shuffle_data_pair, DataPreProcessing
  8. from restart_server import server_killer, restart_hnsw_app
  9. class HNSW_Model_Train():
  10. def __init__(self):
  11. # 配置初始数据
  12. self.mongo_coll = config.mongo_coll_list
  13. self.dpp = DataPreProcessing(self.mongo_coll, is_train=True)
  14. # 训练HNSW模型并重启服务(尽量在ASR_app.py每天定时保存当前运行模型之后运行)
  15. def __call__(self):
  16. # 关闭服务进程
  17. server_killer(port=8858)
  18. # 清空hnsw_update_data.txt中更新数据
  19. with open(config.hnsw_update_save_path, 'w', encoding='utf8') as f:
  20. f.write("")
  21. # 全学科hnsw模型训练
  22. start = time.time()
  23. for hnsw_index in range(config.hnsw_num):
  24. self.subject_model_train(hnsw_index)
  25. # 日志采集
  26. print(config.log_msg.format(id=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
  27. type="HNSW模型训练总计耗时:",
  28. message=time.time()-start))
  29. # 模型训练结束后重启服务
  30. self.restart_server()
  31. # 全学科hnsw模型训练
  32. def subject_model_train(self, hnsw_index):
  33. if hnsw_index == 0:
  34. self.clear_embedding_update(hnsw_index)
  35. find_dict = {"is_stop": 0} if hnsw_index == 0 else {}
  36. origin_dataset = self.mongo_coll[hnsw_index].find(find_dict, no_cursor_timeout=True, batch_size=5)
  37. self.hnsw_train(origin_dataset, config.hnsw_path_list[hnsw_index])
  38. # 训练HNSW模型
  39. def hnsw_train(self, origin_dataset, hnsw_path):
  40. start0 = time.time()
  41. # 初始化HNSW搜索图
  42. # possible options are l2, cosine or ip
  43. hnsw_p = hnswlib.Index(space = config.hnsw_metric, dim = config.vector_dim)
  44. # 初始化HNSW索引及相关参数
  45. hnsw_p.init_index(max_elements = config.num_elements, ef_construction = 200, M = 16)
  46. # ef要大于召回数据数量
  47. hnsw_p.set_ef(config.hnsw_set_ef)
  48. # 设置线程数量-during batch search/construction
  49. hnsw_p.set_num_threads(4)
  50. # 从mongodb获取句向量和标签
  51. idx_list, vec_list = [], []
  52. for data_idx,data in enumerate(origin_dataset):
  53. if "sentence_vec" not in data:
  54. continue
  55. sentence_vec = pickle.loads(data["sentence_vec"])
  56. if sentence_vec.size != config.vector_dim:
  57. continue
  58. idx_list.append(data["topic_id"])
  59. vec_list.append(sentence_vec)
  60. # 设置批量处理长度,若满足条件则进行批量处理
  61. if (data_idx+1) % 500000 == 0:
  62. # 按数据对应顺序随机打乱数据
  63. idx_list, vec_list = shuffle_data_pair(idx_list, vec_list)
  64. # 将数据进行HNSW构图
  65. hnsw_p.add_items(vec_list, idx_list)
  66. idx_list, vec_list = [], []
  67. # 将句向量加入到HNSW
  68. if len(idx_list) > 0:
  69. # 按数据对应顺序随机打乱数据
  70. idx_list, vec_list = shuffle_data_pair(idx_list, vec_list)
  71. # 将数据进行HNSW构图
  72. hnsw_p.add_items(vec_list, idx_list)
  73. # 保存HNSW图模型
  74. # 注意:HNSW需要进行路径管理-------------------------------------tjt
  75. os.chdir(config.data_root_path)
  76. hnsw_p.save_index(hnsw_path)
  77. os.chdir(config.root_path)
  78. # 注意:HNSW需要进行路径管理-------------------------------------tjt
  79. time.sleep(5)
  80. # 无效变量内存回收
  81. del hnsw_p
  82. gc.collect()
  83. # 日志采集
  84. print(config.log_msg.format(id=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
  85. type=hnsw_path+"模型训练耗时:",
  86. message=time.time()-start0))
  87. # 数据清洗与句向量计算
  88. def clear_embedding_update(self, hnsw_index):
  89. find_dict = {"sent_train_flag": {"$exists": 0}}
  90. if hnsw_index == 0:
  91. find_dict["is_stop"] = 0
  92. origin_dataset = self.mongo_coll[hnsw_index].find(find_dict, no_cursor_timeout=True, batch_size=5)
  93. self.dpp(origin_dataset, hnsw_index)
  94. # 模型训练结束后重启服务
  95. def restart_server(self):
  96. restart_hnsw_app()
  97. # 日志采集
  98. print(config.log_msg.format(id=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
  99. type="重启hnsw_app服务",
  100. message="HNSW模型训练完毕, 已重新启动hnsw_app服务"))
  101. if __name__ == "__main__":
  102. hm_train = HNSW_Model_Train()
  103. hm_train()