dim_classify.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import torch
  2. import torch.nn as nn
  3. from transformers import AutoConfig, BertTokenizer, AutoModel
  4. import config
  5. class Solution_Model(nn.Module):
  6. def __init__(self):
  7. super(Solution_Model, self).__init__()
  8. self.bert_config = AutoConfig.from_pretrained(config.bert_path)
  9. self.bert = AutoModel.from_pretrained(config.bert_path)
  10. self.fc = nn.Linear(in_features=self.bert_config.hidden_size, out_features=8)
  11. def forward(self, input_ids, attention_mask):
  12. x = self.bert(input_ids, attention_mask)[0][:, 0, :]
  13. x = self.fc(x)
  14. return x
  15. class Difficulty_Model(nn.Module):
  16. def __init__(self):
  17. super(Difficulty_Model, self).__init__()
  18. self.bert_config = AutoConfig.from_pretrained(config.bert_path)
  19. self.bert = AutoModel.from_pretrained(config.bert_path)
  20. self.fc = nn.Linear(in_features=self.bert_config.hidden_size, out_features=1)
  21. def forward(self, input_ids, attention_mask):
  22. x = self.bert(input_ids, attention_mask)[0][:, 0, :]
  23. x = self.fc(x)
  24. return x
  25. class Dimension_Classification():
  26. def __init__(self, dim_mode=2, logger=None):
  27. self.dim_mode = dim_mode
  28. self.tokenizer = BertTokenizer.from_pretrained(config.bert_path)
  29. self.solution_model, self.difficulty_model = None, None
  30. if self.dim_mode in {0, 2}:
  31. self.solution_model = torch.load(config.solution_model_path)
  32. if self.dim_mode in {1, 2}:
  33. self.difficulty_model = torch.load(config.difficulty_model_path)
  34. self.max_squence_length = 500
  35. self.solving_type_dict = {
  36. 0: "实验操作",
  37. 1: "计算分析",
  38. 2: "连线作图",
  39. 3: "实验读数",
  40. 4: "现象解释",
  41. 5: "概念辨析",
  42. 6: "规律理解",
  43. 7: "物理学史"
  44. }
  45. # 日志采集
  46. self.logger = logger
  47. def __call__(self, sentence, quesType):
  48. solution_list, difficulty_value = [], 0.6
  49. if self.dim_mode in {0, 2}:
  50. solution_list = self.solution_classify(sentence, quesType)
  51. if self.dim_mode in {1, 2}:
  52. difficulty_value = self.difficulty_classify(sentence)
  53. res_dict = {
  54. "solving_type": solution_list,
  55. "difficulty": difficulty_value,
  56. }
  57. return res_dict
  58. def solution_classify(self, sentence, quesType):
  59. solution_tensor = self.model_calculate(self.solution_model, sentence)
  60. solution_tensor[solution_tensor >= 0.5] = 1
  61. solution_tensor[solution_tensor < 0.5] = 0
  62. solution_list = solution_tensor[0].int().tolist()
  63. solution_result = [self.solving_type_dict[i] for i,idx in enumerate(solution_list) if idx == 1]
  64. # 题型判断
  65. if quesType == "计算题":
  66. solution_result.append("计算分析")
  67. elif quesType == "作图题":
  68. solution_result.append("连线作图")
  69. if len(solution_result) == 0:
  70. solution_result.append("规律理解")
  71. return list(set(solution_result))
  72. def difficulty_classify(self, sentence):
  73. difficulty_tensor = self.model_calculate(self.difficulty_model, sentence).item()
  74. difficulty_value = 0.6
  75. if difficulty_tensor >= 0.8:
  76. difficulty_value = 0.8
  77. elif difficulty_tensor <= 0.2:
  78. difficulty_value = 0.4
  79. else:
  80. difficulty_value = 0.6
  81. return difficulty_value
  82. def model_calculate(self, model, sentence):
  83. model.eval()
  84. with torch.no_grad():
  85. token_tensor = self.sentence_tokenize(sentence)
  86. mask_tensor = torch.ones_like(token_tensor, dtype=torch.float)
  87. output_tensor = model(token_tensor, attention_mask=mask_tensor)
  88. output_tensor = torch.sigmoid(output_tensor)
  89. return output_tensor
  90. def sentence_tokenize(self, sentence):
  91. # 直接截断
  92. # 编码时: 开头添加[LCS]->101, 结尾添加[SEP]->102, 未知的字或单词变为[UNK]->100
  93. token_list = self.tokenizer.encode(sentence[:self.max_squence_length])
  94. return torch.tensor([token_list])
  95. if __name__ == "__main__":
  96. dc = Dimension_Classification(dim_mode=0)
  97. sentence = "请在图乙中的虚线框内画出与图甲中实物图对应的电路图。"
  98. res = dc(sentence, "")
  99. print(res)