深度学习任务面临非平衡数据问题?试试这个简单方法

简介: 对于数据科学或机器学习研究者而言,当解决任何机器学习问题时,可能面临的最大问题之一就是训练数据不平衡的问题。本文使用最简单的过采样方法解决驼背鲸鱼分类任务,取得了很好的效果。

       对于数据科学或机器学习研究者而言,当解决任何机器学习问题时,可能面临的最大问题之一就是训练数据不平衡的问题。本文将尝试使用图像分类问题来揭示训练数据中不平衡类别的奥秘。

1

数据不平衡问题是什么?

       在一个分类问题中,当你想要预测一个或多个类中的样本数量极少时,可能会遇到数据中类不平衡的问题,即部分类的样本数量远远大于其它类中的样本数量。

例子

  • 欺诈预测(真实交易的欺诈数量要低得多);
  • 自然灾害预测(坏事件发生的频率将远远低于好事);
  • 识别图像分类中的恶性肿瘤(具有肿瘤的图像将比训练样本内的无肿瘤的图像少得多);

为什么这会是个问题?

       不平衡课程造成问题主要是由于以下两个原因:

  • 由于模型/算法从来没有充分地查看全部类别信息,对于实时不平衡的类别没有得到最优化的结果;
  • 由于少数样本类的观察次数极少,这会产生一个验证或测试样本的问题,即很难在类中进行表示;

解决这个问题的方法有哪些?

       解决这个问题的方法主要有三种,三种各有各自的优缺点:

  • 下采样(Undersampling):随机删除具有足够观察多样本的类,以便数据中类的数量比较平衡。虽然这种方法非常简单,但很有可能删除的数据中可能包含有关预测的重要信息。
  • 过采样(Oversampling):对于不平衡类(样本数少的类),随机地增加观测样本的数量,这些观测样本只是现有样本的副本,虽然增加了样本的数量,但过采样可能导致训练数据过拟合。
  • 合成取样(SMOT):该技术要求综合地制造不平衡类的样本,类似于使用最近邻分类。问题是当观察的数目是极其罕见的类时不知道怎么做。
           尽管每种方法都有各自的优点,但没有什么固定的使用方式,需要根据实际问题不断自己尝试。现在将使用深度学习特定的图像分类问题来详细研究这个问题。

图像分类中的不平衡类

       在本节中,将分析一个图像分类问题(其中存在不平衡类问题),然后使用一种简单有效的技术来解决它。
       问题:在kaggle上选择了“驼背鲸识别挑战”任务,期望解决不平衡类别的挑战(理想情况下,所分类的鲸鱼数量少于未分类的鲸类)。
       Kagele上任务说明:在这场比赛中,面临的挑战是要建立一个算法来识别图像中的鲸鱼种类。将分析Happy Whale数据库(包含25,000多张图像),这些数据来自研究机构和公共贡献者。通过竞赛,你将有助于为全球海洋哺乳动物种群动态开启丰富的理解领域。

查看Happy Whale数据集

       由于这是一个多标签图像分类问题,首先想要检查数据是如何在类中分布的。

2


       上图表明,在4251张训练图像中,每个类只有一张图像的超过了2000张。还有一些类只有2~5张图像。可见这是一个严重的不平衡类问题。我们不能期望深度学习模型每个类别仅使用一张图像进行训练。这也会产生一个问题,即如何在训练和验证样本之间创建一个分界线,理想情况下希望每个类都在训练样本和验证样本中都有表示。
接下来应该做什么?
       本文考虑了两个特别的选项:
  • 选项1:对训练样本进行严格的数据增强(只需要针对特定类的数据增强,单这可能无法完全解决本文的问题)。
  • 选项2:类似于之前提到的过采样技术。只是使用不同的图像增强技术将不平衡类的图像复制到训练数据中15次。
           在开始使用选项2处理数据之前,可以从训练样本中查看少量图像。

3


       从图像中可以看到,图像是特定于鲸鱼的尾巴,因此,识别将可能与图像的方向有关。同时注意到数据中有很多图像是特定的黑白或只有R/G/B通道。
       根据这些观察结果,使用以下代码对训练样本中不平衡类的图像进行小幅改动并保存:
import os
from PIL import Image
from PIL import ImageFilter
filelist = train['Image'].loc[(train['cnt_freq']<10)].tolist()
for count in range(0,2):
  
  for imagefile in filelist:
    os.chdir('/home/paperspace/fastai/courses/dl1/data/humpback/train')
    im=Image.open(imagefile)
    im=im.convert("RGB")
    r,g,b=im.split()
    r=r.convert("RGB")
    g=g.convert("RGB")
    b=b.convert("RGB")
    im_blur=im.filter(ImageFilter.GaussianBlur)
    im_unsharp=im.filter(ImageFilter.UnsharpMask)
    
    os.chdir('/home/paperspace/fastai/courses/dl1/data/humpback/copy')
    r.save(str(count)+'r_'+imagefile)
    g.save(str(count)+'g_'+imagefile)
    b.save(str(count)+'b_'+imagefile)
    im_blur.save(str(count)+'bl_'+imagefile)
    im_unsharp.save(str(count)+'un_'+imagefile)

       以上代码对不平衡类中的每张图像(频率小于10)都进行如下处理:

  • 将每张图像的增强副本保存为R / B&G ;
  • 保存每张图像的增强副本;
  • 保存每张图像未锐化的增强副本;
           在上面的代码中可以看到,使用pillow库来严格执行此练习,现在已经为所有不平衡的类分配了至少10个样本。接下来进行训练。

