model.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. import torch
  2. import torch.nn as nn
  3. from transformers import BertConfig, BertModel, BertForSequenceClassification
  4. class UIEModel(nn.Module):
  5. def __init__(self, args):
  6. super(UIEModel, self).__init__()
  7. self.args = args
  8. self.tasks = args.tasks
  9. bert_dir = args.bert_dir
  10. self.bert_config = BertConfig.from_pretrained(bert_dir)
  11. self.bert_model = BertModel.from_pretrained(bert_dir)
  12. # self.bert_model.load_state_dict(torch.load(self.args.bert_pt_dir, map_location="cuda"))
  13. # self.bert_model = BertForSequenceClassification.from_pretrained(args.bert_pt_dir)
  14. if "ner" in args.tasks:
  15. self.ner_num_labels = args.ner_num_labels
  16. self.module_start_list = nn.ModuleList()
  17. self.module_end_list = nn.ModuleList()
  18. self.module_content_list = nn.ModuleList() # 增加一层内容判断
  19. for i in range(args.ner_num_labels):
  20. self.module_start_list.append(nn.Linear(self.bert_config.hidden_size, 1))
  21. self.module_end_list.append(nn.Linear(self.bert_config.hidden_size, 1))
  22. self.module_content_list.append(nn.Linear(self.bert_config.hidden_size, 1))
  23. self.ner_criterion = nn.BCEWithLogitsLoss()
  24. self.dropout = nn.Dropout(self.bert_config.hidden_dropout_prob)
  25. @staticmethod
  26. def build_dummpy_inputs():
  27. inputs = {}
  28. inputs['ner_input_ids'] = torch.LongTensor(
  29. torch.randint(low=1, high=10, size=(32, 56)))
  30. inputs['ner_attention_mask'] = torch.ones(size=(32, 56)).long()
  31. inputs['ner_token_type_ids'] = torch.zeros(size=(32, 56)).long()
  32. inputs['ner_start_labels'] = torch.zeros(size=(32, 8, 56)).float()
  33. inputs['ner_end_labels'] = torch.zeros(size=(32, 8, 56)).float()
  34. return inputs
  35. def get_pointer_loss(self,
  36. start_logits,
  37. end_logits,
  38. attention_mask,
  39. start_labels,
  40. end_labels,
  41. criterion):
  42. start_logits = start_logits.view(-1)
  43. end_logits = end_logits.view(-1)
  44. active_loss = attention_mask.view(-1) == 1
  45. active_start_logits = start_logits[active_loss]
  46. active_end_logits = end_logits[active_loss]
  47. active_start_labels = start_labels.view(-1)[active_loss]
  48. active_end_labels = end_labels.view(-1)[active_loss]
  49. start_loss = criterion(active_start_logits, active_start_labels)
  50. end_loss = criterion(active_end_logits, active_end_labels)
  51. loss = start_loss + end_loss
  52. return loss
  53. def ner_forward_1(self,
  54. ner_input_ids,
  55. ner_attention_mask,
  56. ner_start_labels=None,
  57. ner_end_labels=None):
  58. # 四个参数格式均为[tensor(), tensor(), ...]
  59. # 一次传入batch_size个样本,每个样本含多条句子
  60. # 编码还需要一个个样本进行,若每个样本句子太长,还需截断分批处理
  61. all_start_logits = []
  62. all_end_logits = []
  63. ner_loss = None
  64. for i in range(len(ner_end_labels)): # 有len(ner_end_labels)个样本/文档
  65. input_ids = ner_input_ids[i].to(self.args.device)
  66. attention_mask = ner_attention_mask[i].to(self.args.device)
  67. # start_labels = ner_start_labels[i].to(self.args.device)
  68. # end_labels = ner_end_labels[i].to(self.args.device)
  69. # 根据sent_num的大小分段进行编码(sent_num太大时,显存不够)
  70. max_encoder_len = self.args.max_encoder_sent_len
  71. batch_num = int(input_ids.size(0) / max_encoder_len)
  72. last_hidden_states = []
  73. if batch_num > 0:
  74. for i in range(batch_num):
  75. truncated_input_ids = input_ids[i*max_encoder_len:(i+1)*max_encoder_len, :]
  76. truncated_attention_mask = attention_mask[i*max_encoder_len:(i+1)*max_encoder_len, :]
  77. truncated_outputs = self.bert_model(truncated_input_ids, attention_mask=truncated_attention_mask)
  78. last_hidden_states.append(truncated_outputs.last_hidden_state) # .detach().cpu()
  79. if input_ids.size(0) - batch_num * max_encoder_len > 0:
  80. truncated_input_ids = input_ids[batch_num*max_encoder_len:, :]
  81. truncated_attention_mask = attention_mask[batch_num*max_encoder_len:, :]
  82. truncated_outputs = self.bert_model(truncated_input_ids, attention_mask=truncated_attention_mask)
  83. last_hidden_states.append(truncated_outputs.last_hidden_state) # .detach().cpu()
  84. if len(last_hidden_states) > 1:
  85. seq_bert_output = torch.cat(last_hidden_states, dim=0)
  86. else:
  87. seq_bert_output = last_hidden_states[0] # [sent_num, seq_len, hidden_dim]
  88. # 忽略padding求均值的方法
  89. seq_bert_output = seq_bert_output[:, 1:-1, :] #.to(self.args.device) # 忽略[CLS]和[SEP] ???
  90. expanded_attention_mask = attention_mask[:,1:-1].unsqueeze(-1).expand_as(seq_bert_output) # [sent_num,seq_len,hidden_size]
  91. sum_of_non_padded_output = (seq_bert_output * expanded_attention_mask).sum(dim=1) # 仅对有效位置求和
  92. mean_encoder_outputs = sum_of_non_padded_output / expanded_attention_mask.sum(dim=1) # [sent_num, hidden_size]
  93. # dropout
  94. pooled_output = self.dropout(mean_encoder_outputs)
  95. # 计算Pointer位置的loss
  96. # for i in range(self.ner_num_labels): # 每个ner任务单独计算
  97. if self.ner_num_labels == 1:
  98. # 对每个pointer位置接个线性层
  99. start_logit = self.module_start_list[0](pooled_output).squeeze(1) #[sent_num]
  100. end_logit = self.module_end_list[0](pooled_output).squeeze(1) #[sent_num]
  101. # print(start_logit, start_logit.size())
  102. all_start_logits.append(start_logit)
  103. all_end_logits.append(end_logit)
  104. # 将批次数据合成一个loss值
  105. concat_start_logits = torch.cat(all_start_logits, dim=0)
  106. concat_end_logits = torch.cat(all_end_logits, dim=0)
  107. all_start_labels = torch.cat(ner_start_labels, dim=0).to(self.args.device)
  108. all_end_labels = torch.cat(ner_end_labels, dim=0).to(self.args.device)
  109. start_loss = self.ner_criterion(concat_start_logits, all_start_labels) # 起始位置loss值
  110. end_loss = self.ner_criterion(concat_end_logits, all_end_labels) # 结束位置loss值
  111. if ner_loss is None:
  112. ner_loss = start_loss + end_loss
  113. else:
  114. ner_loss += (start_loss + end_loss)
  115. res = {
  116. "ner_start_logits": [a.detach().cpu() for a in all_start_logits],
  117. "ner_end_logits": [a.detach().cpu() for a in all_end_logits],
  118. "ner_loss": ner_loss,
  119. }
  120. return res
  121. def ner_bc_forward(self,
  122. ner_input_ids,
  123. ner_attention_mask,
  124. ner_start_labels=None,
  125. ner_end_labels=None,
  126. ner_content_labels=None):
  127. """
  128. ner:topic识别任务; bc:二分类任务
  129. 这里将试题的开始、结束位置预测,与试题的判断(是否属于试题内容)任务合并在一起!!!
  130. 原因:根据预测标签进行试题切分时,按start的位置划分错误最少,但会出现题型行也会被划分到试题中,故需要单独判断!
  131. # 四个参数格式均为[tensor(), tensor(), ...]
  132. # 一次传入batch_size=1个样本被截取的一部分,每个样本含多条句子
  133. # 编码还需要一个个样本进行,若每个样本句子太长,需截断分批处理
  134. """
  135. all_start_logits = []
  136. all_end_logits = []
  137. all_content_logits = []
  138. ner_loss = None
  139. outputs = self.bert_model(ner_input_ids, attention_mask=ner_attention_mask)
  140. # 取cls位置的表示
  141. cls_bert_output = outputs.pooler_output # [sent_num, hidden_size]
  142. # dropout
  143. pooled_output = self.dropout(cls_bert_output)
  144. # 对每个pointer位置接个线性层
  145. # 对试题判断任务也接个线性层
  146. start_logit = self.module_start_list[0](pooled_output).squeeze(1) #[sent_num]
  147. end_logit = self.module_end_list[0](pooled_output).squeeze(1) #[sent_num]
  148. content_logit = self.module_content_list[0](pooled_output).squeeze(1) #[sent_num]
  149. all_start_logits.append(start_logit)
  150. all_end_logits.append(end_logit)
  151. all_content_logits.append(content_logit)
  152. if ner_start_labels is not None and ner_end_labels is not None:
  153. start_loss = self.ner_criterion(start_logit, ner_start_labels) # 起始位置loss值
  154. end_loss = self.ner_criterion(end_logit, ner_end_labels) # 结束位置loss值
  155. content_loss = self.ner_criterion(content_logit, ner_content_labels)
  156. if ner_loss is None:
  157. ner_loss = start_loss + end_loss + content_loss
  158. else:
  159. ner_loss += (start_loss + end_loss + content_loss)
  160. res = {
  161. "ner_start_logits": [a.detach().cpu() for a in all_start_logits],
  162. "ner_end_logits": [a.detach().cpu() for a in all_end_logits],
  163. "ner_content_logits": [a.detach().cpu() for a in all_content_logits],
  164. "ner_loss": ner_loss,
  165. }
  166. return res
  167. def ner_forward(self,
  168. ner_input_ids,
  169. ner_attention_mask,
  170. ner_start_labels=None,
  171. ner_end_labels=None):
  172. """
  173. # 四个参数格式均为[tensor(), tensor(), ...]
  174. # 一次传入batch_size=1个样本被截取的一部分,每个样本含多条句子
  175. # 编码还需要一个个样本进行,若每个样本句子太长,还需截断分批处理
  176. """
  177. all_start_logits = []
  178. all_end_logits = []
  179. ner_loss = None
  180. res = {
  181. "ner_start_logits": None,
  182. "ner_end_logits": None,
  183. "ner_loss": None
  184. }
  185. outputs = self.bert_model(ner_input_ids, attention_mask=ner_attention_mask)
  186. # 取每个token位置的表示再求平均
  187. # last_hidden_states = outputs.last_hidden_state # [sent_num, seq_len, hidden_dim]
  188. # seq_bert_output = last_hidden_states[:, 1:-1, :] # 忽略[CLS]和[SEP] ???
  189. # # 忽略padding求均值的方法
  190. # expanded_attention_mask = ner_attention_mask[:,1:-1].unsqueeze(-1).expand_as(seq_bert_output) # [sent_num,seq_len,hidden_size]
  191. # sum_of_non_padded_output = (seq_bert_output * expanded_attention_mask).sum(dim=1) # 仅对有效位置求和
  192. # mean_encoder_outputs = sum_of_non_padded_output / expanded_attention_mask.sum(dim=1) # [sent_num, hidden_size]
  193. # 取cls位置的表示
  194. cls_bert_output = outputs.pooler_output # [sent_num, hidden_size]
  195. # dropout
  196. pooled_output = self.dropout(cls_bert_output)
  197. # 计算Pointer位置的loss
  198. if self.ner_num_labels == 1:
  199. # 对每个pointer位置接个线性层
  200. # tensor.squeeze(1):移除大小为1的第二个维度
  201. start_logit = self.module_start_list[0](pooled_output).squeeze(1) #[sent_num]
  202. end_logit = self.module_end_list[0](pooled_output).squeeze(1) #[sent_num]
  203. all_start_logits.append(start_logit)
  204. all_end_logits.append(end_logit)
  205. if ner_start_labels is not None and ner_end_labels is not None:
  206. start_loss = self.ner_criterion(start_logit, ner_start_labels) # 起始位置loss值
  207. end_loss = self.ner_criterion(end_logit, ner_end_labels) # 结束位置loss值
  208. if ner_loss is None:
  209. ner_loss = start_loss + end_loss
  210. else:
  211. ner_loss += (start_loss + end_loss)
  212. res = {
  213. "ner_start_logits": [a.detach().cpu() for a in all_start_logits],
  214. "ner_end_logits": [a.detach().cpu() for a in all_end_logits],
  215. "ner_loss": ner_loss,
  216. }
  217. return res
  218. def forward(self,
  219. ner_input_ids=None,
  220. # ner_token_type_ids=None,
  221. ner_attention_mask=None,
  222. ner_start_labels=None,
  223. ner_end_labels=None,
  224. ner_content_labels=None,
  225. ):
  226. res = {
  227. "ner_output": None,
  228. "re_output": None,
  229. "event_output": None
  230. }
  231. if "ner" in self.tasks:
  232. # ner_output = self.ner_forward(
  233. # ner_input_ids,
  234. # # ner_token_type_ids,
  235. # ner_attention_mask,
  236. # ner_start_labels,
  237. # ner_end_labels,
  238. # )
  239. ner_output = self.ner_bc_forward(
  240. ner_input_ids,
  241. # ner_token_type_ids,
  242. ner_attention_mask,
  243. ner_start_labels,
  244. ner_end_labels,
  245. ner_content_labels,
  246. )
  247. res["ner_output"] = ner_output
  248. return res
  249. if __name__ == '__main__':
  250. inputs = UIEModel.build_dummpy_inputs()
  251. class Args:
  252. bert_dir = "../chinese-bert-wwm-ext/"
  253. ner_num_labels = 8
  254. re_num_labels = 16
  255. tasks = ["re_rel"]
  256. args = Args()
  257. model = UIEModel(args)
  258. res = model(
  259. ner_input_ids=inputs["ner_input_ids"],
  260. ner_token_type_ids=inputs["ner_token_type_ids"],
  261. ner_attention_mask=inputs["ner_attention_mask"],
  262. ner_start_labels=inputs["ner_start_labels"],
  263. ner_end_labels=inputs["ner_end_labels"],
  264. )
  265. print(res)