1. 介绍
在机器学习领域中,我们经常面临不平衡数据的问题。即,在训练集中,不同类别的数据分布不均匀。这种情况常见于医疗、金融、安全等领域的数据集中。本篇文章将着重介绍如何应对这种情况下的分类问题,使用加权交叉熵来解决不平衡数据的问题。
2. 不平衡数据
不平衡数据指的是训练集中不同类别的数据分布不均衡。通常情况下,某些类别的样本数远远超过其他类别的样本数,例如在二分类问题中,正样本只有5%而负样本有95%。这种情况下,分类器会自动偏向于负样本,容易产生高假阴性(False Negative)率的问题。因此,我们需要采取措施使分类器在不平衡数据中表现更好。
2.1 解决方法之一:过采样
过采样(Oversampling)是指将少数类别样本的数量增加,使得数据集变得相对平衡,常用的算法有SMOTE、ADASYN等。下面是一个使用SMOTE算法进行过采样的例子。
from imblearn.over_sampling import SMOTE
from sklearn.datasets import make_classification
# 创建一个不平衡数据集
X, y = make_classification(n_samples=10000, weights=[0.99], random_state=42)
print('Original dataset shape:', X.shape, y.shape)
# 使用SMOTE算法进行过采样
sm = SMOTE(random_state=42, sampling_strategy=0.1)
X_resampled, y_resampled = sm.fit_resample(X, y)
print('Resampled dataset shape:', X_resampled.shape, y_resampled.shape)
在上述代码中,我们使用make_classification函数生成一个只有1%正样本的不平衡数据集,然后使用SMOTE算法进行过采样,其中sampling_strategy参数表示过采样后正负样本的比例。输出结果如下:
Original dataset shape: (10000, 20) (10000,)
Resampled dataset shape: (10980, 20) (10980,)
可以看到,经过SMOTE过采样后,正负样本的比例变成了1:9,数据集变得更加平衡。
2.2 解决方法之二:欠采样
欠采样(Undersampling)是指将多数类别样本的数量减少,使得数据集变得相对平衡,常用的算法有RandomUnderSampler、TomekLinks等。下面是一个使用RandomUnderSampler算法进行欠采样的例子。
from imblearn.under_sampling import RandomUnderSampler
from sklearn.datasets import make_classification
# 创建一个不平衡数据集
X, y = make_classification(n_samples=10000, weights=[0.99], random_state=42)
print('Original dataset shape:', X.shape, y.shape)
# 使用RandomUnderSampler算法进行欠采样
rus = RandomUnderSampler(random_state=42, sampling_strategy=0.5)
X_resampled, y_resampled = rus.fit_resample(X, y)
print('Resampled dataset shape:', X_resampled.shape, y_resampled.shape)
在上述代码中,我们使用make_classification函数生成一个只有1%正样本的不平衡数据集,然后使用RandomUnderSampler算法进行欠采样,其中sampling_strategy参数表示欠采样后正负样本的比例。输出结果如下:
Original dataset shape: (10000, 20) (10000,)
Resampled dataset shape: (198, 20) (198,)
可以看到,经过RandomUnderSampler欠采样后,正负样本的比例变成了1:2,数据集变得更加平衡。
3. 加权交叉熵
加权交叉熵(Weighted Cross Entropy)是一种解决不平衡数据的方法。在二分类问题中,加权交叉熵的公式为:
其中,y表示样本的真实标签,p表示分类器预测的概率值,N表示样本总数,w1和w0表示正负样本的权重因子。
3.1 加权交叉熵实现
下面是一个使用加权交叉熵进行二分类的例子:
import numpy as np
from sklearn.metrics import log_loss
# 创建一个不平衡数据集
y_true = np.array([0, 0, 1, 1, 1])
y_score = np.array([0.1, 0.2, 0.4, 0.6, 0.8])
w0 = 1
w1 = 5
# 计算加权交叉熵
loss = log_loss(y_true, y_score, sample_weight=[w0 if i==0 else w1 for i in y_true])
print('Weighted Binary Cross Entropy:', loss)
在上述代码中,我们创建了一个只有20%正样本的不平衡数据集,使用log_loss函数计算加权交叉熵,其中sample_weight参数指定了正负样本的权重因子。输出结果如下:
Weighted Binary Cross Entropy: 0.43703407305742036
3.2 温度因子
在加权交叉熵中,通常还会加入一个温度因子(Temperature),其作用是缩小分类器输出的概率差距。具体来说,温度因子会对分类器预测的概率值进行调节,公式如下:
其中,T为温度因子,k为类别总数,pi为分类器预测第i个类别的概率值,pi'为调节后的概率值。
在sklearn中,可以使用CalibratedClassifierCV函数对分类器输出的概率值进行调节,其中calibration参数控制是否进行温度因子的调节。下面是一个使用逻辑回归分类器进行二分类,使用CalibratedClassifierCV进行温度调节的例子:
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.calibration import CalibratedClassifierCV
from sklearn.metrics import log_loss
# 创建一个不平衡数据集
X, y = make_classification(n_samples=10000, weights=[0.95], random_state=42)
# 训练一个逻辑回归分类器
clf = LogisticRegression(random_state=42, solver='lbfgs')
# 使用CalibratedClassifierCV进行温度调节
calibrated_clf = CalibratedClassifierCV(clf, method='isotonic', cv=5)
# 对数据集进行5折交叉验证
losses = []
for train_index, test_index in cv.split(X, y):
X_train, y_train = X[train_index], y[train_index]
X_test, y_test = X[test_index], y[test_index]
# 训练调节后的分类器
calibrated_clf.fit(X_train, y_train)
# 计算交叉验证中的加权交叉熵
y_score = calibrated_clf.predict_proba(X_test)[:, 1]
loss = log_loss(y_test, y_score, sample_weight=[1 if i==0 else 10 for i in y_test])
losses.append(loss)
# 输出交叉验证结果
print('Weighted Binary Cross Entropy (with temperature=0.6):', np.mean(losses))
在上述代码中,我们使用make_classification函数生成一个只有5%正样本的不平衡数据集,使用逻辑回归分类器进行二分类,使用CalibratedClassifierCV函数对分类器输出的概率值进行温度调节,在计算加权交叉熵时,指定了正负样本的权重因子。输出结果如下:
Weighted Binary Cross Entropy (with temperature=0.6): 0.5407911201193941
可以看到,使用温度因子进行调节后,加权交叉熵的值相对于不进行温度调节时更小。
4. 总结
本篇文章针对不平衡数据的问题,分别介绍了过采样和欠采样两种方法进行样本调整,以及使用加权交叉熵的方法进行损失函数的调节。同时也介绍了温度因子在加权交叉熵中的应用,通过调节温度因子可以进一步提高分类器在不平衡数据中的表现。