三招提升数据不平衡模型的性能(附python代码)

简介: 本文的主要目标是处理数据不平衡问题。文中描述了用来克服数据不平衡问题的三种技术,分别是集成交叉验证、类别权重以及过大预测 。

       对于深度学习而言,数据集非常重要,但在实际项目中,或多或少会碰见数据不平衡问题。什么是数据不平衡呢?举例来说,现在有一个任务是判断西瓜是否成熟,这是一个二分类问题——西瓜是生的还是熟的,该任务的数据集由两部分数据组成,成熟西瓜与生西瓜,假设生西瓜的样本数量远远大于成熟西瓜样本的数量,针对这样的数据集训练出来的算法“偏向”于识别新样本为生西瓜,存心让你买不到甜的西瓜以解夏天之苦,这就是一个数据不平衡问题。针对数据不平衡问题有相应的处理办法,比如对多数样本进行采样使得其样本数量级与少样本数相近,或者是对少数样本重复使用等。最近恰好在面试中遇到一个数据不平衡问题,这也是面试中经常会出现的问题之一,现向读者分享此次解决问题的心得。

1_jpeg

数据集

       训练数据中有三个标签,分别标记为[1、2、3],这意味着该问题是一个多分类问题。训练数据集有17个特征以及38829个独立数据点。而在测试数据中,有16个没有标签的特征和16641个数据点。该训练数据集非常不平衡,大部分数据是1类(95%),而2类和3类分别有3.0%和0.87%的数据,如下图所示。

2

算法

       经过初步观察,决定采用随机森林(RF)算法,因为它优于支持向量机、Xgboost以及LightGBM算法。在这个项目中选择RF还有几个原因:

  • 1机森林对过拟合具有很强的鲁棒性;
  • 2.参数化仍然非常直观;
  • 3.在这个项目中,有许多成功的用例将随机森林算法用于高度不平衡的数据集;
  • 4.个人有先前的算法实施经验;
           为了找到最佳参数,使用scikit-sklearn实现的GridSearchCV对指定的参数值执行网格搜索,更多细节可以在本人的Github上找到。

为了处理数据不平衡问题,使用了以下三种技术:

A.使用集成交叉验证(CV):

       在这个项目中,使用交叉验证来验证模型的鲁棒性。整个数据集被分成五个子集。在每个交叉验证中,使用其中的四个子集用于训练,剩余的子集用于验证模型,此外模型还对测试数据进行了预测。在交叉验证结束时,会得到五个测试预测概率。最后,对所有类别的概率取平均值。模型的训练表现稳定,每个交叉验证上具有稳定的召回率和f1分数。这项技术也帮助我在Kaggle比赛中取得了很好的成绩(前1%)。以下部分代码片段显示了集成交叉验证的实现:

for j, (train_idx, valid_idx) in enumerate(folds):
                
                X_train = X[train_idx]
                Y_train = y[train_idx]
                X_valid = X[valid_idx]
                Y_valid = y[valid_idx]
                
                clf.fit(X_train, Y_train)
                
                valid_pred = clf.predict(X_valid)
                recall  = recall_score(Y_valid, valid_pred, average='macro')
                f1 = f1_score(Y_valid, valid_pred, average='macro')
                
                recall_scores[i][j] = recall
                f1_scores[i][j] = f1
                
                train_pred[valid_idx, i] = valid_pred
                test_pred[:, test_col] = clf.predict(T)
                test_col += 1
                
                ## Probabilities
                valid_proba = clf.predict_proba(X_valid)
                train_proba[valid_idx, :] = valid_proba
                test_proba  += clf.predict_proba(T)
                
            test_proba /= self.n_splits

B.设置类别权重/重要性:

       代价敏感学习是使随机森林更适合从非常不平衡的数据中学习的方法之一。随机森林有倾向于偏向大多数类别。因此,对少数群体错误分类施加昂贵的惩罚可能是有作用的。由于这种技术可以改善模型性能,所以我给少数群体分配了很高的权重(即更高的错误分类成本)。然后将类别权重合并到随机森林算法中。我根据类别1中数据集的数量与其它数据集的数量之间的比率来确定类别权重。例如,类别1和类别3数据集的数目之间的比率约为110,而类别1和类别2的比例约为26。现在我稍微对数量进行修改以改善模型的性能,以下代码片段显示了不同类权重的实现:

from sklearn.ensemble import RandomForestClassifier
class_weight = dict({1:1.9, 2:35, 3:180})

rdf = RandomForestClassifier(bootstrap=True,
            class_weight=class_weight, 
            criterion='gini',
            max_depth=8, max_features='auto', max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=4, min_samples_split=10,
            min_weight_fraction_leaf=0.0, n_estimators=300,
            oob_score=False,
            random_state=random_state,
            verbose=0, warm_start=False)

