123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122 |
- import torch
- import torch.nn as nn
- from transformers import AutoConfig, BertTokenizer, AutoModel
- import config
- class Solution_Model(nn.Module):
- def __init__(self):
- super(Solution_Model, self).__init__()
- self.bert_config = AutoConfig.from_pretrained(config.bert_path)
- self.bert = AutoModel.from_pretrained(config.bert_path)
- self.fc = nn.Linear(in_features=self.bert_config.hidden_size, out_features=8)
- def forward(self, input_ids, attention_mask):
- x = self.bert(input_ids, attention_mask)[0][:, 0, :]
- x = self.fc(x)
- return x
- class Difficulty_Model(nn.Module):
- def __init__(self):
- super(Difficulty_Model, self).__init__()
- self.bert_config = AutoConfig.from_pretrained(config.bert_path)
- self.bert = AutoModel.from_pretrained(config.bert_path)
- self.fc = nn.Linear(in_features=self.bert_config.hidden_size, out_features=8)
- def forward(self, input_ids, attention_mask):
- x = self.bert(input_ids, attention_mask)[0][:, 0, :]
- x = self.fc(x)
- return x
- class Dimension_Classification():
- def __init__(self, logger=None):
- self.tokenizer = BertTokenizer.from_pretrained(config.bert_path)
- self.solution_model = torch.load(config.solution_model_path)
- self.difficulty_model = torch.load(config.difficulty_model_path)
- self.max_squence_length = 500
- self.solving_type_dict = {
- 0: "实验操作",
- 1: "计算分析",
- 2: "连线作图",
- 3: "实验读数",
- 4: "现象解释",
- 5: "概念辨析",
- 6: "规律理解",
- 7: "物理学史"
- }
- # 日志采集
- self.logger = logger
- def __call__(self, sentence, quesType):
- solution_list = self.solution_classify(sentence, quesType)
- difficulty_value = self.difficulty_classify(sentence)
- res_dict = {
- "solving_type": solution_list,
- "difficulty": difficulty_value,
- }
- return res_dict
- def solution_classify(self, sentence, quesType):
- solution_tensor = self.model_calculate(self.solution_model, sentence)
- solution_tensor[solution_tensor >= 0.5] = 1
- solution_tensor[solution_tensor < 0.5] = 0
- solution_list = [self.solving_type_dict[idx] for idx in solution_tensor[0].int().tolist() if idx == 1]
- # 题型判断
- if quesType == "计算题":
- solution_list.append("计算分析")
- elif quesType == "作图题":
- solution_list.append("连线作图")
- if len(solution_list) == 0:
- solution_list.append("规律理解")
-
- return list(set(solution_list))
- def difficulty_classify(self, sentence):
- difficulty_tensor = self.model_calculate(self.difficulty_model, sentence).item()
- difficulty_value = 0.6
- if difficulty_tensor >= 0.8:
- difficulty_value = 0.8
- elif difficulty_tensor <= 0.2:
- difficulty_value = 0.4
- else:
- difficulty_value = 0.6
-
- return difficulty_value
- def model_calculate(self, model, sentence):
- model.eval()
- with torch.no_grad():
- token_list = self.sentence_tokenize(sentence)
- mask_list = self.attention_mask(token_list)
- output_tensor = model(torch.tensor(token_list), attention_mask=torch.tensor(mask_list))
- output_tensor = torch.sigmoid(output_tensor)
- return output_tensor
- def sentence_tokenize(self, sentence):
- # 直接截断
- # 编码时: 开头添加[LCS]->101, 结尾添加[SEP]->102, 未知的字或单词变为[UNK]->100
- token_list = self.tokenizer.encode(sentence[:self.max_squence_length])
- # 补齐(pad的索引号就是0)
- if len(token_list) < self.max_squence_length + 2:
- token_list.extend([0] * (self.max_squence_length + 2 - len(token_list)))
-
- return [token_list]
- def attention_mask(self, tokens_list):
- # 在一个文本中,如果是PAD符号则是0,否则就是1
- mask_list = []
- for tokens in tokens_list:
- mask = [float(token > 0) for token in tokens]
- mask_list.append(mask)
- return mask_list
- if __name__ == "__main__":
- dc = Dimension_Classification()
- sentence = "荆门市是国家循环经济试点市,目前正在沙洋建设全国最大的秸秆气化发电厂.电厂建成后每年可消化秸秆13万吨,发电9*10^7*kW*h.同时电厂所产生的灰渣将生成肥料返还农民,焦油用于精细化工,实现“农业--工业--农业”循环.(1)若秸秆电厂正常工作时,每小时可发电2.5*10^5*kW*h,按每户居民每天使用5只20*W的节能灯、1个800*W的电饭锅、1台100*W的电视机计算,该发电厂同时可供多少户居民正常用电?(2)与同等规模的火电厂相比,该电厂每年可减少6.4万吨二氧化碳的排放量,若火电厂煤燃烧的热利用率为20%,秸秆电厂每年可节约多少吨标准煤?(标准煤的热值按3.6*10^7J/k*g计算)"
- res = dc(sentence, "")
- print(res)
|