dim_classify.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import torch
  2. import torch.nn as nn
  3. from transformers import AutoConfig, BertTokenizer, AutoModel
  4. import config
  5. class Solution_Model(nn.Module):
  6. def __init__(self):
  7. super(Solution_Model, self).__init__()
  8. self.bert_config = AutoConfig.from_pretrained(config.bert_path)
  9. self.bert = AutoModel.from_pretrained(config.bert_path)
  10. self.fc = nn.Linear(in_features=self.bert_config.hidden_size, out_features=8)
  11. def forward(self, input_ids, attention_mask):
  12. x = self.bert(input_ids, attention_mask)[0][:, 0, :]
  13. x = self.fc(x)
  14. return x
  15. class Difficulty_Model(nn.Module):
  16. def __init__(self):
  17. super(Difficulty_Model, self).__init__()
  18. self.bert_config = AutoConfig.from_pretrained(config.bert_path)
  19. self.bert = AutoModel.from_pretrained(config.bert_path)
  20. self.fc = nn.Linear(in_features=self.bert_config.hidden_size, out_features=8)
  21. def forward(self, input_ids, attention_mask):
  22. x = self.bert(input_ids, attention_mask)[0][:, 0, :]
  23. x = self.fc(x)
  24. return x
  25. class Dimension_Classification():
  26. def __init__(self, logger=None):
  27. self.tokenizer = BertTokenizer.from_pretrained(config.bert_path)
  28. self.solution_model = torch.load(config.solution_model_path)
  29. self.difficulty_model = torch.load(config.difficulty_model_path)
  30. self.max_squence_length = 500
  31. self.solving_type_dict = {
  32. 0: "实验操作",
  33. 1: "计算分析",
  34. 2: "连线作图",
  35. 3: "实验读数",
  36. 4: "现象解释",
  37. 5: "概念辨析",
  38. 6: "规律理解",
  39. 7: "物理学史"
  40. }
  41. # 日志采集
  42. self.logger = logger
  43. def __call__(self, sentence, quesType):
  44. solution_list = self.solution_classify(sentence, quesType)
  45. difficulty_value = self.difficulty_classify(sentence)
  46. res_dict = {
  47. "solving_type": solution_list,
  48. "difficulty": difficulty_value,
  49. }
  50. return res_dict
  51. def solution_classify(self, sentence, quesType):
  52. solution_tensor = self.model_calculate(self.solution_model, sentence)
  53. solution_tensor[solution_tensor >= 0.5] = 1
  54. solution_tensor[solution_tensor < 0.5] = 0
  55. solution_list = [self.solving_type_dict[idx] for idx in solution_tensor[0].int().tolist() if idx == 1]
  56. # 题型判断
  57. if quesType == "计算题":
  58. solution_list.append("计算分析")
  59. elif quesType == "作图题":
  60. solution_list.append("连线作图")
  61. if len(solution_list) == 0:
  62. solution_list.append("规律理解")
  63. return list(set(solution_list))
  64. def difficulty_classify(self, sentence):
  65. difficulty_tensor = self.model_calculate(self.difficulty_model, sentence).item()
  66. difficulty_value = 0.6
  67. if difficulty_tensor >= 0.8:
  68. difficulty_value = 0.8
  69. elif difficulty_tensor <= 0.2:
  70. difficulty_value = 0.4
  71. else:
  72. difficulty_value = 0.6
  73. return difficulty_value
  74. def model_calculate(self, model, sentence):
  75. model.eval()
  76. with torch.no_grad():
  77. token_list = self.sentence_tokenize(sentence)
  78. mask_list = self.attention_mask(token_list)
  79. output_tensor = model(torch.tensor(token_list), attention_mask=torch.tensor(mask_list))
  80. output_tensor = torch.sigmoid(output_tensor)
  81. return output_tensor
  82. def sentence_tokenize(self, sentence):
  83. # 直接截断
  84. # 编码时: 开头添加[LCS]->101, 结尾添加[SEP]->102, 未知的字或单词变为[UNK]->100
  85. token_list = self.tokenizer.encode(sentence[:self.max_squence_length])
  86. # 补齐(pad的索引号就是0)
  87. if len(token_list) < self.max_squence_length + 2:
  88. token_list.extend([0] * (self.max_squence_length + 2 - len(token_list)))
  89. return [token_list]
  90. def attention_mask(self, tokens_list):
  91. # 在一个文本中,如果是PAD符号则是0,否则就是1
  92. mask_list = []
  93. for tokens in tokens_list:
  94. mask = [float(token > 0) for token in tokens]
  95. mask_list.append(mask)
  96. return mask_list
  97. if __name__ == "__main__":
  98. dc = Dimension_Classification()
  99. sentence = "荆门市是国家循环经济试点市,目前正在沙洋建设全国最大的秸秆气化发电厂.电厂建成后每年可消化秸秆13万吨,发电9*10^7*kW*h.同时电厂所产生的灰渣将生成肥料返还农民,焦油用于精细化工,实现“农业--工业--农业”循环.(1)若秸秆电厂正常工作时,每小时可发电2.5*10^5*kW*h,按每户居民每天使用5只20*W的节能灯、1个800*W的电饭锅、1台100*W的电视机计算,该发电厂同时可供多少户居民正常用电?(2)与同等规模的火电厂相比,该电厂每年可减少6.4万吨二氧化碳的排放量,若火电厂煤燃烧的热利用率为20%,秸秆电厂每年可节约多少吨标准煤?(标准煤的热值按3.6*10^7J/k*g计算)"
  100. res = dc(sentence, "")
  101. print(res)