import torch import torch.nn as nn from transformers import AutoConfig, BertTokenizer, AutoModel import config class Solution_Model(nn.Module): def __init__(self): super(Solution_Model, self).__init__() self.bert_config = AutoConfig.from_pretrained(config.bert_path) self.bert = AutoModel.from_pretrained(config.bert_path) self.fc = nn.Linear(in_features=self.bert_config.hidden_size, out_features=8) def forward(self, input_ids, attention_mask): x = self.bert(input_ids, attention_mask)[0][:, 0, :] x = self.fc(x) return x class Difficulty_Model(nn.Module): def __init__(self): super(Difficulty_Model, self).__init__() self.bert_config = AutoConfig.from_pretrained(config.bert_path) self.bert = AutoModel.from_pretrained(config.bert_path) self.fc = nn.Linear(in_features=self.bert_config.hidden_size, out_features=1) def forward(self, input_ids, attention_mask): x = self.bert(input_ids, attention_mask)[0][:, 0, :] x = self.fc(x) return x class Dimension_Classification(): def __init__(self, dim_mode=2, logger=None): self.dim_mode = dim_mode self.tokenizer = BertTokenizer.from_pretrained(config.bert_path) self.solution_model, self.difficulty_model = None, None if self.dim_mode in {0, 2}: self.solution_model = torch.load(config.solution_model_path) if self.dim_mode in {1, 2}: self.difficulty_model = torch.load(config.difficulty_model_path) self.max_squence_length = 500 self.solving_type_dict = { 0: "实验操作", 1: "计算分析", 2: "连线作图", 3: "实验读数", 4: "现象解释", 5: "概念辨析", 6: "规律理解", 7: "物理学史" } # 日志采集 self.logger = logger def __call__(self, sentence, quesType): solution_list, difficulty_value = [], 0.6 if self.dim_mode in {0, 2}: solution_list = self.solution_classify(sentence, quesType) if self.dim_mode in {1, 2}: difficulty_value = self.difficulty_classify(sentence) res_dict = { "solving_type": solution_list, "difficulty": difficulty_value, } return res_dict def solution_classify(self, sentence, quesType): solution_tensor = self.model_calculate(self.solution_model, sentence) solution_tensor[solution_tensor >= 0.5] = 1 solution_tensor[solution_tensor < 0.5] = 0 solution_list = solution_tensor[0].int().tolist() solution_result = [self.solving_type_dict[i] for i,idx in enumerate(solution_list) if idx == 1] # 题型判断 if quesType == "计算题": solution_result.append("计算分析") elif quesType == "作图题": solution_result.append("连线作图") if len(solution_result) == 0: solution_result.append("规律理解") return list(set(solution_result)) def difficulty_classify(self, sentence): difficulty_tensor = self.model_calculate(self.difficulty_model, sentence).item() difficulty_value = 0.6 if difficulty_tensor >= 0.8: difficulty_value = 0.8 elif difficulty_tensor <= 0.2: difficulty_value = 0.4 else: difficulty_value = 0.6 return difficulty_value def model_calculate(self, model, sentence): model.eval() with torch.no_grad(): token_tensor = self.sentence_tokenize(sentence) mask_tensor = torch.ones_like(token_tensor, dtype=torch.float) output_tensor = model(token_tensor, attention_mask=mask_tensor) output_tensor = torch.sigmoid(output_tensor) return output_tensor def sentence_tokenize(self, sentence): # 直接截断 # 编码时: 开头添加[LCS]->101, 结尾添加[SEP]->102, 未知的字或单词变为[UNK]->100 token_list = self.tokenizer.encode(sentence[:self.max_squence_length]) return torch.tensor([token_list]) if __name__ == "__main__": dc = Dimension_Classification(dim_mode=0) sentence = "请在图乙中的虚线框内画出与图甲中实物图对应的电路图。" res = dc(sentence, "") print(res)