db_train_app.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import sys
  2. import time
  3. import json
  4. import config
  5. from data_preprocessing import DataPreProcessing
  6. from dim_classify import Dimension_Classification
  7. from physical_quantity_extract import physical_quantity_extract
  8. """
  9. MongoDB数据类型:
  10. {
  11. "id" : 20231001,
  12. "quesType" : {
  13. "quesType" : "单选题"
  14. },
  15. "quesBody" : "荔枝是一种岭南佳果,小明拿起一个荔枝,如题图所示,它的尺寸l大小约为( )<br/><img src=\"Upload/QBM/20231001.png\" /><br/>\nA. 0.1cm B. 3cm C. 0.3m D. 1m",
  16. "quesParse" : "......",
  17. "quesAnswer" : "【答案】见解析",
  18. "difficulty" : "一般",
  19. "knowledge" : [
  20. "长度的测量"
  21. ]
  22. }
  23. """
  24. # 数据清洗与句向量计算
  25. def clear_embedding_train(mongo_coll, mongo_find_dict, sup, sub):
  26. origin_dataset = mongo_coll.find(mongo_find_dict, no_cursor_timeout=True, batch_size=5)
  27. dpp = DataPreProcessing(mongo_coll, is_train=True)
  28. start = time.time()
  29. dpp(origin_dataset[sup:sub])
  30. print("耗时:", time.time()-start)
  31. # 知识点转换成id用于mongodb检索/计算物理量/计算求解类型
  32. def convert_knowledge2id(mongo_coll, mongo_find_dict, sup, sub):
  33. # 加载知识点转ID数据
  34. with open("model_data/keyword_mapping.json", 'r', encoding="utf8") as f:
  35. knowledge2id = json.load(f)["knowledge2id"]
  36. dim_classify = Dimension_Classification(dim_mode=0)
  37. origin_dataset = mongo_coll.find(mongo_find_dict, no_cursor_timeout=True, batch_size=5)
  38. start = time.time()
  39. for data in origin_dataset[sup:sub]:
  40. condition = {"id": data["id"]}
  41. # 计算物理量
  42. physical_quantity_list = physical_quantity_extract(data["content_clear"])
  43. # 计算求解类型
  44. solution_list = dim_classify(data["content_clear"], data["quesType"])["solving_type"]
  45. # 知识点转ID
  46. knowledge_list = [knowledge2id[ele] for ele in data["knowledge"] if ele in knowledge2id]
  47. update_elements = {"$set": {"physical_quantity": physical_quantity_list,
  48. "solving_type": solution_list,
  49. "knowledge_id": knowledge_list}}
  50. mongo_coll.update_one(condition, update_elements)
  51. print(physical_quantity_list, solution_list)
  52. print("耗时:", time.time()-start)
  53. if __name__ == "__main__":
  54. # 获取shell输入参数
  55. argv_list = sys.argv
  56. if len(argv_list) == 1:
  57. sup, sub = None, None
  58. elif len(argv_list) == 2:
  59. sup, sub = argv_list[1].split(':')
  60. sup = None if sup == '' else int(sup)
  61. sub = None if sub == '' else int(sub)
  62. # 获取mongodb数据
  63. mongo_coll = config.mongo_coll
  64. # mongo_find_dict = {"sent_train_flag": {"$exists": 0}}
  65. mongo_find_dict = dict()
  66. # 清洗文本与计算句向量(train_mode=1表示需要进行文本清洗与句向量计算)
  67. clear_embedding_train(mongo_coll, mongo_find_dict, sup, sub)
  68. # 知识点转换成id用于mongodb检索
  69. convert_knowledge2id(mongo_coll, mongo_find_dict, sup, sub)