少样本学习综述:技术、算法和模型
机器学习最近取得了很大的进展,但仍然有一个主要的挑战:需要大量的标记数据来训练模型。
有时这种数据在现实世界中是无法获得的。以医疗保健为例,我们可能没有足够的x光扫描来检查一种新的疾病。但是通过少样本学习可以让模型只从几个例子中学习到知识!
所以少样本学习(FSL)是机器学习的一个子领域,它解决了只用少量标记示例学习新任务的问题。FSL的全部意义在于让机器学习模型能够用一点点数据学习新东西,这在收集一堆标记数据太昂贵、花费太长时间或不实用的情况下非常有用。
少样本学习方法
支持样本/查询集:使用少量图片对查询集进行分类。
少样本学习中有三种主要方法需要了解:元学习、数据级和参数级。
- 元学习:元学习包括训练一个模型,学习如何有效地学习新任务;
- 数据级:数据级方法侧重于增加可用数据,以提高模型的泛化性能;
- 参数级:参数级方法旨在学习更健壮的特征表示,以便更好地泛化到新任务中
元学习
元学习(学习如何学习)。这种方法训练一个模型学习如何有效地学习新任务。这个模型是关于识别不同任务之间的共同点,并使用这些知识通过几个例子快速学习新东西。
元学习算法通常在一组相关任务上训练模型,并学习从可用数据中提取与任务无关的特征和特定于任务的特征。任务无关的特征捕获关于数据的一般知识,而任务特定的特征捕获当前任务的细节。在训练过程中,算法通过仅使用每个新任务的几个标记示例更新模型参数来学习适应新任务。这使得模型可以用很少的示例推广到新的任务。
数据级方法
数据级方法侧重于扩充现有数据,这样可以帮助模型更好地理解数据的底层结构,从而提高模型的泛化性能。
主要思想是通过对现有示例应用各种转换来创建新的示例,这可以帮助模型更好地理解数据的底层结构。
有两种类型的数据级方法:
- 数据增强:数据增强包括通过对现有数据应用不同的转换来创建新的示例;
- 数据生成:数据生成涉及使用生成对抗网络(GANs)从头生成新的示例。
数据级的方法:
参数级方法目标是学习更健壮的特征表示,可以更好地泛化到新的任务。
有两种参数级方法:
- 特征提取:特征提取涉及从数据中学习一组特征,可以用于新任务;
- 微调:微调包括通过学习最优参数使预训练的模型适应新任务。
例如,假设你有一个预先训练好的模型,它可以识别图像中的不同形状和颜色。通过在新数据集上微调模型,只需几个示例,它就可以快速学会识别新的类别。
元学习算法
元学习是FSL的一种流行方法,它涉及到在各种相关任务上训练模型,以便它能够学习如何有效地学习新任务。该算法学习从可用数据中提取任务无关和任务特定的特征,快速适应新的任务。
元学习算法可以大致分为两种类型:基于度量的和基于梯度的。
基于度量的元学习
基于度量的元学习算法学习一种特殊的方法来比较每个新任务的不同示例。他们通过将输入示例映射到一个特殊的特征空间来实现这一点,在这个空间中,相似的示例放在一起,而不同的示例则分开很远。模型可以使用这个距离度量将新的示例分类到正确的类别中。
一种流行的基于度量的算法是Siamese Network,它学习如何通过使用两个相同的子网络来测量两个输入示例之间的距离。这些子网络为每个输入示例生成特征表示,然后使用距离度量(如欧几里得距离或余弦相似度)比较它们的输出。
基于梯度元的学习
基于梯度的元学习学习如何更新他们的参数,以便他们能够快速适应新的挑战。
这些算法训练模型学习一组初始参数,只需几个例子就能快速适应新任务。MAML (model – agnostic元学习)是一种流行的基于梯度的元学习算法,它学习如何优化模型的参数以快速适应新任务。它通过一系列相关任务来训练模型,并使用每个任务中的一些示例来更新模型的参数。一旦模型学习到这些参数,它就可以使用当前任务中的其他示例对它们进行微调,提高其性能。
基于少样本学习的图像分类算法
FSL有几种算法,包括:
- 与模型无关的元学习(Model-Agnostic Meta-Learning):MAML是一种元学习算法,它为模型学习了一个良好的初始化,然后可以用少量的例子适应新的任务。
- 匹配网络 (Matching Networks):匹配网络通过计算相似度来学习将新例子与标记的例子匹配。
- 原型网络(Prototypical Networks):原型网络学习每个类的原型表示,根据它们与原型的相似性对新示例进行分类。
- 关系网络(Relation Networks):关系网络学会比较成对的例子,对新的例子做出预测。
与模型无关的元学习
MAML的关键思想是学习模型参数的初始化,这些参数可以通过一些示例适应新任务。在训练过程中,MAML接受一组相关任务,并学习仅使用每个任务的几个标记示例来更新模型参数。这一过程使模型能够通过学习模型参数的良好初始化来泛化到新的任务,这些参数可以快速适应新的任务。
匹配网络
匹配网络是另一种常用的少样本图像分类算法。它不是学习固定的度量或参数,而是基于当前支持集学习动态度量。这意味着用于比较查询图像和支持集的度量因每个查询图像而异。
匹配网络算法使用一种注意力机制来计算每个查询图像的支持集特征的加权和。权重是根据查询图像和每个支持集图像之间的相似性来学习的。然后将支持集特征的加权和与查询图像特征连接起来,得到的向量通过几个全连接的层来产生最终的分类。
原型网络
原型网络是一种简单有效的少样本图像分类算法。它学习图像的表示,并使用支持示例的嵌入特征的平均值计算每个类的原型。在测试过程中,计算查询图像与每个类原型之间的距离,并将原型最近的类分配给查询。
关系网络
关系网络学习比较支持集中的示例对,并使用此信息对查询示例进行分类。关系网络包括两个子网络:特征嵌入网络和关系网络。特征嵌入网络将支持集中的每个示例和查询示例映射到一个特征空间。然后关系网络计算查询示例和每个支持集示例之间的关系分数。最后使用这些关系分数对查询示例进行分类。
少样本学习的应用
少样本学习在不同的领域有许多应用,包括:
在各种计算机视觉任务中,包括图像分类、目标检测和分割。少样本学习可以识别图像中不存在于训练数据中的新对象。
在自然语言处理任务中,如文本分类、情感分析和语言建模,少样本学习有助于提高语言模型在低资源语言上的性能。
在机器人技术中使用少数次学习,使机器人能够快速学习新任务,适应新环境。例如,机器人只需要几个例子就可以学会捡起新物体。
少样本在医疗诊断领域可以在数据有限的情况下识别罕见疾病和异常,可以帮助个性化治疗和预测病人的结果。
总结
少样本学习是一种强大的技术,它使模型能够从少数例子中学习。它在各个领域都有大量的应用,并有可能彻底改变机器学习。随着不断的研究和开发,少样本学习可以为更高效和有效的机器学习系统铺平道路。