深度学习-蒸馏、剪枝、量化

蒸馏

Distilling the Knowledge in a Neural Network

概述

基本原理

教师模型与学生模型
  • 教师模型(Teacher Model,Net-T):通常是一个大型且高精度的模型,能够通过训练获得较强的性能。教师模型可以是任何一种复杂的模型,通常是深度神经网络。

  • 学生模型(Student Model,Net-S):是一个结构更小、计算需求更少的模型,旨在通过学习教师模型的知识来达到类似的性能,甚至在某些情况下超越教师模型。

知识迁移

蒸馏的核心思想是通过将教师模型的输出信息(通常是软标签,soft labels)传递给学生模型。软标签不同于硬标签(即分类任务中的目标标签),它包含了更多关于类别之间关系的信息。教师模型的输出通常会比学生模型的输出更平滑、更有信息,因此能够帮助学生模型学到更为细致的特征表示。

软标签(Soft Labels)与温度(Temperature)
  • 软标签:是教师模型经过软化处理后的预测结果。通常,教师模型的输出通过一个温度参数进行调整,使得原本很尖锐的概率分布变得更加平滑,从而为学生模型提供更多的训练信息。

  • 温度(T):在教师模型的输出层加入温度参数(T),通常使用Softmax函数生成概率分布时引入温度。较高的温度会使得教师模型的输出分布变得更加平滑,从而提供更多的类别之间的相关性。

蒸馏损失(Distillation Loss)

蒸馏损失函数通常由两部分组成:

  • 学生模型的标准损失:如交叉熵损失,基于训练数据的硬标签进行计算。
  • 教师模型与学生模型之间的差异:通过教师模型的软标签和学生模型输出之间的差异进行计算,常见的损失函数是Kullback-Leibler散度(KL Divergence)。

目标:通过蒸馏,学生模型能够从教师模型中获得深层的知识,尤其是关于类别之间关系、数据分布和决策边界的信息,从而在没有显著增加计算成本的情况下,获得较好的性能。

蒸馏流程

准备教师模型

首先训练一个高精度且较为复杂的教师模型。该模型通常在大规模数据集上训练,并且具有较强的预测能力。教师模型的结构可以是深度神经网络,也可以是其他类型的模型。

选择学生模型

选择一个结构更小、参数量更少的学生模型。学生模型可以是较小的神经网络,或者是压缩过的模型。目标是使学生模型尽可能在保持较低计算开销的同时,接近教师模型的性能。

软标签生成与温度设置

使用教师模型进行推理,获取每个训练样本的软标签。这些软标签通常是通过教师模型的输出概率分布获得,Softmax层通常会加入一个温度参数(T)。较大的T值使得输出分布更平滑,从而提供更多的类别之间的信息。

训练学生模型

学生模型同时优化两个目标:

  • 硬标签损失:学生模型根据真实标签进行训练,通常使用交叉熵损失。

  • 蒸馏损失:学生模型根据教师模型的软标签进行训练,通常使用Kullback-Leibler散度(KL散度)来衡量学生模型输出与教师模型输出之间的差异。

蒸馏损失函数通常定义为:

Ldistill=αDKL(q(T),p(T))+(1α)Lhard(y,y^)L_{distill} = \alpha \cdot D_{KL}(q(T), p(T)) + (1 - \alpha) \cdot L_{hard}(y, \hat{y})

其中:

  • DKL(q(T),p(T))D_{KL}(q(T), p(T)) 是教师和学生模型的输出之间的KL散度;
  • Lhard(y,y^)L_{hard}(y, \hat{y}) 是基于硬标签的损失(通常是交叉熵);
  • α\alpha 是调节硬标签损失与蒸馏损失之间权重的超参数。

微调与优化

学生模型训练过程中可能需要进行微调,帮助其更好地适应教师模型的知识,尤其是在损失函数中的平衡调节上。通过不断的训练和调整,学生模型能够逐渐学习到来自教师模型的深层次特征,并优化其性能。

验证和评估

在蒸馏过程完成后,对学生模型进行验证和评估,确保其在测试集上的表现达到预期,并且在推理速度和计算资源上相比教师模型有了明显的提升。

论文

Basic Idea

