0


torch.load()加载模型及其map_location参数

参考

  • TORCH.LOAD

torch.load()

函数格式为:

torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args)

,一般我们使用的时候,基本只使用前两个参数。

模型的保存

  • 模型保存有两种形式,一种是保存模型的state_dict(),只是保存模型的参数。那么加载时需要先创建一个模型的实例model,之后通过torch.load()将保存的模型参数加载进来,得到dict,再通过model.load_state_dict(dict)将模型的参数更新。
  • 另一种是将整个模型保存下来,之后加载的时候只需要通过torch.load()将模型加载,即可返回一个加载好的模型。 具体可参考:PyTorch模型的保存与加载。

模型加载中的map_location参数

具体来说,

map_location

参数是用于重定向,比如此前模型的参数是在

cpu

中的,我们希望将其加载到

cuda:0

中。或者我们有多张卡,那么我们就可以将卡1中训练好的模型加载到卡2中,这在数据并行的分布式深度学习中可能会用到。

  • 首先定义一个AlexNet,并使用cuda:0将其训练了一个猫狗分类,之后把模型存储起来。

map_location=None

  • 我们先把state_dict加载进来。
model_path ="./cuda_model.pth"
model = torch.load(model_path)print(next(model.parameters()).device)

结果为:

cuda:0

因为保存的时候就是模型就是

cuda:0

的,所以加载进来也是。

map_location=torch.device()

model_path ="./cuda_model.pth"
model = torch.load(model_path, map_location=torch.device('cpu'))print(next(model.parameters()).device)

结果为:

cpu

模型从

cuda:0

变成了

cpu

map_location={xx:xx}

model_path ="./cuda_model.pth"
model = torch.load(model_path, map_location={'cuda:0':'cuda:1'})print(next(model.parameters()).device)

结果为:

cuda:1

模型从

cuda:0

变成了

cuda:1

model_path ="./cuda_model.pth"
model = torch.load(model_path, map_location={'cuda:2':'cpu'})print(next(model.parameters()).device)

结果为:

cuda:0

模型还是

cuda:0

,并没有变成

cpu

。因为这个

map_location

的映射是不对的,原始的模型就是

cuda:0

,而映射是

cuda:2

cpu

,是不对的。这种情况下,

map_location

返回

None

,也就是和不加

map_location

相同。


本文转载自: https://blog.csdn.net/qq_43219379/article/details/123675375
版权归原作者 eecspan 所有, 如有侵权,请联系我们删除。

“torch.load()加载模型及其map_location参数”的评论:

还没有评论