浅谈keras保存模型中的save()和save_weights()区别

1. save()和save_weights()的基本区别

在Keras中,保存已经训练好的模型有两种方法:save()和save_weights()。它们之间的主要区别在于保存的内容。

1.1 save()

save()方法会将模型的结构、权重和训练配置全部保存下来。具体来说,它会包括:

模型的架构,包括每层的类型和连接方式。

每个层的权重。

模型的训练配置,例如优化器的参数、损失函数的名称等。

优化器的状态,包括动量、梯度累积等信息。

模型的编译信息,例如损失函数的名称、指标的名称等。

可能自定义的层和损失函数等。

在使用save()方法保存模型时,我们可以通过指定文件名来保存模型,也可以使用默认的文件名"model.h5"。例如:

model.save("my_model.h5")

1.2 save_weights()

相比于save()方法,save_weights()方法只保存模型的权重。具体来说,它会只包括所有层的权重,不包括模型的结构和训练配置。

在使用save_weights()方法保存模型权重时,我们也可以通过指定文件名来保存模型权重,也可以使用默认的文件名"weights.h5"。例如:

model.save_weights("my_weights.h5")

2. 使用场景

2.1 save()

save()方法适用于以下场景:

需要保存完整的模型结构、权重和训练配置。

需要将模型部署到其他地方进行使用。

需要将模型分享给他人,让他们可以直接加载使用。

2.2 save_weights()

save_weights()方法适用于以下场景:

需要仅保存模型的权重,而不保存模型的结构和训练配置。

需要将模型的权重保存下来,以便在将来的训练中继续使用。

需要将模型的权重转移到不同的模型架构中。

3. 示例

下面通过一个具体的示例来演示save()和save_weights()方法的用法。

import keras

from keras.models import Sequential

from keras.layers import Dense

# 创建模型

model = Sequential()

model.add(Dense(64, activation='relu', input_dim=100))

model.add(Dense(64, activation='relu'))

model.add(Dense(10, activation='softmax'))

# 编译模型

model.compile(optimizer='rmsprop',

loss='categorical_crossentropy',

metrics=['accuracy'])

# 模型训练

model.fit(x_train, y_train, epochs=10, batch_size=32)

# 保存模型

model.save("my_model.h5")

model.save_weights("my_weights.h5")

上述代码中,首先定义一个简单的多层感知器模型,并编译模型。然后进行模型训练,最后分别调用save()方法和save_weights()方法保存模型。可以看到,save()方法保存的文件是完整的模型文件,而save_weights()方法保存的文件仅包含了模型的权重。

4. 使用已保存的模型

当我们需要使用之前保存的模型时,可以使用Keras的load_model()方法加载完整的模型,或者使用load_weights()方法加载模型的权重。

4.1 load_model()

使用load_model()方法加载完整的模型时,可以直接获取保存的模型、权重和训练配置。

from keras.models import load_model

# 加载模型

loaded_model = load_model("my_model.h5")

# 使用模型进行预测

predictions = loaded_model.predict(x_test)

4.2 load_weights()

如果只需要加载模型的权重,可以使用load_weights()方法。但是在使用前需要先定义一个与保存的模型结构相同的模型,并使用load_weights()方法加载权重。

from keras.models import Sequential

from keras.layers import Dense

# 定义模型

model = Sequential()

model.add(Dense(64, activation='relu', input_dim=100))

model.add(Dense(64, activation='relu'))

model.add(Dense(10, activation='softmax'))

# 加载权重

model.load_weights("my_weights.h5")

# 使用模型进行预测

predictions = model.predict(x_test)

5. 总结

在Keras中,save()和save_weights()方法都可以用于保存已训练好的模型,但它们保存的内容有所不同。save()方法保存的是完整的模型,包括模型的结构、权重和训练配置,而save_weights()方法仅保存模型的权重。根据使用场景的不同,我们可以选择使用适当的方法保存和加载模型。

后端开发标签