1. save()和save_weights()的基本区别
在Keras中,保存已经训练好的模型有两种方法:save()和save_weights()。它们之间的主要区别在于保存的内容。
1.1 save()
save()方法会将模型的结构、权重和训练配置全部保存下来。具体来说,它会包括:
模型的架构,包括每层的类型和连接方式。
每个层的权重。
模型的训练配置,例如优化器的参数、损失函数的名称等。
优化器的状态,包括动量、梯度累积等信息。
模型的编译信息,例如损失函数的名称、指标的名称等。
可能自定义的层和损失函数等。
在使用save()方法保存模型时,我们可以通过指定文件名来保存模型,也可以使用默认的文件名"model.h5"。例如:
model.save("my_model.h5")
1.2 save_weights()
相比于save()方法,save_weights()方法只保存模型的权重。具体来说,它会只包括所有层的权重,不包括模型的结构和训练配置。
在使用save_weights()方法保存模型权重时,我们也可以通过指定文件名来保存模型权重,也可以使用默认的文件名"weights.h5"。例如:
model.save_weights("my_weights.h5")
2. 使用场景
2.1 save()
save()方法适用于以下场景:
需要保存完整的模型结构、权重和训练配置。
需要将模型部署到其他地方进行使用。
需要将模型分享给他人,让他们可以直接加载使用。
2.2 save_weights()
save_weights()方法适用于以下场景:
需要仅保存模型的权重,而不保存模型的结构和训练配置。
需要将模型的权重保存下来,以便在将来的训练中继续使用。
需要将模型的权重转移到不同的模型架构中。
3. 示例
下面通过一个具体的示例来演示save()和save_weights()方法的用法。
import keras
from keras.models import Sequential
from keras.layers import Dense
# 创建模型
model = Sequential()
model.add(Dense(64, activation='relu', input_dim=100))
model.add(Dense(64, activation='relu'))
model.add(Dense(10, activation='softmax'))
# 编译模型
model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
# 模型训练
model.fit(x_train, y_train, epochs=10, batch_size=32)
# 保存模型
model.save("my_model.h5")
model.save_weights("my_weights.h5")
上述代码中,首先定义一个简单的多层感知器模型,并编译模型。然后进行模型训练,最后分别调用save()方法和save_weights()方法保存模型。可以看到,save()方法保存的文件是完整的模型文件,而save_weights()方法保存的文件仅包含了模型的权重。
4. 使用已保存的模型
当我们需要使用之前保存的模型时,可以使用Keras的load_model()方法加载完整的模型,或者使用load_weights()方法加载模型的权重。
4.1 load_model()
使用load_model()方法加载完整的模型时,可以直接获取保存的模型、权重和训练配置。
from keras.models import load_model
# 加载模型
loaded_model = load_model("my_model.h5")
# 使用模型进行预测
predictions = loaded_model.predict(x_test)
4.2 load_weights()
如果只需要加载模型的权重,可以使用load_weights()方法。但是在使用前需要先定义一个与保存的模型结构相同的模型,并使用load_weights()方法加载权重。
from keras.models import Sequential
from keras.layers import Dense
# 定义模型
model = Sequential()
model.add(Dense(64, activation='relu', input_dim=100))
model.add(Dense(64, activation='relu'))
model.add(Dense(10, activation='softmax'))
# 加载权重
model.load_weights("my_weights.h5")
# 使用模型进行预测
predictions = model.predict(x_test)
5. 总结
在Keras中,save()和save_weights()方法都可以用于保存已训练好的模型,但它们保存的内容有所不同。save()方法保存的是完整的模型,包括模型的结构、权重和训练配置,而save_weights()方法仅保存模型的权重。根据使用场景的不同,我们可以选择使用适当的方法保存和加载模型。