cdZWj 6 месяцев назад
Родитель
Сommit
4fca00c7e2
2 измененных файлов с 58 добавлено и 22 удалено
  1. 57 21
      PointerNet/main.py
  2. 1 1
      PointerNet/predictor.py

+ 57 - 21
PointerNet/main.py

@@ -320,7 +320,33 @@ class NerPipeline:
                                                                                mirco_metrics[2]))
             print(classification_report(role_metric, self.args.labels, self.args.id2label, total_count))
 
-    def predict(self, **kwargs):
+    def predict(self, sentences_list):
+        with torch.no_grad():  # 不需要梯度计算的操作     
+            inputs = self.args.tokenizer(sentences_list, padding='max_length', truncation=True, 
+                                    max_length=self.args.max_seq_len, return_tensors='pt')
+            token_ids = inputs['input_ids'].to(self.args.device)
+            attention_mask = inputs['attention_mask'].to(self.args.device)
+
+            
+            # tokens = ['[CLS]'] + tokens + ['[SEP]']
+            # token_ids = torch.from_numpy(np.array(encode_dict['input_ids'])).unsqueeze(0).to(self.args.device)
+            # attention_mask = torch.from_numpy(np.array(encode_dict['attention_mask'])).unsqueeze(0).to(
+            #     self.args.device)
+            # token_type_ids = torch.from_numpy(np.array(encode_dict['token_type_ids'])).unsqueeze(0).to(self.args.device)
+            output = self.model(token_ids, attention_mask)
+            start_logits = output["ner_output"]["ner_start_logits"]
+            end_logits = output["ner_output"]["ner_end_logits"]
+            content_logits = output["ner_output"]["ner_content_logits"]
+
+            start_logits = sigmoid(start_logits[0])
+            end_logits = sigmoid(end_logits[0])
+            con_logits = sigmoid(content_logits[0])
+            return start_logits, end_logits, con_logits
+        
+    def half_batch_predict(self, **kwargs):
+        """
+        将一份文档截断分批次预测
+        """
         # self.load_model()
         # self.model.eval()
         # self.model.to(self.args.device)
@@ -340,32 +366,42 @@ class NerPipeline:
             #                         truncation="only_first",
             #                         return_token_type_ids=True,
             #                         return_attention_mask=True)
-            inputs = self.args.tokenizer(sentences, padding='max_length', truncation=True, 
-                                    max_length=self.args.max_seq_len, return_tensors='pt')
-            token_ids = inputs['input_ids'].to(self.args.device)
-            attention_mask = inputs['attention_mask'].to(self.args.device)
-
-            # tokens = ['[CLS]'] + tokens + ['[SEP]']
-            # token_ids = torch.from_numpy(np.array(encode_dict['input_ids'])).unsqueeze(0).to(self.args.device)
-            # attention_mask = torch.from_numpy(np.array(encode_dict['attention_mask'])).unsqueeze(0).to(
-            #     self.args.device)
-            # token_type_ids = torch.from_numpy(np.array(encode_dict['token_type_ids'])).unsqueeze(0).to(self.args.device)
-            output = self.model(token_ids, attention_mask)
-            start_logits = output["ner_output"]["ner_start_logits"]
-            end_logits = output["ner_output"]["ner_end_logits"]
-            content_logits = output["ner_output"]["ner_content_logits"]
+            # 需要分固定句子处理,句子不能太长,因为显存不够 
+            print("预测文档的句子数:", len(sentences))
+            max_input_len = 100
+            batch_num = int(len(sentences) / max_input_len)
+            start_logits, end_logits, con_logits = [], [], []
+            if batch_num > 0:
+                for i in range(batch_num):
+                    left, right = i*max_input_len, (i+1)*max_input_len
+                    if i == batch_num-1 and len(sentences) - (i+1)*max_input_len<28:
+                        batch_num -= 1
+                        break
+                    l_edge = 10 if left > 0 else 0  
+                    r_edge = 10  # 左右多加几句共同参与
+                    start_logit, end_logit, con_logit = self.predict(sentences[left-l_edge: right+r_edge])
+                    start_logits.append(start_logit[l_edge:-r_edge])
+                    end_logits.append(end_logit[l_edge:-r_edge])
+                    con_logits.append(con_logit[l_edge:-r_edge])
+            if len(sentences) - batch_num * max_input_len > 0:
+                left = batch_num*max_input_len
+                l_edge = 10 if left > 0 else 0
+                start_logit, end_logit, con_logit = self.predict(sentences[left-l_edge:])
+                start_logits.append(start_logit[l_edge:])
+                end_logits.append(end_logit[l_edge:])
+                con_logits.append(con_logit[l_edge:])
+            start_logits = torch.cat(start_logits, dim=0)
+            end_logits = torch.cat(end_logits, dim=0)
+            con_logits = torch.cat(con_logits, dim=0)
+            # 不分段预测
+            # start_logits, end_logits, con_logits = self.predict(sentences)
 
-            start_logits = sigmoid(start_logits[0])
-            end_logits = sigmoid(end_logits[0])
-            con_logits = sigmoid(content_logits[0])
-            
-            # print("start_logits:::", start_logits)
             pred_entities, topic_item_pred = topic_ner_decode(start_logits, end_logits, con_logits, sentences, self.args.id2label)
             # pprint(dict(pred_entities))
             split_topic_idx = []
             for i in pred_entities['TOPIC']:
                 split_topic_idx.append((i[-1], i[-1]+len(i[0])))
-            # print(split_topic_idx)
+            print("split_topic_idx:",split_topic_idx)
             return dict(pred_entities), split_topic_idx, topic_item_pred
 
 

+ 1 - 1
PointerNet/predictor.py

@@ -34,7 +34,7 @@ class Predictor:
 
     item_str, _, _ = simpwash(text, paper_id)
     self.sents_with_imginfo, simply_sents = again_wash(item_str, paper_id)
-    entities, split_topic_idx, topic_item_pred = self.ner_pipeline.predict(text_list=simply_sents)
+    entities, split_topic_idx, topic_item_pred = self.ner_pipeline.half_batch_predict(text_list=simply_sents)
     # pprint(entities)
     if not split_topic_idx:
       return []