|
@@ -12,8 +12,8 @@ class UIEModel(nn.Module):
|
|
|
bert_dir = args.bert_dir
|
|
|
self.bert_config = BertConfig.from_pretrained(bert_dir)
|
|
|
self.bert_model = BertModel.from_pretrained(bert_dir)
|
|
|
- self.bert_model.load_state_dict(torch.load(self.args.bert_pt_dir, map_location="cuda"))
|
|
|
- self.bert_model = BertForSequenceClassification.from_pretrained(args.bert_pt_dir)
|
|
|
+ # self.bert_model.load_state_dict(torch.load(self.args.bert_pt_dir, map_location="cuda"))
|
|
|
+ # self.bert_model = BertForSequenceClassification.from_pretrained(args.bert_pt_dir)
|
|
|
|
|
|
if "ner" in args.tasks:
|
|
|
self.ner_num_labels = args.ner_num_labels
|