半监督学习(SSL)提供了一种利用无标签数据提高模型性能的有效方法,这一领域最近取得了快速进展,但以往的算法需要借助复杂的损失函数和大量难以调整的超参数。本文介绍了谷歌的研究团队提出的FixMatch[1],这是一种大大简化现有 SSL 方法的算法。FixMatch是SSL的两种方法的组合:一致性正则和伪标签。
如图所示为FixMatch的流程图。FixMatch的新颖之处在于,对于无标签的样本:
FixMatch首先对弱增强的无标签图像预测伪标签,对于给定的图像,只有当模型产生高于阈值的预测时,才会保留作为伪标签;
再对同一图像的强增强版本预测出分类概率,通过交叉熵损失衡量强弱二者的预测的一致性。
我们首先回顾一下利用了一致性正则和伪标签方法的经典SSL算法,然后再详细描述FixMatch算法,并穿插关键部分的飞桨代码实现(采用飞桨最新稳定版本)。
一致性正则
一致性正则是许多SSL算法的重要组成部分。一致性正则的思想是——即使在无标签的样本被注入噪声之后,分类器也应该为其输出相同的类分布概率。即强制一个无标签的样本
伪标签
指使用模型本身为无标签数据获取标签的方法。具体而言,将模型输出的softmax概率分布视为软伪标签;或将经过argmax或者one_hot得到的预测视为硬伪标签。利用这些伪标签作为监督损失进一步训练模型。
SSL算法的比较
下表提到了与生成伪标签(Artificial label)相关的SSL算法。其中列出了用于伪标签的数据增强、模型的预测以及应用于伪标签的后处理。
FixMatch的核心是一致性正则和伪标签方法的简单组合,无标签模型预测与UDA一样采用RandAugment[3]进行强增强,详细实现见AI Studio项目。
FixMatch
FixMatch的损失函数
FixMatch的损失函数由两个交叉熵损失项组成:一个是应用于有标签数据的全监督损失,另一个是用于无标签数据的一致性正则损失。
令表示batch size为
令是batch size为
对于有标签样本,FixMatch均采用弱增强,其损失函数为:
"""
将有 / 无标签的 batch 拼接后输入模型
:inputs_x: 有标签数据
:inputs_u_w: 无标签数据的弱增强
:inputs_u_s: 无标签数据的强增强
"""
inputs = interleave(
paddle.concat((inputs_x, inputs_u_w, inputs_u_s)), 2 * args.mu + 1)
# 模型输出(全连接层分类预测)
logits = model(inputs)
logits = de_interleave(logits, 2 * args.mu + 1)
# 有标签数据的模型输出
logits_x = logits[:batch_size]
# 有标签预测的交叉熵损失
Lx = F.cross_entropy(logits_x, targets_x, reduction='mean')
对于无标签样本,FixMatch为每个无标签样本预测一个伪标签,然后用于计算交叉熵损失。为了获得一个伪标签,首先输入无标签图像的弱增强版本
其中,
# 弱增强和强增强模型预测
logits_u_w, logits_u_s = logits[batch_size:].chunk(2)
# 对弱增强的模型输出使用 softmax + argmax 得到伪标签 targets_u
pseudo_label = F.softmax(logits_u_w.detach() / args.T, axis=-1)
targets_u = paddle.argmax(pseudo_label, axis=-1) # 利用 argmax 得到硬伪标签
# 通过阈值筛选伪标签
max_probs = paddle.max(pseudo_label, axis=-1)
mask = paddle.greater_equal(
max_probs,
paddle.to_tensor(args.threshold)).astype(paddle.float32)
# 无标签预测的交叉熵损失(一致性损失)
Lu = (F.cross_entropy(logits_u_s, targets_u,
reduction='none') * mask).mean()
# 两个损失加权相加
loss = Lx + args.lambda_u * Lu
最终,FixMatch的总损失是两个损失函数的加权和:
FixMatch的简洁之处
FixMatch和前面提到的SSL方法的关键区别在于,伪标签是基于弱增强图像预测的硬伪标签,而对于强增强图像的模型输出的全连接层预测直接计算损失(不进行 argmax),这对FixMatch的成功至关重要。
UDA和MixMatch中用了sharpen构建软伪标签,sharpen 引入了一个超参数
另外,在Mean-Teacher、MixMatch等SSL算法中,在训练期间会增加无标签损失项的权重(
FixMatch的“强弱调和”
FixMatch利用了两种数据增强:“弱”和“强”。弱增强是标准的随机翻转和移位的数据增强策略。
对于弱增强,FIxMatch在有标签数据样本上以50%的概率进行水平翻转图像;以12.5%的概率在垂直和水平方向上随机平移图像;
对于强增强,FixMatch与UDA一样利用了RandAugment为每个无标签样本随机选择变换。
论文还研究了弱增强和强增强的不同组合对伪标签生成的影响:
当将预测伪标签的弱增强替换为强增强时,实验发现模型在训练早期就出现了分歧;
相反,当用无增强替换弱增强时,该模型会过度拟合无标签数据;
使用弱增强代替原先的强增强时,只能达到45%的准确率峰值,但不稳定,并逐渐下降到12%,表明了强增强的重要性。
FixMatch的优化器
FixMatch使用简单的weight decay模型参数正则化。论文做了消融实验,相较于使用Momentum优化器,使用Adam优化器会导致更差的性能。对于Momentum优化器参数的设置,momentum=0.9,weight_decay=0.0005,use_nesterov=True。
由于优化器采用了weight_decay,需要剔除设置了bias=True参数的网络层和BatchNorm层。
no_decay = ['bias', 'bn']
scheduler = get_cosine_schedule_with_warmup(args.lr, args.warmup, args.total_steps)
grouped_parameters = [
# 若网络层不包含 bias 或 BatchNorm,则应用 weight_decay
{'params': [p for n, p in model.named_parameters() if not any(
nd in n for nd in no_decay)], 'weight_decay': args.wdecay},
# 反之,则不用 weight_decay
{'params': [p for n, p in model.named_parameters() if any(
nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = optim.Momentum(learning_rate=scheduler,
momentum=0.9,
parameters=grouped_parameters,
use_nesterov=args.nesterov)
使用余弦学习速率衰减,衰减策略设置为,其中
def get_cosine_schedule_with_warmup(learning_rate, num_warmup_steps,
num_training_steps,
num_cycles=7. / 16.,
last_epoch=-1):
"""
借助 LambdaDecay 实现余弦学习率衰减
"""
def _lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
no_progress = float(current_step - num_warmup_steps) / \
float(max(1, num_training_steps - num_warmup_steps))
return max(0., math.cos(math.pi * num_cycles * no_progress))
return LambdaDecay(learning_rate=learning_rate,
lr_lambda=_lr_lambda,
last_epoch=last_epoch)
最后,FixMatch使用训练过程中每个eval_step的模型参数的指数移动平均值(EMA)作为最终的预测模型。
backbone网络架构默认为 Wide ResNet-28-2[4],详细实现见AI Studio项目。训练的超参数如下:
无标签损失权重
初始学习率
优化器 momentum 参数
伪标签阈值
有 / 无标签样本比例 1: 7:
batch_size
总训练,EMA eval_step =1024
尽管FixMatch非常简单,但它在各种标准的半监督学习benchmark上都达到了SOTA,在CIFAR-10[5]上仅有250个标签时的准确率为94.93%,在40个标签时的准确率为88.61%(每类仅4个标签)。下表为五折交叉验证得出的FixMatch及其baselines在CIFAR-10数据集上的错误率:
CIFAR-10数据集在飞桨复现版本的精度如下:
结论
在半监督学习算法日益复杂的发展中,FixMatch以出人意料的简单获得了SOTA性能——在有标签和无标签的数据上只使用标准的交叉熵损失,FixMatch的训练只需几行代码即可完成。
论文指出,由于这种简单性,我们能够彻底研究FixMatch是如何发挥作用的。我们发现某些设计选择很重要(而且往往被低估)——最重要的是weight decay和优化器的选择。总的来说,我们相信,这种简单但性能良好的半监督机器学习算法的存在,将有助于机器学习被应用到越来越多的标签价格昂贵或难以获得的实际领域。
Github:
https://github.com/S-HuaBomb/FixMatch-Paddle
AIStudio:
https://aistudio.baidu.com/aistudio/projectdetail/2509943?contributionType=1
扫码报名,赢取万元奖金
相关阅读
关注【飞桨PaddlePaddle】公众号
获取更多技术内容~