所有文章 > 正文

CVPR 2023 | 大模型流行之下,SN-Net给出一份独特的答卷

作者: PaperWeekly

时间: 2023-03-21 22:41

本文介绍我们组在 CVPR 2023 的工作:Stitchable Neural Networks,下文简称 SN-Net。一种全新的模型部署方法,利用现有的model family直接做少量 epoch finetune 就可以得到大量插值般存在的子网络,运行时任意切换网络结构满足不同 resource constraint。

40-Do6ZyAvGyn.png

论文链接:Stitchable Neural Networks - AMiner

41-Mnv8POMUGS.png

背景

去年一次组会上,在和导师们讨论未来的 research 方向的时候,偶然聊到一个问题:

视频网站的视频播放会自动根据网络带宽调整画质,如网速好的时候到4K,网速差就720P甚至更低。那同一个神经网络能不能随时根据计算资源的变化调整推理速度?

从 2012 的 AlexNet 到 2023 年火出圈的 ChatGPT, AI/ML 这一社区在十年间少说已经训练了上百万个模型。截至这篇文章写作时,HuggingFace 上可以直接下载的模型就有 14 万个,涵盖各个模态和任务。每个模型各司其职,用自己在训练中学到的知识去处理某一种场景,互不叨扰。

42-Wpm95ZlLLq.jpg

模型虽然越来越多,但是资源浪费也越来越严重。训练一个模型的成本很高,尤其是大模型训练,耗费数个节点和几天的算力才能得到一个好权重,但最后却受限于应用场景只能重新调整结构,然后再重新训练,如网络 backbone 设计中通常会有不同 scale 来满足不同的推理速度要求: ResNet-18/50/101,DeiT-Ti/S/B,Swin-Ti/S/B 等等。

传统方法当然能加速模型推理,如 pruning,distillation,quantization。但问题是这些方法一次大都只能针对一个模型,一个资源场景。我们也可以用 NAS 搜出来若干个子网络来满足不同推理速度需求,即使如此,NAS 中训练一个 Supernet 的成本也是巨大的,典型的如 OFA 和 BigNAS,花费上千 GPU hours 才得到一个好网络,资源消耗巨大。 

看着 huggingface 上这么大的 model zoo,我们不禁想,整个社区花了大量时间,金钱和人力资源去训练网络,得到了这么多的 pretrained model,但是能不能有效利用起来?况且这些模型已经训练好了,当需要他们的时候,能不能用少量计算资源就可以满足目标场景? 

对这一问题的思考也是随着模型被工业界越推越大引出的。几年前一张 1080 就能跑完的实验,现在 8 张卡都很难 train 得动一个 model,特别是 Transformer 出来之后。最新的 ViT 已经 scale 到 22B,BAAI 的 EVA 也把 ViT 扩展到了 1B 的参数级别。留给小组的空间越来越小,在资源有限(缺卡)的场景下,我们需要寻求新的突破方向。

43-EbokOqYvVG.png

Stitchable Neural Networks

Industry 和 Academia 所关注的问题可以有些区别。既然大模型不是所有人能做得起的,那我们不如去利用好已有的 pretrained model。现在我们有了一组训练好的 model family,如 DeiT-Tiny/Small/Base。不同模型有不同大小,推理速度,显存占用。那么能不能利用这些已有的 weights 和结构快速得到一批新网络来满足不同的资源场景? 

我们在 CVPR 2023 最新的工作 Stitchable Neural Network (SN-Net) 给出了一个非常具有潜力的方案。

44-iKg8UYT881.jpg

SN-Net 的主要思想是:在一组已经训练好的 model family 中插入若干个 stitching layer (即 1x1 conv), 使得 forward 时 activation 可以在模型间的不同位置游走。当模型在不同位置缝合的时候,一个个新网络结构就出来了!!! 

此时,我们把原先 model family 中的网络叫做 anchors,缝合出来的新网络叫做 stitches。单个 SN-Net 可以 cover 众多 FLOPs-accuracy 的 trade-off,如在基于 Swin 的实验中,一个 SN-Net 的可以挑战 timm 中 200 个独立的模型,整个实验不过是 50 epochs,八张 V100 上训练不到一天。

45-SKKUUuzIfe.png

下面会介绍详细的做法,以及我们当时方法设计时候的考虑。想直接看效果的朋友可以移步最后的结果展示。


1. 模型这么多,怎么去选择


