12345678910111213141516171819202122232425262728293031323334353637383940414243444546 |
- import sys
- import time
- import json
- import config
- from data_preprocessing import DataPreProcessing
- # 数据清洗与句向量计算
- def clear_embedding_train(mongo_coll, mongo_find_dict, sup, sub):
- origin_dataset = mongo_coll.find(mongo_find_dict, no_cursor_timeout=True, batch_size=5)
- dpp = DataPreProcessing(mongo_coll, is_train=True)
- start = time.time()
- dpp(origin_dataset[sup:sub])
- print("耗时:", time.time()-start)
- # 知识点转换成id用于mongodb检索
- def convert_knowledge2id(mongo_coll, mongo_find_dict, sup, sub):
- with open("model_data/keyword_mapping.json", 'r', encoding="utf8") as f:
- knowledge2id = json.load(f)["knowledge2id"]
- origin_dataset = mongo_coll.find(mongo_find_dict, no_cursor_timeout=True, batch_size=5)
- start = time.time()
- for data in origin_dataset[sup:sub]:
- print(data["knowledge"])
- condition = {"id": data["id"]}
- # 需要新增train_flag,防止机器奔溃重复训练
- knowledge_list = [knowledge2id[ele] for ele in data["knowledge"] if knowledge2id.get(ele, 0)]
- update_elements = {"$set": {"knowledge_id": knowledge_list}}
- mongo_coll.update_one(condition, update_elements)
- print("耗时:", time.time()-start)
- if __name__ == "__main__":
- # 获取shell输入参数
- argv_list = sys.argv
- if len(argv_list) == 1:
- sup, sub = None, None
- elif len(argv_list) == 2:
- sup, sub = argv_list[1].split(':')
- sup = None if sup == '' else int(sup)
- sub = None if sub == '' else int(sub)
- # 获取mongodb数据
- mongo_coll = config.mongo_coll
- # mongo_find_dict = {"sent_train_flag": {"$exists": 0}}
- mongo_find_dict = dict()
- # 清洗文本与计算句向量(train_mode=1表示需要进行文本清洗与句向量计算)
- clear_embedding_train(mongo_coll, mongo_find_dict, sup, sub)
- # 知识点转换成id用于mongodb检索
- convert_knowledge2id(mongo_coll, mongo_find_dict, sup, sub)
|