对比学习是机器学习(ML)中的一项强大技术,在自我监督学习(SSL)中尤为突出。它不依赖于精心标注的数据,而是通过教授一个模型来区分相似和不相似的数据点,从而学习有意义的表征。其核心思想很简单:在嵌入空间中将 "相似 "示例的表征拉近,同时将 "不相似 "示例的表征推远。这种方法可以让模型从大量未标记的数据中学习丰富的特征,然后通过微调使其适用于各种下游任务。
对比学习如何发挥作用
这一过程通常包括以下步骤:
- 数据增强:从未标明的数据点(如图像)开始。为该数据点创建两个或多个增强版本。这些增强版本形成 "正对",因为它们来自同一来源,应被视为相似。常见的数据增强技术包括随机裁剪、颜色抖动、旋转或添加噪音。
- 负采样:从数据集中(或当前批次)选择与原始数据点不同的其他数据点。这些数据点与原始数据点的增强数据形成 "负对"。
- 编码:将正样本和负样本通过编码器神经网络 (NN),通常是用于图像的卷积神经网络 (CNN),或用于文本或图像的变换器(视觉变换器 (ViT))。该网络将输入数据转换为低维表示,即嵌入。
- 损失计算:应用对比损失函数,如 InfoNCE(噪声对比估计)或三重损失。该函数根据嵌入式之间的距离计算分数。它鼓励正向数据对的嵌入式数据接近(低距离/高相似性),而反向数据对的嵌入式数据相距甚远(高距离/低相似性)。
- 优化:使用随机梯度下降(SGD)或亚当(Adam )等优化算法,根据计算出的损失更新编码器权重,通过反向传播迭代改进所学表征的质量。
对比学习与相关术语
对比学习与其他 ML 范式不同:
- 监督学习:要求每个数据点都有明确的标签(如 "猫"、"狗")。对比学习主要使用无标签数据,通过正负配对产生自己的监督信号。
- 无监督学习(聚类): K-Means等方法根据固有结构对数据进行分组。对比学习明确地训练一个模型,以创建一个表示空间,其中的相似性是由正/负对定义的,重点是学习判别特征。
- 生成模型: GANs或扩散模型等模型通过学习生成与训练数据相似的新数据。对比学习侧重于学习判别表征,而不是生成数据。
实际应用
对比学习擅长学习能很好地迁移到其他任务中的表征:
优势与挑战
好处
- 减少标签依赖性:利用大量未标记数据,减少了昂贵而耗时的数据标记需求。
- 鲁棒性表征:与纯粹的监督式方法相比,它通常能学习不受干扰变化影响的特征。
- 有效的预训练:为特定下游任务的微调提供了绝佳的起点,往往能带来更好的性能,尤其是在标注数据有限的情况下(少量学习)。
挑战: