keras 回调函数Callbacks 断点ModelCheckpoint教程

一、回调函数Callbacks的介绍

在深度学习模型训练过程中,我们可能需要在一定的训练轮次后对训练结果进行监控、调整模型参数或者保存模型等操作,这时就需要使用回调函数Callbacks了。回调函数可以在训练过程中的不同阶段(如每轮训练开始前、结束后等)自动调用,当满足某些条件时进行特定操作,比如:

定期保存模型

在训练过程中动态修改学习率

在训练过程中动态修改损失函数的权重

在训练过程中输出训练过程中的某些指标,比如准确率

回调函数是Keras里非常有用的一项功能,对于许多深度学习任务来说,只有通过它们才能成功地训练高质量的模型。

二、ModelCheckpoint回调函数的使用

ModelCheckpoint是Keras中提供的一种回调函数,可在训练过程中自动定期保存最佳模型。我们可以指定一个文件名模板,ModelCheckpoint会自动将最佳模型保存到文件中。以下是一个基本示例:

from keras.callbacks import ModelCheckpoint

checkpoint = ModelCheckpoint(filepath='best_model.h5', verbose=1, save_best_only=True)

在这个示例中,我们使用了ModelCheckpoint回调函数,并指定了相关参数。其中:

filepath:我们要保存模型的文件名模板。注意,如果设置了save_best_only=True,则文件名中应该包含{val_loss}等模型评估指标。

verbose:日志输出提醒。如果设置为1,则在每个保存点时将打印一条消息。

save_best_only:如果设置为True,则只保存在验证集上性能最好的模型。

在实际应用中,我们可以通过在模型训练时添加回调函数来进行模型的定期保存。以下是一个完整的代码示例:

from keras.callbacks import ModelCheckpoint

from keras.models import Sequential

from keras.layers import Dense

import numpy as np

import os

# 构造样本数据

x_train = np.random.random((1000, 1))

y_train = x_train * 2 + np.random.random((1000, 1)) * 0.1

x_val = np.random.random((100, 1))

y_val = x_val * 2 + np.random.random((100, 1)) * 0.1

# 构造模型

model = Sequential()

model.add(Dense(1, input_shape=(1,)))

model.compile(optimizer='sgd',loss='mse')

# 定义回调函数

best_model_path = 'model/best_model.h5'

if not os.path.exists(os.path.dirname(best_model_path)):

os.makedirs(os.path.dirname(best_model_path))

checkpoint = ModelCheckpoint(filepath=best_model_path, monitor='val_loss', verbose=1, save_best_only=True, mode='min')

callbacks_list = [checkpoint]

# 执行模型训练

model.fit(x_train, y_train,

validation_data=(x_val, y_val),

epochs=50, batch_size=32,

callbacks=callbacks_list)

1.定义回调函数

首先,我们需要定义ModelCheckpoint回调函数,将其保存到文件中。代码如下:

best_model_path = 'model/best_model.h5'

if not os.path.exists(os.path.dirname(best_model_path)):

os.makedirs(os.path.dirname(best_model_path))

checkpoint = ModelCheckpoint(filepath=best_model_path, monitor='val_loss', verbose=1, save_best_only=True, mode='min')

在这段代码中,我们指定了以下参数:

filepath:指定保存模型的文件名模板。

monitor:用来监控模型的评估指标(这里选择了验证集上的损失函数)。

verbose:日志输出提醒。如果设置为1,则在每个保存点时将打印一条消息。

save_best_only:如果设置为True,则只保存在验证集上性能最好的模型。

mode:用于指定监控模式。这里使用'min'表示验证集上的损失函数应该最小化。

2.执行模型训练

接下来,我们对模型进行训练,将回调函数添加到训练过程中。注意,我们要将回调函数传递给fit方法的callbacks参数中,以便在训练过程中调用。

model.fit(x_train, y_train, 

validation_data=(x_val, y_val),

epochs=50, batch_size=32,

callbacks=callbacks_list)

在这段代码中,我们指定了以下参数:

x_train:训练集数据。

y_train:训练集标签。

validation_data:验证集数据和标签。

epochs:训练轮次。

batch_size:批量大小。

callbacks:回调函数。

三、ModelCheckpoint回调函数的高级使用

当我们在训练深度神经网络时,通常需要进行很长时间的训练,这时使用ModelCheckpoint回调函数非常有用。但是,对于某些灵敏的模型(比如GAN和强化学习模型),我们可能需要定期保存较小的权重或其他一些后处理信息,这时候可以结合调整参数来进行优化。