现有机器学习中, 任何算法都可以用 Ensemble 的方法来提升性能, 但这样做会花费昂贵的计算资源, 并且不利于部署到真实场景中。

作者尝试提出一种把大模型知识尽可能的压缩进单个小模型的方法。

Distillation

人们希望得到的模型并不是在单一的数据集上拟合完美, 而是要求模型具有强大的泛化能力。 在某个问题的某个具体数据及上, 通常训练出的模型与真实问题会存在偏差, 存在一点点过拟合。

那么对于一个有能力的大模型, 就有希望直接利用大模型的知识, 训练一个具有更强泛化能力的小模型, 让小模型直接学习大模型的泛化能力。

从泛化能力的角度来考虑, 知识蒸馏非常像一种正则化手段。

T与Softmax

训练时经常所采用的标签是独热编码, Softmax 会刻意放大 Logits 之间的差距。 这使得模型输出的类别概率在某一类是非常大的 (文中也称为 Hard Target), 其他类别的概率都非常小。

但不同类别之间的相对概率仍然很重要, 例如猫的图片可能与狗有一定相似, 它一定比和苹果的相似性要低。 这种类别概率差异仍然可能存在着一些隐含的知识, 但它会被 Softmax 所抹除掉, 所以需要一些手段把这种知识传授给小模型。

一种可以尝试的方法是把大模型 (Teacher) 的预测结果和大模型的知识作为小模型 (Student) 的 Target, 即将处理过后的大模型 Logits 作为 Label 或 Label 的一部分训练小模型。

既然是 Softmax 抹除了不同类别之间的差异, 那么可以对 Softmax 改动, 弱化其对隐含知识的影响。

假设神经网络在没有经过 Softmax 前的 Logits 记为 ziz_i, 我们可以添加” 温度TT 来弱化影响, 记结果为 qiq_i:

qi=exp(zi)jexp(zj)qi=exp(zi/T)jexp(zj/T)q_{i}=\frac{\exp \left(z_{i}\right)}{\sum_{j} \exp \left(z_{j}\right)} \quad \rightarrow \quad q_{i}=\frac{\exp \left(z_{i} / T\right)}{\sum_{j} \exp \left(z_{j} / T\right)}

T=1T=1 时, 就是正常的 Softmax。 当 T>1T > 1 时, 原来的 Softmax 将变得更加软化, 不同类别之间的差距, 不再显得类别间的差距那么绝对。 Teacher 中的不同类别之间的暗含知识得到一定保留。 因为标签变得软化了, 所以熵更大, 也保存了更多的信息。

如果这个式子不够直观体现出它的作用, 我做出了不同 TT 对 Target qiq_i 的影响变化曲线图:

随着 TT 的增大, 生成的 Soft Target 之间的差距会越来越小, 变得 Softer

Student 将使用 Hard Target 和 Soft Target 共同训练自己, Teacher 软化后的知识将作为损失函数的一部分调节 Student 的参数:

L=Lhard+λLsoft\mathcal{L} = \mathcal{L}_{hard} + \lambda \mathcal{L}_{soft}

λ\lambda 为超参, 用于调节 Teacher Soft Target 的影响占比。 具体来说, Student 应以常温 T=1T=1 用 Hard Target 训练自己, 即损失的第一项。 同时, Teacher 和 Student 的蒸馏时,对 Softmax 加以高温 TT, 即损失的第二项。

蒸馏过程

蒸馏过程的示意图如下:

高温蒸馏过程的目标函数由 distill loss(对应 soft target) 和 student loss(对应 hard target) 加权得到。示意图如上。

L=αLsoft+βLhard=Lhard+λLsoftL=\alpha L_{soft}+\beta L_{hard}=\mathcal{L}_{hard} + \lambda \mathcal{L}_{soft}

  • viv_i : Net-T 的 logits
  • ziz_i: Net-S 的 logits
  • piTp^T_i: Net-T 的在温度 = T 下的 softmax 输出在第 i 类上的值
  • qiTq^T_i: Net-S 的在温度 = T 下的 softmax 输出在第 i 类上的值
  • cic_i: 在第 i 类上的 ground truth 值, ci{0,1}c_i\in\{0,1\}, 正标签取 1,负标签取 0。
  • NN: 总标签数量
  • Net-T 和 Net-S 同时输入 transfer set (这里可以直接复用训练 Net-T 用到的 training set), 用 Net-T 产生的 softmax distribution (with high temperature) 来作为 soft target,Net-S 在相同温度 T 条件下的 softmax 输出和 soft target 的 cross entropy 就是 Loss 函数的第一部分 LsoftL_{soft}

