mongodb_cloud_train.py 948 B

123456789101112131415161718192021222324252627
  1. import sys
  2. import time
  3. import config
  4. from data_preprocessing import DataPreProcessing
  5. # 数据清洗与句向量计算
  6. def clear_embedding_train(mongo_coll, sup, sub):
  7. origin_dataset = mongo_coll[0].find({"is_stop": 0}, no_cursor_timeout=True, batch_size=5)
  8. dpp = DataPreProcessing(mongo_coll, is_train=True)
  9. start = time.time()
  10. dpp(origin_dataset[sup:sub], hnsw_index=0)
  11. print("耗时:", time.time()-start)
  12. if __name__ == "__main__":
  13. # 获取shell输入参数
  14. argv_list = sys.argv
  15. if len(argv_list) == 1:
  16. sup, sub = None, None
  17. elif len(argv_list) == 2:
  18. sup, sub = argv_list[1].split(':')
  19. sup = None if sup == '' else int(sup)
  20. sub = None if sub == '' else int(sub)
  21. # 获取mongodb数据
  22. mongo_coll = config.mongo_coll_list
  23. # 清洗文本与计算句向量(train_mode=1表示需要进行文本清洗与句向量计算)
  24. clear_embedding_train(mongo_coll, sup, sub)