model.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. import torch
  2. import gc
  3. import torch.nn as nn
  4. from transformers import BertConfig, BertModel, BertForSequenceClassification
  5. class UIEModel(nn.Module):
  6. def __init__(self, args):
  7. super(UIEModel, self).__init__()
  8. self.args = args
  9. self.tasks = args.tasks
  10. bert_dir = args.bert_dir
  11. self.bert_config = BertConfig.from_pretrained(bert_dir)
  12. self.bert_model = BertModel.from_pretrained(bert_dir)
  13. # self.bert_model.load_state_dict(torch.load(self.args.bert_pt_dir, map_location="cuda"))
  14. # self.bert_model = BertForSequenceClassification.from_pretrained(args.bert_pt_dir)
  15. if "ner" in args.tasks:
  16. self.ner_num_labels = args.ner_num_labels
  17. self.module_start_list = nn.ModuleList()
  18. self.module_end_list = nn.ModuleList()
  19. self.module_content_list = nn.ModuleList() # 增加一层内容判断
  20. for i in range(args.ner_num_labels):
  21. self.module_start_list.append(nn.Linear(self.bert_config.hidden_size, 1))
  22. self.module_end_list.append(nn.Linear(self.bert_config.hidden_size, 1))
  23. self.module_content_list.append(nn.Linear(self.bert_config.hidden_size, 1))
  24. self.ner_criterion = nn.BCEWithLogitsLoss()
  25. self.dropout = nn.Dropout(self.bert_config.hidden_dropout_prob)
  26. @staticmethod
  27. def build_dummpy_inputs():
  28. inputs = {}
  29. inputs['ner_input_ids'] = torch.LongTensor(
  30. torch.randint(low=1, high=10, size=(32, 56)))
  31. inputs['ner_attention_mask'] = torch.ones(size=(32, 56)).long()
  32. inputs['ner_token_type_ids'] = torch.zeros(size=(32, 56)).long()
  33. inputs['ner_start_labels'] = torch.zeros(size=(32, 8, 56)).float()
  34. inputs['ner_end_labels'] = torch.zeros(size=(32, 8, 56)).float()
  35. return inputs
  36. def get_pointer_loss(self,
  37. start_logits,
  38. end_logits,
  39. attention_mask,
  40. start_labels,
  41. end_labels,
  42. criterion):
  43. start_logits = start_logits.view(-1)
  44. end_logits = end_logits.view(-1)
  45. active_loss = attention_mask.view(-1) == 1
  46. active_start_logits = start_logits[active_loss]
  47. active_end_logits = end_logits[active_loss]
  48. active_start_labels = start_labels.view(-1)[active_loss]
  49. active_end_labels = end_labels.view(-1)[active_loss]
  50. start_loss = criterion(active_start_logits, active_start_labels)
  51. end_loss = criterion(active_end_logits, active_end_labels)
  52. loss = start_loss + end_loss
  53. return loss
  54. def ner_forward_1(self,
  55. ner_input_ids,
  56. ner_attention_mask,
  57. ner_start_labels=None,
  58. ner_end_labels=None):
  59. # 四个参数格式均为[tensor(), tensor(), ...]
  60. # 一次传入batch_size个样本,每个样本含多条句子
  61. # 编码还需要一个个样本进行,若每个样本句子太长,还需截断分批处理
  62. all_start_logits = []
  63. all_end_logits = []
  64. ner_loss = None
  65. for i in range(len(ner_end_labels)): # 有len(ner_end_labels)个样本/文档
  66. input_ids = ner_input_ids[i].to(self.args.device)
  67. attention_mask = ner_attention_mask[i].to(self.args.device)
  68. # start_labels = ner_start_labels[i].to(self.args.device)
  69. # end_labels = ner_end_labels[i].to(self.args.device)
  70. # 根据sent_num的大小分段进行编码(sent_num太大时,显存不够)
  71. max_encoder_len = self.args.max_encoder_sent_len
  72. batch_num = int(input_ids.size(0) / max_encoder_len)
  73. last_hidden_states = []
  74. if batch_num > 0:
  75. for i in range(batch_num):
  76. truncated_input_ids = input_ids[i*max_encoder_len:(i+1)*max_encoder_len, :]
  77. truncated_attention_mask = attention_mask[i*max_encoder_len:(i+1)*max_encoder_len, :]
  78. truncated_outputs = self.bert_model(truncated_input_ids, attention_mask=truncated_attention_mask)
  79. last_hidden_states.append(truncated_outputs.last_hidden_state) # .detach().cpu()
  80. if input_ids.size(0) - batch_num * max_encoder_len > 0:
  81. truncated_input_ids = input_ids[batch_num*max_encoder_len:, :]
  82. truncated_attention_mask = attention_mask[batch_num*max_encoder_len:, :]
  83. truncated_outputs = self.bert_model(truncated_input_ids, attention_mask=truncated_attention_mask)
  84. last_hidden_states.append(truncated_outputs.last_hidden_state) # .detach().cpu()
  85. if len(last_hidden_states) > 1:
  86. seq_bert_output = torch.cat(last_hidden_states, dim=0)
  87. else:
  88. seq_bert_output = last_hidden_states[0] # [sent_num, seq_len, hidden_dim]
  89. # 忽略padding求均值的方法
  90. seq_bert_output = seq_bert_output[:, 1:-1, :] #.to(self.args.device) # 忽略[CLS]和[SEP] ???
  91. expanded_attention_mask = attention_mask[:,1:-1].unsqueeze(-1).expand_as(seq_bert_output) # [sent_num,seq_len,hidden_size]
  92. sum_of_non_padded_output = (seq_bert_output * expanded_attention_mask).sum(dim=1) # 仅对有效位置求和
  93. mean_encoder_outputs = sum_of_non_padded_output / expanded_attention_mask.sum(dim=1) # [sent_num, hidden_size]
  94. # dropout
  95. pooled_output = self.dropout(mean_encoder_outputs)
  96. # 计算Pointer位置的loss
  97. # for i in range(self.ner_num_labels): # 每个ner任务单独计算
  98. if self.ner_num_labels == 1:
  99. # 对每个pointer位置接个线性层
  100. start_logit = self.module_start_list[0](pooled_output).squeeze(1) #[sent_num]
  101. end_logit = self.module_end_list[0](pooled_output).squeeze(1) #[sent_num]
  102. # print(start_logit, start_logit.size())
  103. all_start_logits.append(start_logit)
  104. all_end_logits.append(end_logit)
  105. # 将批次数据合成一个loss值
  106. concat_start_logits = torch.cat(all_start_logits, dim=0)
  107. concat_end_logits = torch.cat(all_end_logits, dim=0)
  108. all_start_labels = torch.cat(ner_start_labels, dim=0).to(self.args.device)
  109. all_end_labels = torch.cat(ner_end_labels, dim=0).to(self.args.device)
  110. start_loss = self.ner_criterion(concat_start_logits, all_start_labels) # 起始位置loss值
  111. end_loss = self.ner_criterion(concat_end_logits, all_end_labels) # 结束位置loss值
  112. if ner_loss is None:
  113. ner_loss = start_loss + end_loss
  114. else:
  115. ner_loss += (start_loss + end_loss)
  116. res = {
  117. "ner_start_logits": [a.detach().cpu() for a in all_start_logits],
  118. "ner_end_logits": [a.detach().cpu() for a in all_end_logits],
  119. "ner_loss": ner_loss,
  120. }
  121. return res
  122. def ner_bc_forward(self,
  123. ner_input_ids,
  124. ner_attention_mask,
  125. ner_start_labels=None,
  126. ner_end_labels=None,
  127. ner_content_labels=None):
  128. """
  129. ner:topic识别任务; bc:二分类任务
  130. 这里将试题的开始、结束位置预测,与试题的判断(是否属于试题内容)任务合并在一起!!!
  131. 原因:根据预测标签进行试题切分时,按start的位置划分错误最少,但会出现题型行也会被划分到试题中,故需要单独判断!
  132. # 四个参数格式均为[tensor(), tensor(), ...]
  133. # 一次传入batch_size=1个样本被截取的一部分,每个样本含多条句子
  134. # 编码还需要一个个样本进行,若每个样本句子太长,需截断分批处理
  135. """
  136. all_start_logits = []
  137. all_end_logits = []
  138. all_content_logits = []
  139. ner_loss = None
  140. outputs = self.bert_model(ner_input_ids, attention_mask=ner_attention_mask)
  141. # 取cls位置的表示
  142. cls_bert_output = outputs.pooler_output # [sent_num, hidden_size]
  143. # dropout
  144. pooled_output = self.dropout(cls_bert_output)
  145. # 尝试内存回收
  146. del outputs,cls_bert_output
  147. gc.collect()
  148. # 对每个pointer位置接个线性层
  149. # 对试题判断任务也接个线性层
  150. start_logit = self.module_start_list[0](pooled_output).squeeze(1) #[sent_num]
  151. end_logit = self.module_end_list[0](pooled_output).squeeze(1) #[sent_num]
  152. content_logit = self.module_content_list[0](pooled_output).squeeze(1) #[sent_num]
  153. all_start_logits.append(start_logit)
  154. all_end_logits.append(end_logit)
  155. all_content_logits.append(content_logit)
  156. if ner_start_labels is not None and ner_end_labels is not None:
  157. start_loss = self.ner_criterion(start_logit, ner_start_labels) # 起始位置loss值
  158. end_loss = self.ner_criterion(end_logit, ner_end_labels) # 结束位置loss值
  159. content_loss = self.ner_criterion(content_logit, ner_content_labels)
  160. if ner_loss is None:
  161. ner_loss = start_loss + end_loss + content_loss
  162. else:
  163. ner_loss += (start_loss + end_loss + content_loss)
  164. res = {
  165. "ner_start_logits": [a.detach().cpu() for a in all_start_logits],
  166. "ner_end_logits": [a.detach().cpu() for a in all_end_logits],
  167. "ner_content_logits": [a.detach().cpu() for a in all_content_logits],
  168. "ner_loss": ner_loss,
  169. }
  170. return res
  171. def ner_forward(self,
  172. ner_input_ids,
  173. ner_attention_mask,
  174. ner_start_labels=None,
  175. ner_end_labels=None):
  176. """
  177. # 四个参数格式均为[tensor(), tensor(), ...]
  178. # 一次传入batch_size=1个样本被截取的一部分,每个样本含多条句子
  179. # 编码还需要一个个样本进行,若每个样本句子太长,还需截断分批处理
  180. """
  181. all_start_logits = []
  182. all_end_logits = []
  183. ner_loss = None
  184. res = {
  185. "ner_start_logits": None,
  186. "ner_end_logits": None,
  187. "ner_loss": None
  188. }
  189. outputs = self.bert_model(ner_input_ids, attention_mask=ner_attention_mask)
  190. # 取每个token位置的表示再求平均
  191. # last_hidden_states = outputs.last_hidden_state # [sent_num, seq_len, hidden_dim]
  192. # seq_bert_output = last_hidden_states[:, 1:-1, :] # 忽略[CLS]和[SEP] ???
  193. # # 忽略padding求均值的方法
  194. # expanded_attention_mask = ner_attention_mask[:,1:-1].unsqueeze(-1).expand_as(seq_bert_output) # [sent_num,seq_len,hidden_size]
  195. # sum_of_non_padded_output = (seq_bert_output * expanded_attention_mask).sum(dim=1) # 仅对有效位置求和
  196. # mean_encoder_outputs = sum_of_non_padded_output / expanded_attention_mask.sum(dim=1) # [sent_num, hidden_size]
  197. # 取cls位置的表示
  198. cls_bert_output = outputs.pooler_output # [sent_num, hidden_size]
  199. # dropout
  200. pooled_output = self.dropout(cls_bert_output)
  201. # 计算Pointer位置的loss
  202. if self.ner_num_labels == 1:
  203. # 对每个pointer位置接个线性层
  204. # tensor.squeeze(1):移除大小为1的第二个维度
  205. start_logit = self.module_start_list[0](pooled_output).squeeze(1) #[sent_num]
  206. end_logit = self.module_end_list[0](pooled_output).squeeze(1) #[sent_num]
  207. all_start_logits.append(start_logit)
  208. all_end_logits.append(end_logit)
  209. if ner_start_labels is not None and ner_end_labels is not None:
  210. start_loss = self.ner_criterion(start_logit, ner_start_labels) # 起始位置loss值
  211. end_loss = self.ner_criterion(end_logit, ner_end_labels) # 结束位置loss值
  212. if ner_loss is None:
  213. ner_loss = start_loss + end_loss
  214. else:
  215. ner_loss += (start_loss + end_loss)
  216. res = {
  217. "ner_start_logits": [a.detach().cpu() for a in all_start_logits],
  218. "ner_end_logits": [a.detach().cpu() for a in all_end_logits],
  219. "ner_loss": ner_loss,
  220. }
  221. return res
  222. def forward(self,
  223. ner_input_ids=None,
  224. # ner_token_type_ids=None,
  225. ner_attention_mask=None,
  226. ner_start_labels=None,
  227. ner_end_labels=None,
  228. ner_content_labels=None,
  229. ):
  230. res = {
  231. "ner_output": None,
  232. "re_output": None,
  233. "event_output": None
  234. }
  235. if "ner" in self.tasks:
  236. # ner_output = self.ner_forward(
  237. # ner_input_ids,
  238. # # ner_token_type_ids,
  239. # ner_attention_mask,
  240. # ner_start_labels,
  241. # ner_end_labels,
  242. # )
  243. ner_output = self.ner_bc_forward(
  244. ner_input_ids,
  245. # ner_token_type_ids,
  246. ner_attention_mask,
  247. ner_start_labels,
  248. ner_end_labels,
  249. ner_content_labels,
  250. )
  251. res["ner_output"] = ner_output
  252. return res
  253. if __name__ == '__main__':
  254. inputs = UIEModel.build_dummpy_inputs()
  255. class Args:
  256. bert_dir = "../chinese-bert-wwm-ext/"
  257. ner_num_labels = 8
  258. re_num_labels = 16
  259. tasks = ["re_rel"]
  260. args = Args()
  261. model = UIEModel(args)
  262. res = model(
  263. ner_input_ids=inputs["ner_input_ids"],
  264. ner_token_type_ids=inputs["ner_token_type_ids"],
  265. ner_attention_mask=inputs["ner_attention_mask"],
  266. ner_start_labels=inputs["ner_start_labels"],
  267. ner_end_labels=inputs["ner_end_labels"],
  268. )
  269. print(res)