Lsoft=jNpjTlog(qjT)piT=exp(vi/T)kNexp(vk/T)qiT=exp(zi/T)kNexp(zk/T)L_{soft}=-\sum_j^N p^T_j\log(q^T_j),p^T_i=\frac{\exp(v_i/T)}{\sum_k^N \exp(v_k/T)},q^T_i=\frac{\exp(z_i/T)}{\sum_k^N \exp(z_k/T)}

KL散度(Kullback-Leibler Divergence)通常定义为两个概率分布 PPQQ 之间的差异,公式为:

DKL(PQ)=iP(i)logP(i)Q(i)D_{KL}(P \| Q) = \sum_{i} P(i) \log \frac{P(i)}{Q(i)}

在上面公式中:

  • piTp^T_iqiTq^T_i 分别代表概率分布 PPQQ,它们是通过 softmax 函数计算得到的。
  • 你表达的 Lsoft=jNpjTlog(qjT)L_{soft} = -\sum_j^N p^T_j \log(q^T_j)PPQQ 之间的负对数似然。

如果我们将 LsoftL_{soft} 修改为以 pp 为基础的 KL 散度的形式,可以写成:

Lsoft=DKL(PQ)+H(P)L_{soft} = D_{KL}(P \| Q) + H(P)

其中 H(P)H(P) 是分布 PP 的熵。这样可以看出 LsoftL_{soft} 是 KL 散度的一种变体。

  • Net-S 在 T=1 的条件下的 softmax 输出和 ground truth 的 cross entropy 就是 Loss 函数的第二部分 LhardL_{hard}

Lhard=jNcjlog(qj1)qi1=exp(zi)kNexp(zk)L_{hard}=-\sum_j^N c_j\log(q^1_j),q^1_i=\frac{\exp(z_i)}{\sum_k^N \exp(z_k)}

  • 第二部分 Loss LhardL_{hard} 的必要性其实很好理解: Net-T 也有一定的错误率,使用 ground truth 可以有效降低错误被传播给 Net-S 的可能。打个比方,老师虽然学识远远超过学生,但是他仍然有出错的可能,而这时候如果学生在老师的教授之外,可以同时参考到标准答案,就可以有效地降低被老师偶尔的错误 “带偏” 的可能性。

实验发现第二部分所占比重比较小的时候,能产生最好的结果,这是一个经验的结论。一个可能的原因是,由于 soft target 产生的 gradient 与 hard target 产生的 gradient 之间有与 TT 相关的比值。

在训练时, 必须保证 Teacher 和 Student 的温度一致, 当训练完成后, Student 预测不再使用 TT, 或者说训练完成后的推断设置 T=1T=1

同时, 由于温度 TT 的影响, 梯度均缩小了 T2T^2 倍 (详见下一小节最后), 所以在设置 λ\lambda 时, 需要让其尽可能大一些, 或者乘 T2T^2 倍, 才能保证两种损失的贡献度相同。

Matching Logits is a Special Case of Distillation

作者下面证明了直接让 Student 学 Teacher 的 Logits 只是蒸馏的一种特殊情况。

假设我们处理的问题所采用的损失函数是交叉熵 CC, 梯度为 Czi\frac{\partial C}{\partial z_i}, Teacher 模型的 Logits 为 viv_i, 以及其对应的概率为 pip_i, 则有:

Czi=1T(qipi)=1T(ezi/Tjezj/Tevi/Tjevj/T)\frac{\partial C}{\partial z_{i}}=\frac{1}{T}\left(q_{i}-p_{i}\right)=\frac{1}{T}\left(\frac{e^{z_{i} / T}}{\sum_{j} e^{z_{j} / T}}-\frac{e^{v_{i} / T}}{\sum_{j} e^{v_{j} / T}}\right)

