|
@@ -320,7 +320,33 @@ class NerPipeline:
|
|
|
mirco_metrics[2]))
|
|
|
print(classification_report(role_metric, self.args.labels, self.args.id2label, total_count))
|
|
|
|
|
|
- def predict(self, **kwargs):
|
|
|
+ def predict(self, sentences_list):
|
|
|
+ with torch.no_grad(): # 不需要梯度计算的操作
|
|
|
+ inputs = self.args.tokenizer(sentences_list, padding='max_length', truncation=True,
|
|
|
+ max_length=self.args.max_seq_len, return_tensors='pt')
|
|
|
+ token_ids = inputs['input_ids'].to(self.args.device)
|
|
|
+ attention_mask = inputs['attention_mask'].to(self.args.device)
|
|
|
+
|
|
|
+
|
|
|
+ # tokens = ['[CLS]'] + tokens + ['[SEP]']
|
|
|
+ # token_ids = torch.from_numpy(np.array(encode_dict['input_ids'])).unsqueeze(0).to(self.args.device)
|
|
|
+ # attention_mask = torch.from_numpy(np.array(encode_dict['attention_mask'])).unsqueeze(0).to(
|
|
|
+ # self.args.device)
|
|
|
+ # token_type_ids = torch.from_numpy(np.array(encode_dict['token_type_ids'])).unsqueeze(0).to(self.args.device)
|
|
|
+ output = self.model(token_ids, attention_mask)
|
|
|
+ start_logits = output["ner_output"]["ner_start_logits"]
|
|
|
+ end_logits = output["ner_output"]["ner_end_logits"]
|
|
|
+ content_logits = output["ner_output"]["ner_content_logits"]
|
|
|
+
|
|
|
+ start_logits = sigmoid(start_logits[0])
|
|
|
+ end_logits = sigmoid(end_logits[0])
|
|
|
+ con_logits = sigmoid(content_logits[0])
|
|
|
+ return start_logits, end_logits, con_logits
|
|
|
+
|
|
|
+ def half_batch_predict(self, **kwargs):
|
|
|
+ """
|
|
|
+ 将一份文档截断分批次预测
|
|
|
+ """
|
|
|
# self.load_model()
|
|
|
# self.model.eval()
|
|
|
# self.model.to(self.args.device)
|
|
@@ -340,32 +366,42 @@ class NerPipeline:
|
|
|
# truncation="only_first",
|
|
|
# return_token_type_ids=True,
|
|
|
# return_attention_mask=True)
|
|
|
- inputs = self.args.tokenizer(sentences, padding='max_length', truncation=True,
|
|
|
- max_length=self.args.max_seq_len, return_tensors='pt')
|
|
|
- token_ids = inputs['input_ids'].to(self.args.device)
|
|
|
- attention_mask = inputs['attention_mask'].to(self.args.device)
|
|
|
-
|
|
|
- # tokens = ['[CLS]'] + tokens + ['[SEP]']
|
|
|
- # token_ids = torch.from_numpy(np.array(encode_dict['input_ids'])).unsqueeze(0).to(self.args.device)
|
|
|
- # attention_mask = torch.from_numpy(np.array(encode_dict['attention_mask'])).unsqueeze(0).to(
|
|
|
- # self.args.device)
|
|
|
- # token_type_ids = torch.from_numpy(np.array(encode_dict['token_type_ids'])).unsqueeze(0).to(self.args.device)
|
|
|
- output = self.model(token_ids, attention_mask)
|
|
|
- start_logits = output["ner_output"]["ner_start_logits"]
|
|
|
- end_logits = output["ner_output"]["ner_end_logits"]
|
|
|
- content_logits = output["ner_output"]["ner_content_logits"]
|
|
|
+ # 需要分固定句子处理,句子不能太长,因为显存不够
|
|
|
+ print("预测文档的句子数:", len(sentences))
|
|
|
+ max_input_len = 100
|
|
|
+ batch_num = int(len(sentences) / max_input_len)
|
|
|
+ start_logits, end_logits, con_logits = [], [], []
|
|
|
+ if batch_num > 0:
|
|
|
+ for i in range(batch_num):
|
|
|
+ left, right = i*max_input_len, (i+1)*max_input_len
|
|
|
+ if i == batch_num-1 and len(sentences) - (i+1)*max_input_len<28:
|
|
|
+ batch_num -= 1
|
|
|
+ break
|
|
|
+ l_edge = 10 if left > 0 else 0
|
|
|
+ r_edge = 10 # 左右多加几句共同参与
|
|
|
+ start_logit, end_logit, con_logit = self.predict(sentences[left-l_edge: right+r_edge])
|
|
|
+ start_logits.append(start_logit[l_edge:-r_edge])
|
|
|
+ end_logits.append(end_logit[l_edge:-r_edge])
|
|
|
+ con_logits.append(con_logit[l_edge:-r_edge])
|
|
|
+ if len(sentences) - batch_num * max_input_len > 0:
|
|
|
+ left = batch_num*max_input_len
|
|
|
+ l_edge = 10 if left > 0 else 0
|
|
|
+ start_logit, end_logit, con_logit = self.predict(sentences[left-l_edge:])
|
|
|
+ start_logits.append(start_logit[l_edge:])
|
|
|
+ end_logits.append(end_logit[l_edge:])
|
|
|
+ con_logits.append(con_logit[l_edge:])
|
|
|
+ start_logits = torch.cat(start_logits, dim=0)
|
|
|
+ end_logits = torch.cat(end_logits, dim=0)
|
|
|
+ con_logits = torch.cat(con_logits, dim=0)
|
|
|
+ # 不分段预测
|
|
|
+ # start_logits, end_logits, con_logits = self.predict(sentences)
|
|
|
|
|
|
- start_logits = sigmoid(start_logits[0])
|
|
|
- end_logits = sigmoid(end_logits[0])
|
|
|
- con_logits = sigmoid(content_logits[0])
|
|
|
-
|
|
|
- # print("start_logits:::", start_logits)
|
|
|
pred_entities, topic_item_pred = topic_ner_decode(start_logits, end_logits, con_logits, sentences, self.args.id2label)
|
|
|
# pprint(dict(pred_entities))
|
|
|
split_topic_idx = []
|
|
|
for i in pred_entities['TOPIC']:
|
|
|
split_topic_idx.append((i[-1], i[-1]+len(i[0])))
|
|
|
- # print(split_topic_idx)
|
|
|
+ print("split_topic_idx:",split_topic_idx)
|
|
|
return dict(pred_entities), split_topic_idx, topic_item_pred
|
|
|
|
|
|
|