mixup

mixup是基于邻域风险最小化(Vicinal Risk Minimization, VRM)原则的数据增强方法,使用线性插值得到新样本数据。

在邻域风险最小化原则下,根据特征向量线性插值将导致相关目标线性插值的先验知识,可得出简单且与数据无关的mixup公式:
在这里插入图片描述
其中(xn,yn)是插值生成的新数据,(xi,yi) 和 (xj,yj)是训练集中随机选取的两个数据,λ的取值满足贝塔分布,取值范围介于0到1,超参数α控制特征目标之间的插值强度。

mixup的实验丰富,实验结果表明可以改进深度学习模型在ImageNet数据集、CIFAR数据集、语音数据集和表格数据集中的泛化误差,降低模型对已损坏标签的记忆,增强模型对对抗样本的鲁棒性和训练对抗生成网络的稳定性。

PyTorch实现

def mixup(data, target, alpha):
    indices = torch.randperm(data.size(0))
    shuffled_data = data[indices]
    shuffled_target = target[indices]

    lam = np.clip(np.random.beta(alpha, alpha),0.3,0.7)
    data = lam*data + (1-lam)*shuffled_data
    targets = (target, shuffled_target, lam)

    return data, targets

在这里插入图片描述

Reference

https://www.kaggle.com/virajbagal/mixup-cutmix-fmix-visualisations
https://www.jianshu.com/p/99450dbdadcf

Logo

一站式 AI 云服务平台

更多推荐