1.设置断点

为了在训练期间对模型进行断点登录,我们可以使用ModelCheckpoint回调函数的save_weights_only=True选项。这将仅保存当前权重而不是整个模型(包括权重和架构)。可以使用以下方式定义回调函数:

checkpoint = ModelCheckpoint(filepath='weights.{epoch:02d}-{val_loss:.2f}.h5', save_weights_only=True)

在这个示例中,我们使用了ModelCheckpoint回调函数,并指定了以下参数:

filepath:权重文件保存的路径和文件名模板。

save_weights_only:如果设置为True,则只保存权重而不是整个模型。

在训练过程中,权重文件将在每个epoch之后保存,每个权重文件的命名将包括当前epoch的编号、验证集损失值等信息。例如,第2个epoch的权重文件将被保存为weights.02-0.80.h5。

2.使用预训练权重

在某些情况下,我们希望从之前的训练中加载模型的权重,继续进行训练,或在现有模型的基础上进行训练。这时,我们可以通过将ModelCheckpoint的参数save_weights_only设置为False来实现:

checkpoint = ModelCheckpoint(filepath='model.{epoch:02d}-{val_loss:.2f}.h5', save_weights_only=False)

在这个示例中,我们使用了ModelCheckpoint回调函数,并指定了以下参数:

filepath:模型文件保存的路径和文件名模板。

save_weights_only:如果设置为False,则保存整个模型(包括权重和架构)。

在训练过程中,模型文件将在每个epoch之后保存,每个模型文件的命名将包括当前epoch的编号、验证集损失值等信息。例如,第2个epoch的模型文件将被保存为model.02-0.80.h5。

3.继续训练模型

有时,我们可能需要从之前训练的checkpoint文件中加载模型继续进行训练。在这种情况下,我们可以使用ModelCheckpoint回调函数的load_weights()方法加载权重:

from keras.models import load_model

model = create_model() # 创建模型

model.load_weights(filepath) # 使用ModelCheckpoint保存的权重文件进行加载

在这个示例中,我们使用了Keras的load_model()函数和我们自己定义的create_model()函数来加载整个模型的权重。请确保在创建新模型时模型的结构与之前的结构完全相同。这样才能正确加载预训练的权重。

4.完整示例

下面是一个综合示例,其中模型被训练到评估指标停滞时退出。我们设置了callbacks以周期性地保存模型,每隔n个epoch检查一次验证集上的评估指标(这里选择了损失函数),如果指标没有改善,则使用提前停止技术来终止训练过程。

from keras.callbacks import ModelCheckpoint, EarlyStopping

from keras.models import Sequential

from keras.layers import Dense

import numpy as np

import os

# 构造样本数据

x_train = np.random.random((1000, 1))

y_train = x_train * 2 + np.random.random((1000, 1)) * 0.1

x_val = np.random.random((100, 1))

y_val = x_val * 2 + np.random.random((100, 1)) * 0.1

# 构造模型

model = Sequential()

model.add(Dense(1, input_shape=(1,)))

model.compile(optimizer='sgd',loss='mse')

# 定义回调函数

best_model_path = 'model/best_model.h5'

if not os.path.exists(os.path.dirname(best_model_path)):

os.makedirs(os.path.dirname(best_model_path))

checkpoint = ModelCheckpoint(filepath=best_model_path, monitor='val_loss', verbose=1, save_best_only=True, mode='min')

early_stopping = EarlyStopping(monitor='val_loss', patience=10, verbose=1, mode='min')

callbacks_list = [checkpoint, early_stopping]

# 执行模型训练

model.fit(x_train, y_train,

validation_data=(x_val, y_val),

epochs=200, batch_size=32,

callbacks=callbacks_list)

在上面的示例中,我们定义了ModelCheckpoint和EarlyStopping两个回调函数。EarlyStopping是一种早期停止技术,它可以在训练过程中监测验证集的评估指标,如果指标在n个epoch内没有改善,则提前停止训练。这样可以有效地防止过拟合。以上是完整可运行的短代码示例。

免责声明:本文来自互联网,本站所有信息(包括但不限于文字、视频、音频、数据及图表),不保证该信息的准确性、真实性、完整性、有效性、及时性、原创性等,版权归属于原作者,如无意侵犯媒体或个人知识产权,请来电或致函告之,本站将在第一时间处理。猿码集站发布此文目的在于促进信息交流,此文观点与本站立场无关,不承担任何责任。

后端开发标签