TT 相较于 Logits 充分大的时候, 可以使用泰勒展开, 有 ex/T1+x/Te^{x/T}\approx1+x/T:

Czi1T(1+zi/TN+jzj/T1+vi/TN+jvj/T)\frac{\partial C}{\partial z_{i}} \approx \frac{1}{T}\left(\frac{1+z_{i} / T}{N+\sum_{j} z_{j} / T}-\frac{1+v_{i} / T}{N+\sum_{j} v_{j} / T}\right)

当对 Logits 做了零均值假设后, 有 jzj=jvj=0\sum_jz_j=\sum_jv_j=0, 结合上式有:

Czi1NT2(zivi)\frac{\partial C}{\partial z_{i}} \approx \frac{1}{N T^{2}}\left(z_{i}-v_{i}\right)

因此, 在较高的温度 TT 设置下, 蒸馏等价于最小化 12(zivi)2\frac{1}{2}(z_i - v_i)^2, 也就是直接把 Teacher 和 Student 的 Logits 匹配, 所以匹配 Logits 是一种蒸馏的特殊情况。

当温度较低时, 对负样本的关注就比较少, 可能滤去关键信息, 但实际上这有利有弊。 有些负样本的 Logits 应该是非常小的负值, 这种极小的负值在高温时的作用会被放大, 作为强大的噪声影响 Student。 在低温时, 这种噪声将被滤去。

所以温度的选取一般依赖于经验, 不要太高也不要太低

分母上有 T2T^2, 所以在知识蒸馏时, Lsoft\mathcal{L}_{soft} 的影响被缩小了 T2T^2, 所以需要在设置损失项时平衡回来。

关于 “温度” 的讨论

我们都知道 “蒸馏” 需要在高温下进行,那么这个 “蒸馏” 的温度代表了什么,又是如何选取合适的温度?

在回答这个问题之前,先讨论一下温度 T 的特点

1。 原始的 softmax 函数是 T=1T=1 时的特例, T<1T<1 时,概率分布比原始更 “陡峭”, T>1T>1 时,概率分布比原始更 “平缓”。
2。 温度越高,softmax 上各个值的分布就越平均(思考极端情况: (i) T=T=\infty , 此时 softmax 的值是平均分布的;(ii) T0T\rightarrow0,此时 softmax 的值就相当于 argmaxargmax , 即最大的概率处的值趋近于 1,而其他值趋近于 0)
3。 不管温度 T 怎么取值,Soft target 都有忽略相对较小的 pip_i 携带的信息的倾向

温度代表了什么,如何选取合适的温度?

温度的高低改变的是 Net-S 训练过程中对负标签的关注程度: 温度较低时,对负标签的关注,尤其是那些显著低于平均值的负标签的关注较少;而温度较高时,负标签相关的值会相对增大,Net-S 会相对多地关注到负标签。

实际上,负标签中包含一定的信息,尤其是那些值显著高于平均值的负标签。但由于 Net-T 的训练过程决定了负标签部分比较 noisy,并且负标签的值越低,其信息就越不可靠。因此温度的选取比较 empirical,本质上就是在下面两件事之中取舍:

1。 从有部分信息量的负标签中学习 --> 温度要高一些
2。 防止受负标签中噪声的影响 --> 温度要低一些

总的来说,T 的选择和 Net-S 的大小有关,Net-S 参数量比较小的时候,相对比较低的温度就可以了(因为参数量小的模型不能 capture all knowledge,所以可以适当忽略掉一些负标签的信息)

Summary

蒸馏为何有效, 人们还没有彻底摸清其中的作用原理。

甚至单个模型的自蒸馏也是有效的… 这点非常诡异, 为什么模型单单依靠样本本身却无法达到自蒸馏后的效果? 样本之间隐含的差异居然需要自己产生的产物重新喂给自己才能吸收 (反刍)?

读完本论文后, 自然会产生进一步的想法。 直接把 Logits 蒸给小模型效果如何? 能蒸 Logits 为什么不直接蒸 Feature 呢? 要是蒸 Feature 也不够直接的话把参数蒸给小模型是不是也可以? 这些想法确实都可以, 或多或少都有效果。

Reference

https://adaning.github.io/posts/39586.html

https://zhuanlan.zhihu.com/p/102038521