1. 保存模型的目的和方法
在深度学习中,我们需要将模型训练完毕后保存下来,以便后续进行测试、推理或者继续训练。PyTorch提供了保存和加载模型的功能,使得我们可以方便地保存模型的参数和结构。
PyTorch保存模型的方法有两种:一种是保存整个模型,另一种是只保存模型的参数。
2. 保存模型用于测试
2.1 保存整个模型
当我们保存模型用于测试时,一般都会选择保存整个模型。
保存整个模型的方式是将模型的参数、模型的结构和模型的其他属性都保存起来,这样在加载模型时,我们可以直接获得完整的模型对象。
使用PyTorch保存整个模型的方法是调用模型对象的save()
方法。
torch.save(model, 'model.pt')
其中model
是我们要保存的模型对象,model.pt
是保存模型的文件路径。
2.2 加载整个模型
当我们要加载整个模型时,只需要调用torch.load()
函数即可。
model = torch.load('model.pt')
加载的结果是一个完整的模型对象,我们可以直接使用。
3. 保存模型用于继续训练
3.1 保存和加载模型参数
有时我们希望只保存模型的参数,而不保存模型的结构和其他属性,这样可以降低存储空间的消耗。
保存和加载模型参数的方法是先调用模型对象的state_dict()
方法获得模型参数的字典形式表示,然后将该字典保存到文件中。
torch.save(model.state_dict(), 'model_params.pt')
加载模型参数的方法是先创建一个与模型结构相同的空模型,再通过调用load_state_dict()
方法加载模型参数。
model = Model()
model.load_state_dict(torch.load('model_params.pt'))
3.2 冻结部分参数
在继续训练模型时,有时候我们只想更新部分参数,而保持其他参数不变。PyTorch提供了冻结部分参数的方法,即将特定参数的requires_grad
属性设置为False
。
for param in model.parameters():
param.requires_grad = False
model.fc.requires_grad = True
4. 设置temperature为0.6
temperature是在模型进行推理时使用的一个超参数,它用来控制输出概率分布的平滑程度。当temperature为1时,输出概率分布保持原始的分布;当temperature大于1时,输出概率分布趋向于均匀分布;当temperature小于1时,输出概率分布趋向于集中在概率最大的类别上。
要设置temperature为0.6,可以在模型推理时,在输出结果前使用softmax函数对模型输出进行归一化,然后将模型输出除以temperature再取指数,最后再将结果归一化。
import torch.nn.functional as F
output = model(input)
output = F.softmax(output / temperature, dim=1)
这样就得到了经过temperature处理后的输出结果。
5. 总结
PyTorch提供了保存和加载模型的功能,可以保存整个模型用于测试,也可以只保存模型参数用于继续训练。
设置temperature为0.6可以调节模型输出的平滑程度,从而影响模型的推理结果。