cdZWj преди 5 месеца
родител
ревизия
e630f45fdc
променени са 1 файла, в които са добавени 3 реда и са изтрити 3 реда
  1. 3 3
      PointerNet/main.py

+ 3 - 3
PointerNet/main.py

@@ -34,9 +34,9 @@ class NerPipeline:
         torch.save(self.optimizer.state_dict(), self.args.optimizer_save_dir)
 
     def load_model(self):
-        # self.model.load_state_dict(torch.load(self.args.save_dir, map_location="cpu"))  #GPU 上训练的模型加载到CPU
-        self.model.load_state_dict(torch.load(self.args.save_dir))
-        self.model.to(self.args.device)
+        self.model.load_state_dict(torch.load(self.args.save_dir, map_location="cpu"))  #GPU 上训练的模型加载到CPU
+        # self.model.load_state_dict(torch.load(self.args.save_dir))
+        self.model.to(self.args.device)  # 耗内存
 
     def build_optimizer_and_scheduler(self, t_total):
         module = (