|
@@ -22,7 +22,7 @@ class Difficulty_Model(nn.Module):
|
|
super(Difficulty_Model, self).__init__()
|
|
super(Difficulty_Model, self).__init__()
|
|
self.bert_config = AutoConfig.from_pretrained(config.bert_path)
|
|
self.bert_config = AutoConfig.from_pretrained(config.bert_path)
|
|
self.bert = AutoModel.from_pretrained(config.bert_path)
|
|
self.bert = AutoModel.from_pretrained(config.bert_path)
|
|
- self.fc = nn.Linear(in_features=self.bert_config.hidden_size, out_features=8)
|
|
|
|
|
|
+ self.fc = nn.Linear(in_features=self.bert_config.hidden_size, out_features=1)
|
|
|
|
|
|
def forward(self, input_ids, attention_mask):
|
|
def forward(self, input_ids, attention_mask):
|
|
x = self.bert(input_ids, attention_mask)[0][:, 0, :]
|
|
x = self.bert(input_ids, attention_mask)[0][:, 0, :]
|
|
@@ -31,10 +31,14 @@ class Difficulty_Model(nn.Module):
|
|
return x
|
|
return x
|
|
|
|
|
|
class Dimension_Classification():
|
|
class Dimension_Classification():
|
|
- def __init__(self, logger=None):
|
|
|
|
|
|
+ def __init__(self, dim_mode=2, logger=None):
|
|
|
|
+ self.dim_mode = dim_mode
|
|
self.tokenizer = BertTokenizer.from_pretrained(config.bert_path)
|
|
self.tokenizer = BertTokenizer.from_pretrained(config.bert_path)
|
|
- self.solution_model = torch.load(config.solution_model_path)
|
|
|
|
- self.difficulty_model = torch.load(config.difficulty_model_path)
|
|
|
|
|
|
+ self.solution_model, self.difficulty_model = None, None
|
|
|
|
+ if self.dim_mode in {0, 2}:
|
|
|
|
+ self.solution_model = torch.load(config.solution_model_path)
|
|
|
|
+ if self.dim_mode in {1, 2}:
|
|
|
|
+ self.difficulty_model = torch.load(config.difficulty_model_path)
|
|
self.max_squence_length = 500
|
|
self.max_squence_length = 500
|
|
self.solving_type_dict = {
|
|
self.solving_type_dict = {
|
|
0: "实验操作",
|
|
0: "实验操作",
|
|
@@ -50,8 +54,11 @@ class Dimension_Classification():
|
|
self.logger = logger
|
|
self.logger = logger
|
|
|
|
|
|
def __call__(self, sentence, quesType):
|
|
def __call__(self, sentence, quesType):
|
|
- solution_list = self.solution_classify(sentence, quesType)
|
|
|
|
- difficulty_value = self.difficulty_classify(sentence)
|
|
|
|
|
|
+ solution_list, difficulty_value = [], 0.6
|
|
|
|
+ if self.dim_mode in {0, 2}:
|
|
|
|
+ solution_list = self.solution_classify(sentence, quesType)
|
|
|
|
+ if self.dim_mode in {1, 2}:
|
|
|
|
+ difficulty_value = self.difficulty_classify(sentence)
|
|
res_dict = {
|
|
res_dict = {
|
|
"solving_type": solution_list,
|
|
"solving_type": solution_list,
|
|
"difficulty": difficulty_value,
|
|
"difficulty": difficulty_value,
|
|
@@ -63,16 +70,17 @@ class Dimension_Classification():
|
|
solution_tensor = self.model_calculate(self.solution_model, sentence)
|
|
solution_tensor = self.model_calculate(self.solution_model, sentence)
|
|
solution_tensor[solution_tensor >= 0.5] = 1
|
|
solution_tensor[solution_tensor >= 0.5] = 1
|
|
solution_tensor[solution_tensor < 0.5] = 0
|
|
solution_tensor[solution_tensor < 0.5] = 0
|
|
- solution_list = [self.solving_type_dict[idx] for idx in solution_tensor[0].int().tolist() if idx == 1]
|
|
|
|
|
|
+ solution_list = solution_tensor[0].int().tolist()
|
|
|
|
+ solution_result = [self.solving_type_dict[i] for i,idx in enumerate(solution_list) if idx == 1]
|
|
# 题型判断
|
|
# 题型判断
|
|
if quesType == "计算题":
|
|
if quesType == "计算题":
|
|
- solution_list.append("计算分析")
|
|
|
|
|
|
+ solution_result.append("计算分析")
|
|
elif quesType == "作图题":
|
|
elif quesType == "作图题":
|
|
- solution_list.append("连线作图")
|
|
|
|
- if len(solution_list) == 0:
|
|
|
|
- solution_list.append("规律理解")
|
|
|
|
|
|
+ solution_result.append("连线作图")
|
|
|
|
+ if len(solution_result) == 0:
|
|
|
|
+ solution_result.append("规律理解")
|
|
|
|
|
|
- return list(set(solution_list))
|
|
|
|
|
|
+ return list(set(solution_result))
|
|
|
|
|
|
def difficulty_classify(self, sentence):
|
|
def difficulty_classify(self, sentence):
|
|
difficulty_tensor = self.model_calculate(self.difficulty_model, sentence).item()
|
|
difficulty_tensor = self.model_calculate(self.difficulty_model, sentence).item()
|
|
@@ -89,9 +97,9 @@ class Dimension_Classification():
|
|
def model_calculate(self, model, sentence):
|
|
def model_calculate(self, model, sentence):
|
|
model.eval()
|
|
model.eval()
|
|
with torch.no_grad():
|
|
with torch.no_grad():
|
|
- token_list = self.sentence_tokenize(sentence)
|
|
|
|
- mask_list = self.attention_mask(token_list)
|
|
|
|
- output_tensor = model(torch.tensor(token_list), attention_mask=torch.tensor(mask_list))
|
|
|
|
|
|
+ token_tensor = self.sentence_tokenize(sentence)
|
|
|
|
+ mask_tensor = torch.ones_like(token_tensor, dtype=torch.float)
|
|
|
|
+ output_tensor = model(token_tensor, attention_mask=mask_tensor)
|
|
output_tensor = torch.sigmoid(output_tensor)
|
|
output_tensor = torch.sigmoid(output_tensor)
|
|
|
|
|
|
return output_tensor
|
|
return output_tensor
|
|
@@ -100,23 +108,12 @@ class Dimension_Classification():
|
|
# 直接截断
|
|
# 直接截断
|
|
# 编码时: 开头添加[LCS]->101, 结尾添加[SEP]->102, 未知的字或单词变为[UNK]->100
|
|
# 编码时: 开头添加[LCS]->101, 结尾添加[SEP]->102, 未知的字或单词变为[UNK]->100
|
|
token_list = self.tokenizer.encode(sentence[:self.max_squence_length])
|
|
token_list = self.tokenizer.encode(sentence[:self.max_squence_length])
|
|
- # 补齐(pad的索引号就是0)
|
|
|
|
- if len(token_list) < self.max_squence_length + 2:
|
|
|
|
- token_list.extend([0] * (self.max_squence_length + 2 - len(token_list)))
|
|
|
|
|
|
|
|
- return [token_list]
|
|
|
|
-
|
|
|
|
- def attention_mask(self, tokens_list):
|
|
|
|
- # 在一个文本中,如果是PAD符号则是0,否则就是1
|
|
|
|
- mask_list = []
|
|
|
|
- for tokens in tokens_list:
|
|
|
|
- mask = [float(token > 0) for token in tokens]
|
|
|
|
- mask_list.append(mask)
|
|
|
|
|
|
+ return torch.tensor([token_list])
|
|
|
|
|
|
- return mask_list
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
- dc = Dimension_Classification()
|
|
|
|
- sentence = "荆门市是国家循环经济试点市,目前正在沙洋建设全国最大的秸秆气化发电厂.电厂建成后每年可消化秸秆13万吨,发电9*10^7*kW*h.同时电厂所产生的灰渣将生成肥料返还农民,焦油用于精细化工,实现“农业--工业--农业”循环.(1)若秸秆电厂正常工作时,每小时可发电2.5*10^5*kW*h,按每户居民每天使用5只20*W的节能灯、1个800*W的电饭锅、1台100*W的电视机计算,该发电厂同时可供多少户居民正常用电?(2)与同等规模的火电厂相比,该电厂每年可减少6.4万吨二氧化碳的排放量,若火电厂煤燃烧的热利用率为20%,秸秆电厂每年可节约多少吨标准煤?(标准煤的热值按3.6*10^7J/k*g计算)"
|
|
|
|
|
|
+ dc = Dimension_Classification(dim_mode=0)
|
|
|
|
+ sentence = "请在图乙中的虚线框内画出与图甲中实物图对应的电路图。"
|
|
res = dc(sentence, "")
|
|
res = dc(sentence, "")
|
|
print(res)
|
|
print(res)
|