123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316 |
- import torch
- import gc
- import torch.nn as nn
- from transformers import BertConfig, BertModel, BertForSequenceClassification
- class UIEModel(nn.Module):
- def __init__(self, args):
- super(UIEModel, self).__init__()
- self.args = args
-
- self.tasks = args.tasks
- bert_dir = args.bert_dir
- self.bert_config = BertConfig.from_pretrained(bert_dir)
- self.bert_model = BertModel.from_pretrained(bert_dir)
- # self.bert_model.load_state_dict(torch.load(self.args.bert_pt_dir, map_location="cuda"))
- # self.bert_model = BertForSequenceClassification.from_pretrained(args.bert_pt_dir)
- if "ner" in args.tasks:
- self.ner_num_labels = args.ner_num_labels
- self.module_start_list = nn.ModuleList()
- self.module_end_list = nn.ModuleList()
- self.module_content_list = nn.ModuleList() # 增加一层内容判断
- for i in range(args.ner_num_labels):
- self.module_start_list.append(nn.Linear(self.bert_config.hidden_size, 1))
- self.module_end_list.append(nn.Linear(self.bert_config.hidden_size, 1))
- self.module_content_list.append(nn.Linear(self.bert_config.hidden_size, 1))
- self.ner_criterion = nn.BCEWithLogitsLoss()
- self.dropout = nn.Dropout(self.bert_config.hidden_dropout_prob)
- @staticmethod
- def build_dummpy_inputs():
- inputs = {}
- inputs['ner_input_ids'] = torch.LongTensor(
- torch.randint(low=1, high=10, size=(32, 56)))
- inputs['ner_attention_mask'] = torch.ones(size=(32, 56)).long()
- inputs['ner_token_type_ids'] = torch.zeros(size=(32, 56)).long()
- inputs['ner_start_labels'] = torch.zeros(size=(32, 8, 56)).float()
- inputs['ner_end_labels'] = torch.zeros(size=(32, 8, 56)).float()
- return inputs
- def get_pointer_loss(self,
- start_logits,
- end_logits,
- attention_mask,
- start_labels,
- end_labels,
- criterion):
- start_logits = start_logits.view(-1)
- end_logits = end_logits.view(-1)
- active_loss = attention_mask.view(-1) == 1
- active_start_logits = start_logits[active_loss]
- active_end_logits = end_logits[active_loss]
- active_start_labels = start_labels.view(-1)[active_loss]
- active_end_labels = end_labels.view(-1)[active_loss]
- start_loss = criterion(active_start_logits, active_start_labels)
- end_loss = criterion(active_end_logits, active_end_labels)
- loss = start_loss + end_loss
- return loss
- def ner_forward_1(self,
- ner_input_ids,
- ner_attention_mask,
- ner_start_labels=None,
- ner_end_labels=None):
- # 四个参数格式均为[tensor(), tensor(), ...]
- # 一次传入batch_size个样本,每个样本含多条句子
- # 编码还需要一个个样本进行,若每个样本句子太长,还需截断分批处理
-
- all_start_logits = []
- all_end_logits = []
- ner_loss = None
- for i in range(len(ner_end_labels)): # 有len(ner_end_labels)个样本/文档
- input_ids = ner_input_ids[i].to(self.args.device)
- attention_mask = ner_attention_mask[i].to(self.args.device)
- # start_labels = ner_start_labels[i].to(self.args.device)
- # end_labels = ner_end_labels[i].to(self.args.device)
- # 根据sent_num的大小分段进行编码(sent_num太大时,显存不够)
- max_encoder_len = self.args.max_encoder_sent_len
- batch_num = int(input_ids.size(0) / max_encoder_len)
- last_hidden_states = []
- if batch_num > 0:
- for i in range(batch_num):
- truncated_input_ids = input_ids[i*max_encoder_len:(i+1)*max_encoder_len, :]
- truncated_attention_mask = attention_mask[i*max_encoder_len:(i+1)*max_encoder_len, :]
- truncated_outputs = self.bert_model(truncated_input_ids, attention_mask=truncated_attention_mask)
- last_hidden_states.append(truncated_outputs.last_hidden_state) # .detach().cpu()
- if input_ids.size(0) - batch_num * max_encoder_len > 0:
- truncated_input_ids = input_ids[batch_num*max_encoder_len:, :]
- truncated_attention_mask = attention_mask[batch_num*max_encoder_len:, :]
- truncated_outputs = self.bert_model(truncated_input_ids, attention_mask=truncated_attention_mask)
- last_hidden_states.append(truncated_outputs.last_hidden_state) # .detach().cpu()
- if len(last_hidden_states) > 1:
- seq_bert_output = torch.cat(last_hidden_states, dim=0)
- else:
- seq_bert_output = last_hidden_states[0] # [sent_num, seq_len, hidden_dim]
- # 忽略padding求均值的方法
- seq_bert_output = seq_bert_output[:, 1:-1, :] #.to(self.args.device) # 忽略[CLS]和[SEP] ???
- expanded_attention_mask = attention_mask[:,1:-1].unsqueeze(-1).expand_as(seq_bert_output) # [sent_num,seq_len,hidden_size]
- sum_of_non_padded_output = (seq_bert_output * expanded_attention_mask).sum(dim=1) # 仅对有效位置求和
- mean_encoder_outputs = sum_of_non_padded_output / expanded_attention_mask.sum(dim=1) # [sent_num, hidden_size]
- # dropout
- pooled_output = self.dropout(mean_encoder_outputs)
- # 计算Pointer位置的loss
- # for i in range(self.ner_num_labels): # 每个ner任务单独计算
- if self.ner_num_labels == 1:
- # 对每个pointer位置接个线性层
- start_logit = self.module_start_list[0](pooled_output).squeeze(1) #[sent_num]
- end_logit = self.module_end_list[0](pooled_output).squeeze(1) #[sent_num]
- # print(start_logit, start_logit.size())
- all_start_logits.append(start_logit)
- all_end_logits.append(end_logit)
-
- # 将批次数据合成一个loss值
- concat_start_logits = torch.cat(all_start_logits, dim=0)
- concat_end_logits = torch.cat(all_end_logits, dim=0)
- all_start_labels = torch.cat(ner_start_labels, dim=0).to(self.args.device)
- all_end_labels = torch.cat(ner_end_labels, dim=0).to(self.args.device)
- start_loss = self.ner_criterion(concat_start_logits, all_start_labels) # 起始位置loss值
- end_loss = self.ner_criterion(concat_end_logits, all_end_labels) # 结束位置loss值
- if ner_loss is None:
- ner_loss = start_loss + end_loss
- else:
- ner_loss += (start_loss + end_loss)
-
- res = {
- "ner_start_logits": [a.detach().cpu() for a in all_start_logits],
- "ner_end_logits": [a.detach().cpu() for a in all_end_logits],
- "ner_loss": ner_loss,
- }
- return res
- def ner_bc_forward(self,
- ner_input_ids,
- ner_attention_mask,
- ner_start_labels=None,
- ner_end_labels=None,
- ner_content_labels=None):
- """
- ner:topic识别任务; bc:二分类任务
- 这里将试题的开始、结束位置预测,与试题的判断(是否属于试题内容)任务合并在一起!!!
- 原因:根据预测标签进行试题切分时,按start的位置划分错误最少,但会出现题型行也会被划分到试题中,故需要单独判断!
- # 四个参数格式均为[tensor(), tensor(), ...]
- # 一次传入batch_size=1个样本被截取的一部分,每个样本含多条句子
- # 编码还需要一个个样本进行,若每个样本句子太长,需截断分批处理
- """
- all_start_logits = []
- all_end_logits = []
- all_content_logits = []
- ner_loss = None
- outputs = self.bert_model(ner_input_ids, attention_mask=ner_attention_mask)
- # 取cls位置的表示
- cls_bert_output = outputs.pooler_output # [sent_num, hidden_size]
- # dropout
- pooled_output = self.dropout(cls_bert_output)
- # 尝试内存回收
- del outputs,cls_bert_output
- gc.collect()
- # 对每个pointer位置接个线性层
- # 对试题判断任务也接个线性层
- start_logit = self.module_start_list[0](pooled_output).squeeze(1) #[sent_num]
- end_logit = self.module_end_list[0](pooled_output).squeeze(1) #[sent_num]
- content_logit = self.module_content_list[0](pooled_output).squeeze(1) #[sent_num]
- all_start_logits.append(start_logit)
- all_end_logits.append(end_logit)
- all_content_logits.append(content_logit)
- if ner_start_labels is not None and ner_end_labels is not None:
- start_loss = self.ner_criterion(start_logit, ner_start_labels) # 起始位置loss值
- end_loss = self.ner_criterion(end_logit, ner_end_labels) # 结束位置loss值
- content_loss = self.ner_criterion(content_logit, ner_content_labels)
- if ner_loss is None:
- ner_loss = start_loss + end_loss + content_loss
- else:
- ner_loss += (start_loss + end_loss + content_loss)
-
- res = {
- "ner_start_logits": [a.detach().cpu() for a in all_start_logits],
- "ner_end_logits": [a.detach().cpu() for a in all_end_logits],
- "ner_content_logits": [a.detach().cpu() for a in all_content_logits],
- "ner_loss": ner_loss,
- }
- return res
-
- def ner_forward(self,
- ner_input_ids,
- ner_attention_mask,
- ner_start_labels=None,
- ner_end_labels=None):
- """
- # 四个参数格式均为[tensor(), tensor(), ...]
- # 一次传入batch_size=1个样本被截取的一部分,每个样本含多条句子
- # 编码还需要一个个样本进行,若每个样本句子太长,还需截断分批处理
- """
- all_start_logits = []
- all_end_logits = []
- ner_loss = None
- res = {
- "ner_start_logits": None,
- "ner_end_logits": None,
- "ner_loss": None
- }
- outputs = self.bert_model(ner_input_ids, attention_mask=ner_attention_mask)
- # 取每个token位置的表示再求平均
- # last_hidden_states = outputs.last_hidden_state # [sent_num, seq_len, hidden_dim]
- # seq_bert_output = last_hidden_states[:, 1:-1, :] # 忽略[CLS]和[SEP] ???
- # # 忽略padding求均值的方法
- # expanded_attention_mask = ner_attention_mask[:,1:-1].unsqueeze(-1).expand_as(seq_bert_output) # [sent_num,seq_len,hidden_size]
- # sum_of_non_padded_output = (seq_bert_output * expanded_attention_mask).sum(dim=1) # 仅对有效位置求和
- # mean_encoder_outputs = sum_of_non_padded_output / expanded_attention_mask.sum(dim=1) # [sent_num, hidden_size]
- # 取cls位置的表示
- cls_bert_output = outputs.pooler_output # [sent_num, hidden_size]
- # dropout
- pooled_output = self.dropout(cls_bert_output)
-
- # 计算Pointer位置的loss
- if self.ner_num_labels == 1:
- # 对每个pointer位置接个线性层
- # tensor.squeeze(1):移除大小为1的第二个维度
- start_logit = self.module_start_list[0](pooled_output).squeeze(1) #[sent_num]
- end_logit = self.module_end_list[0](pooled_output).squeeze(1) #[sent_num]
- all_start_logits.append(start_logit)
- all_end_logits.append(end_logit)
- if ner_start_labels is not None and ner_end_labels is not None:
- start_loss = self.ner_criterion(start_logit, ner_start_labels) # 起始位置loss值
- end_loss = self.ner_criterion(end_logit, ner_end_labels) # 结束位置loss值
- if ner_loss is None:
- ner_loss = start_loss + end_loss
- else:
- ner_loss += (start_loss + end_loss)
-
- res = {
- "ner_start_logits": [a.detach().cpu() for a in all_start_logits],
- "ner_end_logits": [a.detach().cpu() for a in all_end_logits],
- "ner_loss": ner_loss,
- }
- return res
- def forward(self,
- ner_input_ids=None,
- # ner_token_type_ids=None,
- ner_attention_mask=None,
- ner_start_labels=None,
- ner_end_labels=None,
- ner_content_labels=None,
- ):
- res = {
- "ner_output": None,
- "re_output": None,
- "event_output": None
- }
- if "ner" in self.tasks:
- # ner_output = self.ner_forward(
- # ner_input_ids,
- # # ner_token_type_ids,
- # ner_attention_mask,
- # ner_start_labels,
- # ner_end_labels,
- # )
- ner_output = self.ner_bc_forward(
- ner_input_ids,
- # ner_token_type_ids,
- ner_attention_mask,
- ner_start_labels,
- ner_end_labels,
- ner_content_labels,
- )
- res["ner_output"] = ner_output
- return res
- if __name__ == '__main__':
- inputs = UIEModel.build_dummpy_inputs()
- class Args:
- bert_dir = "../chinese-bert-wwm-ext/"
- ner_num_labels = 8
- re_num_labels = 16
- tasks = ["re_rel"]
- args = Args()
- model = UIEModel(args)
- res = model(
- ner_input_ids=inputs["ner_input_ids"],
- ner_token_type_ids=inputs["ner_token_type_ids"],
- ner_attention_mask=inputs["ner_attention_mask"],
- ner_start_labels=inputs["ner_start_labels"],
- ner_end_labels=inputs["ner_end_labels"],
- )
- print(res)
|