소스 검색

model 加载方式

cdZWj 4 달 전
부모
커밋
e630f45fdc
1개의 변경된 파일3개의 추가작업 그리고 3개의 파일을 삭제
  1. 3 3
      PointerNet/main.py

+ 3 - 3
PointerNet/main.py

@@ -34,9 +34,9 @@ class NerPipeline:
         torch.save(self.optimizer.state_dict(), self.args.optimizer_save_dir)
 
     def load_model(self):
-        # self.model.load_state_dict(torch.load(self.args.save_dir, map_location="cpu"))  #GPU 上训练的模型加载到CPU
-        self.model.load_state_dict(torch.load(self.args.save_dir))
-        self.model.to(self.args.device)
+        self.model.load_state_dict(torch.load(self.args.save_dir, map_location="cpu"))  #GPU 上训练的模型加载到CPU
+        # self.model.load_state_dict(torch.load(self.args.save_dir))
+        self.model.to(self.args.device)  # 耗内存
 
     def build_optimizer_and_scheduler(self, t_total):
         module = (