Browse Source

model 加载方式

cdZWj 6 months ago
parent
commit
e630f45fdc
1 changed files with 3 additions and 3 deletions
  1. 3 3
      PointerNet/main.py

+ 3 - 3
PointerNet/main.py

@@ -34,9 +34,9 @@ class NerPipeline:
         torch.save(self.optimizer.state_dict(), self.args.optimizer_save_dir)
 
     def load_model(self):
-        # self.model.load_state_dict(torch.load(self.args.save_dir, map_location="cpu"))  #GPU 上训练的模型加载到CPU
-        self.model.load_state_dict(torch.load(self.args.save_dir))
-        self.model.to(self.args.device)
+        self.model.load_state_dict(torch.load(self.args.save_dir, map_location="cpu"))  #GPU 上训练的模型加载到CPU
+        # self.model.load_state_dict(torch.load(self.args.save_dir))
+        self.model.to(self.args.device)  # 耗内存
 
     def build_optimizer_and_scheduler(self, t_total):
         module = (