C.过大预测标签而不是过小预测(Over-Predict a Label than Under-Predict):

       这项技术是可选的,通过实践发现,这种方法对提高少数类别的表现非常有效。简而言之,如果将模型错误分类为类别3,则该技术能最大限度地惩罚该模型,对于类别2和类别1惩罚力度稍差一些。 为了实施该方法,我改变了每个类别的概率阈值,将类别3、类别2和类别1的概率设置为递增顺序(即,P3= 0.25,P2= 0.35,P1= 0.50),以便模型被迫过度预测类别。该算法的详细实现可以在Github上找到。

最终结果

       以下结果表明,上述三种技术如何帮助改善模型性能:
1.使用集成交叉验证的结果:

3


2.使用集成交叉验证+类别权重的结果:

4


3.使用集成交叉验证+类别权重+过大预测标签的结果:

5

结论

       由于在实施过大预测技术方面的经验很少,因此最初的时候处理起来非常棘手。但是,研究该问题有助于提升我解决问题的能力。对于每个任务而言,起初可能确实是陌生的,这个时候不要害怕,一次次尝试就好。由于时间的限制(48小时),无法将精力分散于模型的微调以及特征工程,存在改进的地方还有很多,比如删除不必要的功能并添加一些额外功能。此外,也尝试过LightGBM和XgBoost算法,但在实践过程中发现,随机森林的效果优于这两个算法。在后面的研究中,可以进一步尝试一些其他算法,比如神经网络、稀疏编码等。

数十款阿里云产品限时折扣中,赶紧点击领劵开始云上实践吧!

作者信息

Sabber Ahamed,计算地球物理学、机器学习爱好者
个人主页:https://www.linkedin.com/in/sabber-ahamed/
本文由阿里云云栖社区组织翻译。
文章原标题《Three techniques to improve machine learning model performance with imbalanced datasets》,译者:海棠,审校:Uncle_LLD。
文章为简译,更为详细的内容,请查看原文

相关文章
|
1天前
|
Python 数据可视化 索引
PYTHON用GARCH、离散随机波动率模型DSV模拟估计股票收益时间序列与蒙特卡洛可视化
PYTHON用GARCH、离散随机波动率模型DSV模拟估计股票收益时间序列与蒙特卡洛可视化
12 0
PYTHON用GARCH、离散随机波动率模型DSV模拟估计股票收益时间序列与蒙特卡洛可视化
|
1天前
|
存储 机器学习/深度学习 数据可视化
Python面板时间序列数据预测:格兰杰因果关系检验Granger causality test药品销售实例与可视化
Python面板时间序列数据预测:格兰杰因果关系检验Granger causality test药品销售实例与可视化
24 6
|
1天前
|
数据可视化 算法 Python
python中的copula:Frank、Clayton和Gumbel copula模型估计与可视化
python中的copula:Frank、Clayton和Gumbel copula模型估计与可视化
|
1天前
|
机器学习/深度学习 数据采集 供应链
从数据到决策:scikit-learn在业务分析中的应用
【4月更文挑战第17天】本文探讨了scikit-learn在业务分析中的应用,包括数据预处理、分类、回归和聚类模型的构建,以及模型评估与优化。通过使用scikit-learn,企业能有效处理数据、预测趋势、客户细分并制定决策,从而提升经营效率和市场策略。随着机器学习的发展,scikit-learn在业务分析领域的潜力将持续释放,创造更多价值。
|
1天前
|
算法 数据可视化 Python
Python中LARS和Lasso回归之最小角算法Lars分析波士顿住房数据实例
Python中LARS和Lasso回归之最小角算法Lars分析波士顿住房数据实例
|
1天前
|
运维 数据可视化 Python
PYTHON中用PROPHET模型对天气时间序列进行预测与异常检测
PYTHON中用PROPHET模型对天气时间序列进行预测与异常检测
15 1
|
1天前
|
Python
Python随机波动率(SV)模型对标普500指数时间序列波动性预测
Python随机波动率(SV)模型对标普500指数时间序列波动性预测
|
1天前
|
数据安全/隐私保护 Python
Python中的装饰器:提升代码可读性与灵活性
Python中的装饰器是一种强大的工具,可以在不改变函数原有逻辑的情况下,为函数添加额外的功能。本文将介绍装饰器的基本概念和用法,并通过实例演示如何利用装饰器提升代码的可读性和灵活性,使代码更加简洁、易于维护。
|
1天前
|
BI 开发者 数据格式
Python代码填充数据到word模板中
【4月更文挑战第16天】
|
5月前
|
Python
python 股票数据分析、绘制K线图、价格走势图、收益率计算 完整代码+数据 可直接运行
python 股票数据分析、绘制K线图、价格走势图、收益率计算 完整代码+数据 可直接运行
91 0
python 股票数据分析、绘制K线图、价格走势图、收益率计算 完整代码+数据 可直接运行

热门文章

最新文章