db_train_app.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. import sys
  2. import time
  3. import json
  4. import config
  5. from data_preprocessing import DataPreProcessing
  6. # 数据清洗与句向量计算
  7. def clear_embedding_train(mongo_coll, mongo_find_dict, sup, sub):
  8. origin_dataset = mongo_coll.find(mongo_find_dict, no_cursor_timeout=True, batch_size=5)
  9. dpp = DataPreProcessing(mongo_coll, is_train=True)
  10. start = time.time()
  11. dpp(origin_dataset[sup:sub])
  12. print("耗时:", time.time()-start)
  13. # 知识点转换成id用于mongodb检索
  14. def convert_knowledge2id(mongo_coll, mongo_find_dict, sup, sub):
  15. with open("model_data/keyword_mapping.json", 'r', encoding="utf8") as f:
  16. knowledge2id = json.load(f)["knowledge2id"]
  17. origin_dataset = mongo_coll.find(mongo_find_dict, no_cursor_timeout=True, batch_size=5)
  18. start = time.time()
  19. for data in origin_dataset[sup:sub]:
  20. print(data["knowledge"])
  21. condition = {"id": data["id"]}
  22. # 需要新增train_flag,防止机器奔溃重复训练
  23. knowledge_list = [knowledge2id[ele] for ele in data["knowledge"] if ele in knowledge2id]
  24. update_elements = {"$set": {"knowledge_id": knowledge_list}}
  25. mongo_coll.update_one(condition, update_elements)
  26. print("耗时:", time.time()-start)
  27. if __name__ == "__main__":
  28. # 获取shell输入参数
  29. argv_list = sys.argv
  30. if len(argv_list) == 1:
  31. sup, sub = None, None
  32. elif len(argv_list) == 2:
  33. sup, sub = argv_list[1].split(':')
  34. sup = None if sup == '' else int(sup)
  35. sub = None if sub == '' else int(sub)
  36. # 获取mongodb数据
  37. mongo_coll = config.mongo_coll
  38. # mongo_find_dict = {"sent_train_flag": {"$exists": 0}}
  39. mongo_find_dict = dict()
  40. # 清洗文本与计算句向量(train_mode=1表示需要进行文本清洗与句向量计算)
  41. clear_embedding_train(mongo_coll, mongo_find_dict, sup, sub)
  42. # 知识点转换成id用于mongodb检索
  43. convert_knowledge2id(mongo_coll, mongo_find_dict, sup, sub)