缓解交叉熵过度自信的一个简明方案
By 苏剑林 | 2023-03-14 | 30279位读者 |众所周知,分类问题的常规评估指标是正确率,而标准的损失函数则是交叉熵,交叉熵有着收敛快的优点,但它并非是正确率的光滑近似,这就带来了训练和预测的不一致性问题。另一方面,当训练样本的预测概率很低时,交叉熵会给出一个非常巨大的损失(趋于$-\log 0^{+}=\infty$),这意味着交叉熵会特别关注预测概率低的样本——哪怕这个样本可能是“脏数据”。所以,交叉熵训练出来的模型往往有过度自信现象,即每个样本都给出较高的预测概率,这会带来两个副作用:一是对脏数据的过度拟合带来的效果下降,二是预测的概率值无法作为不确定性的良好指标。
围绕交叉熵的改进,学术界一直都有持续输出,目前这方面的研究仍处于“八仙过海,各显神通”的状态,没有标准答案。在这篇文章中,我们来学习一下论文《Tailoring Language Generation Models under Total Variation Distance》给出的该问题的又一种简明的候选方案。
结果简介 #
顾名思义,原论文的改动是针对文本生成任务的,理论基础是Total Variation距离(参考《Designing GANs:又一个GAN生产车间》)。但事实上,经过原论文的一系列放缩和简化后,最终结果已经跟Total Variation距离没有明显联系,并且理论上也不限于文本生成任务。所以,本文将它作为一般分类任务的损失函数来看待。
对于数据对$(x,y)$,交叉熵给出的损失函数为
\begin{equation}-\log p_{\theta}(y|x)\end{equation}
原论文的改动很简单,改为
\begin{equation}-\frac{\log \big[\gamma + (1 - \gamma)p_{\theta}(y|x)\big]}{1-\gamma}\label{eq:gamma-ce}\end{equation}
其中$\gamma\in[0,1]$。当$\gamma=0$时,就是普通的交叉熵;当$\gamma=1$时,按极限来算,结果是$-p_{\theta}(y|x)$。
在原论文的实验中,不同任务的$\gamma$选取差别比较大,比如语言模型任务中取到了$\gamma=10^{-7}$,机器翻译任务中取了$\gamma=0.1$,文本摘要任务中取了$\gamma=0.8$。一个可以参考的规律是,如果是从零训练,那么需要选择比较接近于0的$\gamma$,如果是微调训练,那么可以考虑相对大一点的$\gamma$。此外,还有一种比较直观的方案,就是将$\gamma$视为动态参数,从$\gamma=0$开始,随着训练的推进慢慢转向$\gamma=1$,但这样就多了个schedule要调试。
效果上,由于多了个可调的$\gamma$参数,并且原本的交叉熵也包含在里边,所以只要用心去调,一般总有机会调出比交叉熵更好的结果的,这个倒不用太担心。
个人推导 #
怎么理解式$\eqref{eq:gamma-ce}$呢?在《函数光滑化杂谈:不可导函数的可导逼近》中的“正确率”一节,我们推导过正确率的光滑近似是
\begin{equation}\mathbb{E}_{(x,y)\sim \mathcal{D}}[p_{\theta}(y|x)]\end{equation}
所以,如果我们的评估指标是正确率,那么直觉上以$-p_{\theta}(y|x)$为损失函数才对,因为这时候损失函数跟正确率的变化更加同步。然而,事实上是交叉熵的表现往往更好。但交叉熵的出发点只是“更好训练”,所以有时候就会“训过头”了,导致过拟合。所以一个直观的想法就是能否将两个结果“插值”一下,以兼顾两者的优点。
为此,我们考虑两者的梯度【准确率指的是它的负光滑近似$-p_{\theta}(y|x)$】:
\begin{equation}\begin{aligned}
\text{准确率:}&\,\quad-\nabla_{\theta} p_{\theta}(y|x) \\
\text{交叉熵:}&\,\quad-\frac{1}{p_{\theta}(y|x)}\nabla_{\theta} p_{\theta}(y|x)
\end{aligned}\end{equation}
两者就差个$\frac{1}{p_{\theta}(y|x)}$。怎么把$\frac{1}{p_{\theta}(y|x)}$变为1呢?原论文的方案是:
\begin{equation}\frac{1}{\gamma + (1 - \gamma)p_{\theta}(y|x)}\end{equation}
当然这个构造方式不是唯一的,原论文选的这个方式,尽可能地保留了交叉熵的梯度特性,也就尽可能保留了交叉熵收敛快的特点。根据这个构造,我们就希望新损失函数的梯度为
\begin{equation}-\frac{\nabla_{\theta}p_{\theta}(y|x)}{\gamma + (1 - \gamma)p_{\theta}(y|x)} = \nabla_{\theta}\left(-\frac{\log \big[\gamma + (1 - \gamma)p_{\theta}(y|x)\big]}{1-\gamma}\right)\label{eq:gamma-ce-g}\end{equation}
这就找出了损失函数$\eqref{eq:gamma-ce}$,在这个过程中,我们先设计新的梯度,然后通过积分找原函数的方式找到了对应的损失函数。
多扯几句 #
为什么要从梯度角度去设计损失函数呢?大概有两方面的原因。
第一,很多损失函数求了梯度后会得到简化,所以在梯度空间设计,往往有更多的灵感和自由度,比如本文的例子中,在梯度空间设计$\frac{1}{p_{\theta}(y|x)}$与$1$的过渡函数$\frac{1}{\gamma + (1 - \gamma)p_{\theta}(y|x)}$不算太复杂,但如果直接在损失函数空间设计$p_{\theta}(y|x)$和$\log p_{\theta}(y|x)$的过渡函数$\frac{\log \big[\gamma + (1 - \gamma)p_{\theta}(y|x)\big]}{1-\gamma}$就复杂多了。
第二,目前使用的优化器都是基于梯度的,所以很多时候我们设计好梯度就行了,甚至都不必要找出原函数。论文的原始结果实际上就是只给出了梯度:
\begin{equation}-\max\left(b, \frac{p_{\theta}(y|x)}{\gamma + (1 - \gamma)p_{\theta}(y|x)}\right)\nabla_{\theta}\log p_{\theta}(y|x)\end{equation}
当$b=0$时,它就等价于式$\eqref{eq:gamma-ce}$。也就是说,原论文在设计梯度的时候还加了个阈值,这时候就很难写出简单的原函数了。但上式实现上并不困难,只要考虑损失函数
\begin{equation}-\max\left(b, \frac{p_{\theta}(y|x)}{\gamma + (1 - \gamma)p_{\theta}(y|x)}\right)_{\text{stop_grad}}\log p_{\theta}(y|x)\end{equation}
这里边的$\text{stop_grad}$就是直接断掉这部分结果的梯度,在tensorflow中对应着tf.stop_gradient
算子。
文章小结 #
本文主要介绍了缓解交叉熵过度自信的一个简明方案。
转载到请包括本文地址:https://www.spaces.ac.cn/archives/9526
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Mar. 14, 2023). 《缓解交叉熵过度自信的一个简明方案 》[Blog post]. Retrieved from https://www.spaces.ac.cn/archives/9526
@online{kexuefm-9526,
title={缓解交叉熵过度自信的一个简明方案},
author={苏剑林},
year={2023},
month={Mar},
url={\url{https://www.spaces.ac.cn/archives/9526}},
}
March 14th, 2023
苏神大佬,钻个小牛角,八仙过海,各显神通呀~
晕了,这个都能打错,感谢感谢,已经更正~
March 15th, 2023
苏神,这种方案与标签平滑以及focal loss的出发点是不是都是一样的?即,降低噪声标签的影响。 那么他们的核心区别是什么?
本质上来说,这些方法都只是one hot与某种先验的某种平滑方式(加权平均),只不过平滑的方式不一样,这篇文章的思路,大致上可以理解为梯度平滑?
April 13th, 2023
对于一个二分类问题的数据,均值可以完全确定其分布性质。也就是任何一个可收敛的loss,其最优化的点都可以用均值来直接计算。也就是可以用交叉熵学一个数,然后用一个固定的函数变换得到结果,完全等价,中间并没有信息增量。
这样一想,在loss上整的这些花活是不是可以用一个确定的变换来代替啊,省事很多,都不用训模型了。
关键是你固定的函数怎么确定?以及这是梯度下降优化。两个结果即便在数学上等价,在梯度下降优化之下未必等价。
September 6th, 2023
苏神,光滑准确率似乎和交叉熵+off-policy correction是等价的。AlphaCode论文里的GOLD off-policy correction形式如下
$\nabla_\theta L_{gold} (\theta) = - p_\theta(y|x) \nabla_\theta \log p_\theta(y|x)$
正好等于光滑准确率
不知道苏神有没有测试用光滑准确率SFT呢?
而且,$\gamma = 0.8$ 的时候公式(2)对$p_\theta(y|x)$的导数非常接近光滑准确率的$-1$了,在$[-1.25, -1]$之间,感觉差距不大,这意味着微调阶段是不是可以直接用光滑准确率了?
好久不做这类任务了,以前试过用光滑准确率微调的,效果相比交叉熵似乎差点(CLUE榜单)
大概会差多少
印象中多数在1%左右。
October 13th, 2023
[...]众所周知,分类任务的标准损失是交叉熵(Cross Entropy,等价于最大似然MLE,即Maximum Likelihood Estimation),它有着简单高效的特点,但在某些场景下也暴露出一些问题,如偏离评价指标、过度自信等,相应的改进工作也有很多,此前我们也介绍过一些,比如《再谈类别不平衡问题:调节权重与魔改Loss的对比联系》、《如何训练你的准确率?》、《缓解交叉熵过度自信的一个简明方[...]
November 20th, 2023
"当$\gamma = 1$时,按极限来算,结果是$-p\theta(y|x)$",这个地方的$-$号是不是错了
不好意思,少看了式子最前面的$-$号