跳转至

提升分类模型 acc(一):BatchSize&LARS

【GiantPandaCV 导读】在使用大的 bs 训练情况下,会对精度有一定程度的损失,本文探讨了训练的 bs 大小对精度的影响,同时探究 LARS 是否可以有效的提升精度。

论文链接:https://arxiv.org/abs/1708.03888

论文代码: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py

知乎专栏: https://zhuanlan.zhihu.com/p/406882110

一、引言

如何提升业务分类模型的性能,一直是个难题,毕竟没有 99.999% 的性能都会带来一定程度的风险,所以很多时候我们只能通过控制阈值来调整准召以达到想要的效果。本系列主要探究哪些模型 trick 和数据的方法可以大幅度让你的分类性能更上一层楼,不过要注意一点的是,tirck 不一定是适用于不同的数据场景的,但是数据处理方法是普适的。本篇文章主要是对于大的 bs 下训练分类模型的情况,如果 bs 比较小的可以忽略,直接看最后的结论就好了 (这个系列以后的文章讲述的方法是通用的,无论 bs 大小都可以用)。

二、实验配置

  • 模型:ResNet50
  • 数据:ImageNet1k
  • 环境:8xV100

三、BatchSize 对精度的影响

所有的实验的超参都是统一的, warmup 5 个 epoch,训练 90 个 epoch,StepLR 进行衰减,学习率的设置和 bs 线性相关,公式为lr=BatchSize256baselr,优化器使用带有 0.9 的动量的 SGD,baselr 为 0.1(如果采用 Adam 或者 AdamW 优化器的话,公式需要调整为lr=BatchSize512baselr),训练的数据增强只有RandomCropResize,RandomFlip,验证的数据增强为ResizeCenterCrop

训练情况如下:

  • lr 调整曲线如下: lr

  • 训练曲线如下: train

  • 验证曲线如下: val

我这里设计了 4 组对照实验,256, 1024, 2048 和 4096 的 batchsize,开了 FP16 也只能跑到了 4096 了。采用的是分布式训练,所以单张卡的 bs 就是bs = total_bs / ngpus_per_node。这里我没有使用跨卡 bn,对于 bs 64 单卡来说理论上已经很大了,bn 的作用是约束数据分布,64 的 bs 已经可以表达一个分布的 subset 了,再大的 bs 还是同分布的,意义不大,跨卡 bn 的速度也更慢,所以大的 bs 基本可以忽略这个问题。但是对于检测的任务,跨卡 bn 还是有价值的,毕竟输入的分辨率大,单卡的 bs 比较小,一般 4,8,16,这时候统计更大的 bn 会对模型收敛更好。

实验结果如下:

模型 epoch LR batchsize dataaug acc@top1
ResNet50 90 0.1 256 randomcropresize,randomflip 76.422%
ResNet50 90 0.4 1024 randomcropresize,randomflip 76.228%
ResNet50 90 0.8 2048 randomcropresize,randomflip 76.132%
ResNet50 90 1.6 4096 randomcropresize,randomflip 75.75%

很明显可以看出来,当 bs 增加到 4k 的时候,acc 下降了将近 0.8% 个点,1k 的时候,下降了 0.2% 个点,所以,通常我们用大的 bs 训练的时候,是没办法达到最优的精度的。个人建议,使用 1k 的 bs 和 0.4 的学习率最优。

四、LARS(Layer-wise Adaptive Rate Scaling)

1. 理论分析

由于 bs 的增加,在同样的 epoch 的情况下,会使网络的 weights 更新迭代的次数变少,所以需要对 LR 随着 bs 的增加而线性增加,但是这样会导致上面我们看到的问题,过大的 lr 会导致最终的收敛不稳定,精度有所下降。

LARS 的出发点则是各个层的更新参数使用的学习率应该根据自己的情况有所调整,而不是所有层使用相同的学习率,也就是每层有自己的 local lr,所以有:

λl=η||wl||||L(wl)||+β||wl||

这里,l 表示的是第几层,η 表示的是超参数,这个超参数远小于 1,表示每层会改变参数的 confidence,局部学习率可以很方便的替换每层的全局学习率,参数的更新大小为:

wlt=γλlL(wlt)

与 SGD 联合使用的算法如下:

LARS

LARS 代码如下:

class LARC(object):
    def __init__(self, optimizer, trust_coefficient=0.02, clip=True, eps=1e-8):
        self.optim = optimizer
        self.trust_coefficient = trust_coefficient
        self.eps = eps
        self.clip = clip

    def step(self):
        with torch.no_grad():
            weight_decays = []
            for group in self.optim.param_groups:
                # absorb weight decay control from optimizer
                weight_decay = group['weight_decay'] if 'weight_decay' in group else 0
                weight_decays.append(weight_decay)
                group['weight_decay'] = 0
                for p in group['params']:
                    if p.grad is None:
                        continue
                    param_norm = torch.norm(p.data)
                    grad_norm = torch.norm(p.grad.data)

                    if param_norm != 0 and grad_norm != 0:
                        # calculate adaptive lr + weight decay
                        adaptive_lr = self.trust_coefficient * (param_norm) / (
                                    grad_norm + param_norm * weight_decay + self.eps)

                        # clip learning rate for LARC
                        if self.clip:
                            # calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)`
                            adaptive_lr = min(adaptive_lr / group['lr'], 1)

                        p.grad.data += weight_decay * p.data
                        p.grad.data *= adaptive_lr

        self.optim.step()
        # return weight decay control to optimizer
        for i, group in enumerate(self.optim.param_groups):
            group['weight_decay'] = weight_decays[i]
这里有一个超参数,trust_coefficient,也就是公式里面所提到的η, 这个参数对精度的影响比较大,实验部分我们会给出结论。

2. 实验结论

模型 epoch LR batchsize dataaug acc@top1 trust_confidence
ResNet50 90 0.4 1024 randomcropresize,randomflip 75.146% 1e-3
ResNet50 90 0.8 2048 randomcropresize,randomflip 73.946% 1e-3
ResNet50 90 1.6 4096 randomcropresize,randomflip 72.396% 1e-3
ResNet50 90 0.4 1024 randomcropresize,randomflip 76.234% 2e-2
ResNet50 90 0.8 2048 randomcropresize,randomflip 75.898% 2e-2
ResNet50 90 1.6 4096 randomcropresize,randomflip 75.842% 2e-2

可以很明显发现,使用了 LARS, 设置 turst_confidence 为 1e-3 的情况下,有着明显的掉点,设置为 2e-2 的时候,在 1k 和 4k 的情况下,有着明显的提升,但是 2k 的情况下有所下降。

LARS 一定程度上可以提升精度,但是强依赖超参,还是需要细致的调参训练。

四、结论

  • 8 卡进行分布式训练,使用 1k 的 bs 可以很好的平衡 acc&speed。
  • LARS 一定程度上可以提升精度,但是需要调参,做业务可以不用考虑,刷点的话要好好训练。

五、结束语

本文是提升分类模型 acc 系列的第一篇,后续会讲解一些通用的 trick 和数据处理的方法,敬请关注。


本文总阅读量388