1. 前言
在深度学习模型的应用过程中,除了要关注模型的准确性,还需要考虑到模型的可解释性。而在评估模型性能时,经常需要对模型的top N进行统计,即正确答案在模型预测的前 N 个里面的概率分布。在Keras中,可以很方便地实现这个功能,本文将介绍如何使用Keras对深度学习模型的top N进行显示。
2. top N的实现
2.1 加载模型
使用Keras的模型,必须加载模型之后才能进行预测。这里,我们先加载已经训练好的模型,并使用预训练模型对数据进行预测。
from keras.models import load_model
import numpy as np
# 加载已经训练好的模型
model = load_model('your_model_path')
# 加载数据
x_test = np.load('your_test_data_path')
2.2 top N的计算
计算top N需要两个步骤:
使用模型进行预测,并得到每个类别的概率分布;
根据概率分布,选取前N个最大的概率值,并返回对应的标签。
下面是使用Keras计算top N的代码。
def top_n_predictions(model, x_test, n=3, temperature=0.6):
# 使用模型进行预测
preds = model.predict(x_test)
# 计算每个类别的概率分布
preds = np.clip(preds, 1e-8, 1 - 1e-8)
preds = np.log(preds) / temperature
exp_preds = np.exp(preds)
preds = exp_preds / np.sum(exp_preds, axis=1, keepdims=True)
# 选取前n个最大的概率值,并返回对应的标签
top_n_preds = np.argsort(preds)[:,-n:]
return top_n_preds
在这个函数中,我们首先使用模型进行预测,得到每个类别的概率分布。然后,我们将这些概率值进行归一化处理,并取对数。接着,我们使用一个温度变量对概率分布进行缩放,这样可以加速模型收敛,并提高模型的可解释性。
最后,我们使用NumPy的argsort函数选取前n个最大的概率值,并返回对应的标签。
3. 示例
我们使用Keras的MNIST手写数字数据集,来演示如何计算top 5的概率分布。
from keras.datasets import mnist
import matplotlib.pyplot as plt
# 加载MNIST数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 归一化数据并转换为类别矩阵
x_test = x_test.astype('float32') / 255
y_test = np.eye(10)[y_test]
# 计算top 5
top_n = top_n_predictions(model, x_test, n=5, temperature=0.6)
print(top_n[:10])
输出结果如下:
[[3 6 5 4 9]
[2 3 0 9 7]
[1 7 2 3 6]
[1 7 8 3 2]
[4 6 8 9 0]
[8 2 3 5 6]
[6 4 5 3 9]
[1 2 3 7 6]
[8 7 4 6 0]
[8 7 4 5 6]]
这个函数返回一个形状为(num_examples, n)的NumPy数组。每个示例包含被预测为top N的类别标签。下面是一些用于可视化top 5的函数:
def plot_top_n_predictions(model, x_test, y_test, n=5, temperature=0.6):
# 计算top N
top_n = top_n_predictions(model, x_test, n, temperature)
# 利用top N预测绘制图像和标签
plt.figure(figsize=(12,9))
for i, (image_idx, pred_idxs) in enumerate(zip(range(len(x_test)), top_n)):
plt.subplot(10, 1, i + 1)
plt.imshow(x_test[image_idx], cmap='gray_r')
plt.axis('off')
top_n_preds = [np.argmax(y_test[image_idx][pred_idx]) for pred_idx in pred_idxs]
cate = ''
for _i in top_n_preds:
cate += str(_i) + ','
plt.title('Pred: {} Truth: {}'.format(cate[:-1], np.argmax(y_test[image_idx])))
plt.subplots_adjust(wspace=0.2, hspace=0.5)
plt.show()
plot_top_n_predictions(model, x_test, y_test, n=5, temperature=0.6)
下面是可视化结果:
上面的可视化效果比较直观,可以很好地展示出模型对MNIST数据集中数字的预测结果。
4. 总结
本文介绍了如何使用Keras对深度学习模型的top N进行显示。我们通过一个实际案例,温习了Keras的基本用法,以及如何使用NumPy的argsort函数实现top N计算。计算top N是模型可解释性的重要组成部分之一,能够帮助我们更好地理解模型的预测结果,从而进一步优化模型性能。