这里主要考虑了几个地方: 

    • 不同模型结构在网络中各层学习到的 representation 会有较大差别,缝合出来的网络不一定保证较好的 performance;
    • 不同数据集学到的东西差别也很大,为了保证性能最好保持在相同 pretrained 的 dataset 下;
    • 不同网络的实现和训练方式有差别,工程上很难权衡超参和 data augmentation 的选择。


而同一个结构通常在一个 repo 里,更容易实现。 因此,我们初步关注在相同 dataset 上训练好的 model family 上, 即结构相似,但是模型 scale 不一样,如 DeiT-Ti/S/B。 

不同 family 能不能缝合?也能,我们 paper 里有展示结果,但是工程上会比较麻烦,需要 combine 不同 repo 并且权衡超参。


2. 怎么去做缝合?

model stitching 在原先工作中大都是以研究 representation similarity 的形式呈现的,如:

    • Lenc, Karel, and Andrea Vedaldi. "Understanding image representations by measuring their equivariance and equivalence." CVPR 2015. 
    • Kornblith, Simon, et al. "Similarity of neural network representations revisited." ICML, 2019. 
    • Csiszárik, Adrián, et al. "Similarity and matching of neural network representations." NeurIPS 2021. 


总结过去这些工作:同一个网络,用不同 seed 训练之后可以在某些位置缝合起来,此时性能不会掉的很离谱。后续的研究发现结构不一样的网络甚至也能缝合。 

而 stitching 能够 work 在于,假设前一个网络出来的 feature map 属于 activation 空间 A,而另一个网络在此位置的输入 feature map 属于 activation 空间 B,那么 stitching layer 做的事情就是把 feature map 从 A 空间映射到 B 空间,使得此时的 feature map 能模拟下一网络在这个位置的输入。 

当网络是已经是 pretrained,那么 stitching 这一过程完全可以 formulate 成一个求解 least squares 的问题。也就是说 stitching layer 这个 weights 的 matrix 是可以直接求出来的 (参考 Csiszárik, Adrián, et al 这篇)。所以此时求解出来的 matrix 可以天然作为 stitching layer 的初始化。

3. 缝合方向的设定

46-XBR6lpdEeU.png

现在我们有一个大模型:性能好但是推理速度慢,还有一个小模型:性能差点但是推理速度快。我们怎么决定谁 stitch 到谁呢?我们主要考虑了两个方面: 

    • 参考当前 backbone 设计的惯例,随着网络不断深入,channel dimension 是在不断增大的。Fast-to-Slow 这方向比较符合常见的网络设计;
    • 实验验证 Fast-to-Slow 得到的 curve 要比 Slow-to-Fast 要 smooth 一点,详见论文。 


所以目前 SN-Net 在方向上是从小模型缝合到大模型。同时我们提出一个 constraint: nearest stitching,限制 stitching 只在复杂度 (FLOPs) 相邻的两个 anchor 之间。如补充材料中的 Figure 10 所示,以 DeiT-Ti/S/B 为例,我们的方法目前限制在 (a), (b) 两个 case。

47-17W2UQcAN2.png

这个限制是因为我们发现 anchor 的 gap 比较大的时候,缝合出来的网络并不在一个 optimal 的区间。实验部分也证明直接 stitch DeiT-Ti 和 DeiT-B 效果不如中间加一个 DeiT-S。


4. 怎么配置Stitching Layer

48-ixvnX7jc47.png

网络设计地千奇百怪,怎么去缝合是个问题。 


我们以 DeiT 为例,在相同 depth 的缝合实验上采取了 Paired Stitching 这种策略。这种策略的启发来自于过去一些工作发现:相邻 layer 之间的 representation 是有较高的相似度的。所以我们选择在 DeiT 得相邻 blocks 中 share 同一个 stitching layer,如滑窗一般进行 stitching。 

share 的情况下,原先的初始化方法就是简单地对不同 solution 得到的 matrix 做一个 average。选择 share stitching layer 还有其他好处,如减少过多 stitching layer 带来的参数量,同时扩大缝合出来的结构数量,即扩大 stitching space。 

另外一种情况是两个模型的 depth 不一样,小模型一般比较浅,block 的数量要比大模型少。比如 Swin-Ti 的第三个 stage 只有 6 个 block,而 Swin-S 在第三个 stage 有 18 个 block。此时我们进行 Unpaired Stitching,每个小模型的 block 都 stitch 到大模型的若干个 block 中。这样两个 case 就都解决了。