图像增强:只想确保模型能够获得鲸鱼fluke的详细视图。为此,将缩放合并成图像增强。

4


        学习率设定:从图中可以看到,将学习率定为0.01时效果最好。

5


       使用Resnet50模型(第一层参数不变)进行了很少的迭代训练就能取得很好的效果,这是由于imagenet数据库中也有鲸鱼图像。
epoch    trn_loss   val_loss   accuracy                     
    0      1.827677   0.492113   0.895976  
    1      0.93804    0.188566   0.964128                      
    2      0.844708   0.175866   0.967555                      
    3      0.571255   0.126632   0.977614                      
    4      0.458565   0.116253   0.979991                      
    5      0.410907   0.113607   0.980544                      
    6      0.42319    0.109893   0.981097

测试数据集上效果如何?

       在kaggle排行榜上可以看到模型在测试集上的效果,本文提出的解决方案在本次比赛中排名34,平均精度均值(MAP)为0.41928。

6

结论

有时候,最简单的方法是最合乎逻辑的(如果你没有更多的数据,只需要复制现有的数据,并有轻微的变化即可),也是最有效的。

数十款阿里云产品限时折扣中,赶紧点击领劵开始云上实践吧!

作者信息

Shubrashankh Chatterjee,深度学习和数据科学爱好者
个人主页:https://www.linkedin.com/in/shubrashankh-chatterjee/
本文由阿里云云栖社区组织翻译。
文章原标题《Deep Learning Tips and Tricks》,译者:海棠,审校:Uncle_LLD
文章为简译,更为详细的内容,请查看原文

相关文章
|
1月前
|
机器学习/深度学习 数据采集 算法
构建高效图像分类模型:深度学习在处理大规模视觉数据中的应用
随着数字化时代的到来,海量的图像数据被不断产生。深度学习技术因其在处理高维度、非线性和大规模数据集上的卓越性能,已成为图像分类任务的核心方法。本文将详细探讨如何构建一个高效的深度学习模型用于图像分类,包括数据预处理、选择合适的网络架构、训练技巧以及模型优化策略。我们将重点分析卷积神经网络(CNN)在图像识别中的运用,并提出一种改进的训练流程,旨在提升模型的泛化能力和计算效率。通过实验验证,我们的模型能够在保持较低计算成本的同时,达到较高的准确率,为大规模图像数据的自动分类和识别提供了一种有效的解决方案。
|
2月前
|
机器学习/深度学习 人工智能 自动驾驶
深度学习-数据增强与扩充
深度学习-数据增强与扩充
75 1
|
3月前
|
机器学习/深度学习 人工智能 自然语言处理
大数据分析的技术和方法:从深度学习到机器学习
大数据时代的到来,让数据分析成为了企业和组织中不可或缺的一环。如何高效地处理庞大的数据集并且从中发现潜在的价值是每个数据分析师都需要掌握的技能。本文将介绍大数据分析的技术和方法,包括深度学习、机器学习、数据挖掘等方面的应用,以及如何通过这些技术和方法来解决实际问题。
52 2
|
4月前
|
机器学习/深度学习 算法 TensorFlow
【Python深度学习】Tensorflow对半环形数据分类、手写数字识别、猫狗识别实战(附源码)
【Python深度学习】Tensorflow对半环形数据分类、手写数字识别、猫狗识别实战(附源码)
57 0
|
6月前
|
机器学习/深度学习 数据采集 PyTorch
使用自定义 PyTorch 运算符优化深度学习数据输入管道
使用自定义 PyTorch 运算符优化深度学习数据输入管道
37 0
|
3天前
|
机器学习/深度学习 传感器 数据可视化
MATLAB用深度学习长短期记忆 (LSTM) 神经网络对智能手机传感器时间序列数据进行分类
MATLAB用深度学习长短期记忆 (LSTM) 神经网络对智能手机传感器时间序列数据进行分类
16 1
MATLAB用深度学习长短期记忆 (LSTM) 神经网络对智能手机传感器时间序列数据进行分类
|
7天前
|
机器学习/深度学习 数据可视化 测试技术
深度学习:Keras使用神经网络进行简单文本分类分析新闻组数据
深度学习:Keras使用神经网络进行简单文本分类分析新闻组数据
21 0
|
8天前
|
机器学习/深度学习 API 算法框架/工具
R语言深度学习:用keras神经网络回归模型预测时间序列数据
R语言深度学习:用keras神经网络回归模型预测时间序列数据
16 0
|
8天前
|
机器学习/深度学习 数据采集 TensorFlow
R语言KERAS深度学习CNN卷积神经网络分类识别手写数字图像数据(MNIST)
R语言KERAS深度学习CNN卷积神经网络分类识别手写数字图像数据(MNIST)
27 0
|
1月前
|
机器学习/深度学习 算法 计算机视觉
SISR深度学习主要方法简述
SISR深度学习主要方法简述
12 0