|
@@ -34,9 +34,9 @@ class NerPipeline:
|
|
torch.save(self.optimizer.state_dict(), self.args.optimizer_save_dir)
|
|
torch.save(self.optimizer.state_dict(), self.args.optimizer_save_dir)
|
|
|
|
|
|
def load_model(self):
|
|
def load_model(self):
|
|
- # self.model.load_state_dict(torch.load(self.args.save_dir, map_location="cpu")) #GPU 上训练的模型加载到CPU
|
|
|
|
- self.model.load_state_dict(torch.load(self.args.save_dir))
|
|
|
|
- self.model.to(self.args.device)
|
|
|
|
|
|
+ self.model.load_state_dict(torch.load(self.args.save_dir, map_location="cpu")) #GPU 上训练的模型加载到CPU
|
|
|
|
+ # self.model.load_state_dict(torch.load(self.args.save_dir))
|
|
|
|
+ self.model.to(self.args.device) # 耗内存
|
|
|
|
|
|
def build_optimizer_and_scheduler(self, t_total):
|
|
def build_optimizer_and_scheduler(self, t_total):
|
|
module = (
|
|
module = (
|