5. SN-Net能缝出来多少网络?

这个由多种因素决定。

看选择的 model family,即 anchors 的 depth。显然 anchor 越深,那么能 stitch 的位置就越多,新网络结构也会更多。

相同 depth 下看 stitching 时 sliding window 的设置。 

不加 nearest stitching 的时候得到的网络更多 (DeiT 上的实验是十倍的差距,71 vs. 731)。但是此时不 optimal。后续潜力尚待挖掘。 

对比 NAS 中级别的s earch space,SN-Net 在基于同一组 model family 得到的网络数量是有限的。但有一点不得不提,纵使 search space 再大,真正需要的时候也只是用 pareto frontier 上的网络结构,而 SN-Net 缝合出来的网络几乎天然落在 pareto frontier 上,同时部署的时候完全可以直接查表,几乎没有什么 search cost。 

另外一点是,SN-Net 的潜力在于整个 pretrained model zoo。有多少 model familiy,就有多少潜在的 SN-Net 变种。这是 NAS 的单一 supernet 所不能比拟的。

这意味着我们可以轻易缝合已有的 model family 达到 NAS 耗费大量计算资源搜出来的网络性能,比如简单缝合两个 LeViT 就可以用更低的 FLOPs(977M vs. 1040M)达到媲美 BigNASModel-XL 的性能(80.7% vs. 80.9%),如下图所示。

49-ZwRwUlcvR3.jpg

6. 简单的训练策略

训练 SN-Net 尤为简单。先提前把所有需要训练的 stitches 定义好,训练中每次 iteration 都随机 sample 出来一个 stitch,后面和正常的训练一样进行 loss回传,梯度下降。为了进一步提升 stitches 的性能,我们初步实验同时采用了 RegNetY-160 作为 teacher model 去做 distillation。

50-kk59JS2tX6.png

51-x273XEQcrz.png

结果展示


为了验证 Joint Training 和原有网络从头 train 的差距,我们选择了若干个和 stitches 相同的网络结构,然后在 ImageNet 上训满 300 epochs。从下表可以看到,对比用了大量计算资源训练出来的网络,SN-Net 利用已有的 DeiT family 只用 50 个 epoch 就可以得到比肩甚至更好的性能。同时整个网络只要 118.4M 的参数,而这 71 个 stitches 的总量如果单独训练需要 2630M,耗费  71 × 300 epochs,和 SN-Net 比是 22 倍的差距。

52-xjAcnknGdx.png

基于 DeiT 和 Swin Transformer,我们验证了缝合 plain ViT 和 hierarchical ViT 的可行性。性能曲线如在 anchors 中进行插值一般。

53-jwBoEbVaFD.png

值得一提的是,图中不同点所表示的子网络,即 stitch,是可以在运行时随时切换的。这意味着网络在 runtime 完全可以依靠查表进行瞬时推理速度调整。这个是诸多网络无法实现的,但颇具现实意义。比如现在很多手机都有省电模式,一旦进行 power saving, 手机掉帧,系统运行速度变慢,而此时 neural network 也可以调整推理速度,做一个 speed-accuracy 的 trade-off。 

我们当然也尝试了 stitch cnn,甚至不同的 family,结果非常 promising。

54-olaXRd1dbp.png

55-E5eW3CQUMc.png

SN-Net的可扩展空间

SN-Net 生于 large model zoo 的时代。我们初版方法给出了一个最简单的 baseline,相信未来有很大的扩展空间,比如:

1. 当前的训练策略比较简单,每次 iteration sample 出来一个 stitch,但是当 stitches 特别多的时候,可能导致某些 stitch 训练的不够充分,除非增加训练时间。所以训练策略上可以继续改进;

2. anchor 的 performance 会比之前下降一些,虽然不大。直觉上,在 joint training 过程中,anchor 为了保证众多 stitches 的性能在自身 weights 上做了一些 trade-off。目前补充材料里发现 finetune 更多 epoch 可以把这部分损失补回来;

3. 不用 nearest stitching 可以明显扩大 space,但此时大部分网络不在 pareto frontier 上,未来可以结合训练策略进行改进,或者在其他地方发现 advantage;

4. 未来能否有个更好方法和统一的框架去缝合任意网络。到那时,整个 model zoo 就像积木一样,可操作空间更大,玩法更多,这一点 NUS 的 Xingyi Yang 之前有尝试,参考 Deep Model Reassembly. 

二维码 扫码微信阅读
推荐阅读 更多