فهرست منبع

adjust model_load

cdZWj 6 ماه پیش
والد
کامیت
f89f651e5b
1فایلهای تغییر یافته به همراه2 افزوده شده و 2 حذف شده
  1. 2 2
      PointerNet/model.py

+ 2 - 2
PointerNet/model.py

@@ -12,8 +12,8 @@ class UIEModel(nn.Module):
         bert_dir = args.bert_dir
         self.bert_config = BertConfig.from_pretrained(bert_dir)
         self.bert_model = BertModel.from_pretrained(bert_dir)
-        self.bert_model.load_state_dict(torch.load(self.args.bert_pt_dir, map_location="cuda"))
-        self.bert_model = BertForSequenceClassification.from_pretrained(args.bert_pt_dir)
+        # self.bert_model.load_state_dict(torch.load(self.args.bert_pt_dir, map_location="cuda"))
+        # self.bert_model = BertForSequenceClassification.from_pretrained(args.bert_pt_dir)
 
         if "ner" in args.tasks:
             self.ner_num_labels = args.ner_num_labels