Pytorch加载预训练模型的坑

保存模型:

def save(model, model_path):
  torch.save(model.state_dict(), model_path)

加载模型:

def load(model, model_path):
  model.load_state_dict(torch.load(model_path))

这样会出现一个问题,即明明指定了某张卡,但总有一个模型的显存多出来,占到另一张卡上,很烦人,看到知乎有个方法可以解决

https://www.zhihu.com/question/67209417/answer/355059967

说是把模型的数据放在CPU上就可以解决,等试一下效果

def load(model, model_path):
  model.load_state_dict(torch.load(model_path,map_location=torch.device('cpu')))