将“Softmax+交叉熵”推广到多标签分类问题
By 苏剑林 | 2020-04-25 | 328445位读者 |(注:本文的相关内容已整理成论文《ZLPR: A Novel Loss for Multi-label Classification》,如需引用可以直接引用英文论文,谢谢。)
一般来说,在处理常规的多分类问题时,我们会在模型的最后用一个全连接层输出每个类的分数,然后用softmax激活并用交叉熵作为损失函数。在这篇文章里,我们尝试将“Softmax+交叉熵”方案推广到多标签分类场景,希望能得到用于多标签分类任务的、不需要特别调整类权重和阈值的loss。
单标签到多标签 #
一般来说,多分类问题指的就是单标签分类问题,即从$n$个候选类别中选$1$个目标类别。假设各个类的得分分别为$s_1,s_2,
\dots,s_n$,目标类为$t\in\{1,2,\dots,n\}$,那么所用的loss为
\begin{equation}-\log \frac{e^{s_t}}{\sum\limits_{i=1}^n e^{s_i}}= - s_t + \log \sum\limits_{i=1}^n e^{s_i}\label{eq:log-softmax}\end{equation}
这个loss的优化方向是让目标类的得分$s_t$变为$s_1,s_2,\dots,s_t$中的最大值。关于softmax的相关内容,还可以参考《寻求一个光滑的最大值函数》、《函数光滑化杂谈:不可导函数的可导逼近》等文章。
现在我们转到多标签分类问题,即从$n$个候选类别中选$k$个目标类别。这种情况下我们一种朴素的做法是用sigmoid激活,然后变成$n$个二分类问题,用二分类的交叉熵之和作为loss。显然,当$n\gg k$时,这种做法会面临着严重的类别不均衡问题,这时候需要一些平衡策略,比如手动调整正负样本的权重、focal loss等。训练完成之后,还需要根据验证集来进一步确定最优的阈值。
这时候,一个很自然的困惑就是:为什么“$n$选$k$”要比“$n$选$1$”多做那么多工作?
笔者认为这是很不科学的事情,毕竟直觉上$n$选$k$应该只是$n$选$1$自然延伸,所以不应该要比$n$要多做那么多事情,就算$n$选$k$要复杂一些,难度也应该是慢慢过渡的,但如果变成多个二分类的话,$n$选$1$反而是最难的,因为这时候类别最不均衡。而从形式上来看,单标签分类比多标签分类要容易,就是因为单标签有“Softmax+交叉熵”可以用,它不会存在类别不平衡的问题,而多标签分类中的“sigmoid+交叉熵”就存在不平衡的问题。
所以,理想的解决办法应该就是将“Softmax+交叉熵”推广到多标签分类上去。
众里寻她千百度 #
为了考虑这个推广,笔者进行了多次尝试,也否定了很多结果,最后确定了一个相对来说比较优雅的方案:构建组合形式的softmax来作为单标签softmax的推广。在这部分内容中,我们会先假设$k$是一个固定的常数,然后再讨论一般情况下$k$的自动确定方案,最后确实能得到一种有效的推广形式。
组合softmax #
首先,我们考虑$k$是一个固定常数的情景,这意味着预测的时候,我们直接输出得分最高的$k$个类别即可。那训练的时候呢?作为softmax的自然推广,我们可以考虑用下式作为loss:
\begin{equation}-\log \frac{e^{s_{t_1}+s_{t_2}+\dots+s_{t_k}}}{\sum\limits_{1\leq i_1 < i_2 < \cdots < i_k\leq n}e^{s_{i_1}+s_{i_2}+\dots+s_{i_k}}}=\log Z_k - (s_{t_1}+s_{t_2}+\dots+s_{t_k})\end{equation}
其中$t_1,t_2,\dots,t_k$是$k$个目标标签,$Z_k = \sum\limits_{1\leq i_1 < i_2 < \cdots < i_k\leq n}e^{s_{i_1}+s_{i_2}+\dots+s_{i_k}}$是配分函数。很显然,上式是以任何$k$个类别总得分$s_{i_1}+s_{i_2}+\dots+s_{i_k}$为基本单位所构造的softmax,所以它算是单标签softmax的合理推广。又或者理解为还是一个单标签分类问题,只不过这是$C_n^k$选$1$问题。
在这个方案之中,比较困难的地方是$Z_k$的计算,它是$C_n^k$项总得分的指数和。不过,我们可以利用牛顿恒等式来帮助我们递归计算。设$S_k = \sum\limits_{i=1}^n e^{k s_i}$,那么
\begin{equation}\begin{aligned}
Z_1 =&\, S_1\\
2Z_2 =&\, Z_1 S_1 - S_2\\
3Z_3 = &\, Z_2 S_1 - Z_1 S_2 + S_3\\
\vdots\\
k Z_k = &\, Z_{k-1} S_1 - Z_{k-2} S_2 + \dots + (-1)^{k-2} Z_1 S_{k-1} + (-1)^{k-1} S_k
\end{aligned}\end{equation}
所以为了计算$Z_k$,我们只需要递归计算$k$步,这可以在合理的时间内计算出来。预测阶段,则直接输出分数最高的$k$个类就行。
自动确定阈值 #
上述讨论的是输出数目固定的多标签分类问题,但一般的多标签分类的目标标签数是不确定的。为此,我们确定一个最大目标标签数$K\geq k$,并添加一个$0$标签作为填充标签,此时loss变为
\begin{equation}\log \overline{Z}_K - (s_{t_1}+s_{t_2}+\dots+s_{t_k}+\underbrace{s_0+\dots+s_0}_{K-k\text{个}})\end{equation}
而
\begin{equation}\begin{aligned}
\overline{Z}_K =&\, \sum\limits_{1\leq i_1 < i_2 < \cdots < i_K\leq n}e^{s_{i_1}+s_{i_2}+\dots+s_{i_K}} + \sum\limits_{0 = i_1 = \dots = i_j < i_{j+1} < \cdots < i_K\leq n}e^{s_{i_1}+s_{i_2}+\dots+s_{i_K}}\\
=&\, Z_K + e^{s_0} \overline{Z}_{K-1}
\end{aligned}\end{equation}
看上去很复杂,其实很简单,还是以$K$个类别总得分为基本单位,但是允许且仅允许$0$类重复出现。预测的时候,仍然是输出分数最大的$K$个类,但允许重复输出$0$类,等价的效果是以$s_0$为阈值,只输出得分大于$s_0$的类。最后的式子显示$\overline{Z}_K$也可以通过递归来计算,所以实现上是没有困难的。
暮然回首阑珊处 #
看上去“众里寻她千百度”终究是有了结果:理论有了,实现也不困难,接下来似乎就应该做实验看效果了吧?效果好的话,甚至可以考虑发paper了吧?看似一片光明前景呢!然而~
幸运或者不幸,在验证了它的有效性的同时,笔者请教了一些前辈大神,在他们的提示下翻看了之前没细看的Circle Loss,看到了它里边统一的loss形式(原论文的公式(1)),然后意识到了这个统一形式蕴含了一个更简明的推广方案。
所以,不幸的地方在于,已经有这么一个现成的更简明的方案了,所以不管如何“众里寻她千百度”,都已经没有太大意义了;而幸运的地方在于,还好找到了这个更好的方案,要不然屁颠屁颠地把前述方案写成文章发出来,还不如现成的方案简单有效,那时候丢人就丢大发了~
统一的loss形式 #
让我们换一种形式看单标签分类的交叉熵$\eqref{eq:log-softmax}$:
\begin{equation}-\log \frac{e^{s_t}}{\sum\limits_{i=1}^n e^{s_i}}=-\log \frac{1}{\sum\limits_{i=1}^n e^{s_i-s_t}}=\log \sum\limits_{i=1}^n e^{s_i-s_t}=\log \left(1 + \sum\limits_{i=1,i\neq t}^n e^{s_i-s_t}\right)\end{equation}
为什么这个loss会有效呢?在文章《寻求一个光滑的最大值函数》、《函数光滑化杂谈:不可导函数的可导逼近》中我们都可以知道,$\text{logsumexp}$实际上就是$\max$的光滑近似,所以我们有:
\begin{equation}\log \left(1 + \sum\limits_{i=1,i\neq t}^n e^{s_i-s_t}\right)\approx \max\begin{pmatrix}0 \\ s_1 - s_t \\ \vdots \\ s_{t-1} - s_t \\ s_{t+1} - s_t \\ \vdots \\ s_n - s_t\end{pmatrix}\end{equation}
这个loss的特点是,所有的非目标类得分$\{s_1,\cdots,s_{t-1},s_{t+1},\cdots,s_n\}$跟目标类得分$\{s_t\}$两两作差比较,它们的差的最大值都要尽可能小于零,所以实现了“目标类得分都大于每个非目标类的得分”的效果。
所以,假如是有多个目标类的多标签分类场景,我们也希望“每个目标类得分都不小于每个非目标类的得分”,所以下述形式的loss就呼之欲出了:
\begin{equation}\log \left(1 + \sum\limits_{i\in\Omega_{neg},j\in\Omega_{pos}} e^{s_i-s_j}\right)=\log \left(1 + \sum\limits_{i\in\Omega_{neg}} e^{s_i}\sum\limits_{j\in\Omega_{pos}} e^{-s_j}\right)\label{eq:unified}\end{equation}
其中$\Omega_{pos},\Omega_{neg}$分别是样本的正负类别集合。这个loss的形式很容易理解,就是我们希望$s_i < s_j$,就往$\log$里边加入$e^{s_i - s_j}$这么一项。如果补上缩放因子$\gamma$和间隔$m$,就得到了Circle Loss论文里边的统一形式:
\begin{equation}\log \left(1 + \sum\limits_{i\in\Omega_{neg},j\in\Omega_{pos}} e^{\gamma(s_i-s_j + m)}\right)=\log \left(1 + \sum\limits_{i\in\Omega_{neg}} e^{\gamma (s_i + m)}\sum\limits_{j\in\Omega_{pos}} e^{-\gamma s_j}\right)\end{equation}
说个题外话,上式就是Circle Loss论文的公式(1),但原论文的公式(1)不叫Circle Loss,原论文的公式(4)才叫Circle Loss,所以不能把上式叫做Circle Loss。但笔者认为,整篇论文之中最有意思的部分还数公式(1)。
用于多标签分类 #
$\gamma$和$m$一般都是度量学习中才会考虑的,所以这里我们还是只关心式$\eqref{eq:unified}$。如果$n$选$k$的多标签分类中$k$是固定的话,那么直接用式$\eqref{eq:unified}$作为loss就行了,然后预测时候直接输出得分最大的$k$个类别。
对于$k$不固定的多标签分类来说,我们就需要一个阈值来确定输出哪些类。为此,我们同样引入一个额外的$0$类,希望目标类的分数都大于$s_0$,非目标类的分数都小于$s_0$,而前面已经已经提到过,“希望$s_i < s_j$就往$\log$里边加入$e^{s_i - s_j}$”,所以现在式$\eqref{eq:unified}$变成:
\begin{equation}\begin{aligned}
&\log \left(1 + \sum\limits_{i\in\Omega_{neg},j\in\Omega_{pos}} e^{s_i-s_j}+\sum\limits_{i\in\Omega_{neg}} e^{s_i-s_0}+\sum\limits_{j\in\Omega_{pos}} e^{s_0-s_j}\right)\\
=&\log \left(e^{s_0} + \sum\limits_{i\in\Omega_{neg}} e^{s_i}\right) + \log \left(e^{-s_0} + \sum\limits_{j\in\Omega_{pos}} e^{-s_j}\right)\\
\end{aligned}\end{equation}
如果指定阈值为0,那么就简化为
\begin{equation}\log \left(1 + \sum\limits_{i\in\Omega_{neg}} e^{s_i}\right) + \log \left(1 + \sum\limits_{j\in\Omega_{pos}} e^{-s_j}\right)\label{eq:final}\end{equation}
这便是我们最终得到的Loss形式了——“softmax + 交叉熵”在多标签分类任务中的自然、简明的推广,它没有类别不均衡现象,因为它不是将多标签分类变成多个二分类问题,而是变成目标类别得分与非目标类别得分的两两比较,并且借助于$\text{logsumexp}$的良好性质,自动平衡了每一项的权重。
这里给出Keras下的参考实现:
def multilabel_categorical_crossentropy(y_true, y_pred):
"""多标签分类的交叉熵
说明:y_true和y_pred的shape一致,y_true的元素非0即1,
1表示对应的类为目标类,0表示对应的类为非目标类。
警告:请保证y_pred的值域是全体实数,换言之一般情况下y_pred
不用加激活函数,尤其是不能加sigmoid或者softmax!预测
阶段则输出y_pred大于0的类。如有疑问,请仔细阅读并理解
本文。
"""
y_pred = (1 - 2 * y_true) * y_pred
y_pred_neg = y_pred - y_true * 1e12
y_pred_pos = y_pred - (1 - y_true) * 1e12
zeros = K.zeros_like(y_pred[..., :1])
y_pred_neg = K.concatenate([y_pred_neg, zeros], axis=-1)
y_pred_pos = K.concatenate([y_pred_pos, zeros], axis=-1)
neg_loss = K.logsumexp(y_pred_neg, axis=-1)
pos_loss = K.logsumexp(y_pred_pos, axis=-1)
return neg_loss + pos_loss
所以,结论就是 #
所以,最终结论就是式$\eqref{eq:final}$,它就是本文要寻求的多标签分类问题的统一loss,欢迎大家测试并报告效果。笔者也实验过几个多标签分类任务,均能媲美精调权重下的二分类方案。
要提示的是,除了标准的多标签分类问题外,还有一些常见的任务形式也可以认为是多标签分类,比如基于0/1标注的序列标注,典型的例子是笔者的“半指针-半标注”标注设计。因此,从这个角度看,能被视为多标签分类来测试式$\eqref{eq:final}$的任务就有很多了,笔者也确实在之前的三元组抽取例子task_relation_extraction.py中尝试了$\eqref{eq:final}$,最终能取得跟这里一致的效果。
当然,最后还是要说明一下,虽然理论上式$\eqref{eq:final}$作为多标签分类的损失函数能自动地解决很多问题,但终究是不存在绝对完美、保证有提升的方案,所以当你用它替换掉你原来多标签分类方案时,也不能保证一定会有提升,尤其是当你原来已经通过精调权重等方式处理好类别不平衡问题的情况下,式$\eqref{eq:final}$的收益是非常有限的。毕竟式$\eqref{eq:final}$的初衷,只是让我们在不需要过多调参的的情况下达到大部分的效果。
转载到请包括本文地址:https://www.spaces.ac.cn/archives/7359
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Apr. 25, 2020). 《将“Softmax+交叉熵”推广到多标签分类问题 》[Blog post]. Retrieved from https://www.spaces.ac.cn/archives/7359
@online{kexuefm-7359,
title={将“Softmax+交叉熵”推广到多标签分类问题},
author={苏剑林},
year={2020},
month={Apr},
url={\url{https://www.spaces.ac.cn/archives/7359}},
}
January 17th, 2022
[...]苏神在将“softmax+交叉熵”推广到多标签分类问题3中详细推导了交叉熵的计算,并且给出了代码实现。我也是看了苏神的博客,对比了原论文,感受到Circle Loss确实非常漂亮,笔者觉得有两个地方可圈可点。[...]
January 21st, 2022
1、计算损失代码的第一行:y_pred = (1 - 2 * y_true) * y_pred 这个公式是什么意思呢;
2、原来的公式里在log前有加1,为什么实现的时候没有呢
1、y_pred = (1 - 2 * y_true) * y_pred就是这个公式的运算本身的意思,没什么物理含义;
2、实现的时候有加1,要不然为什么要concatenate一个zeros呢?
苏神,这个位置y_pred = (1 - 2 * y_true) * y_pred 是为了将pos和neg区分的更开做的扩大倍数运算是吗?另外y_pred_pos对应的y_true为1位置算出来是个稍小于0的负数,跟之前以0为区分点似乎有点矛盾。
1、y_pred = (1 - 2 * y_true) * y_pred没有扩大倍数,只是根据y_true来修改y_pred的正负号;
2、“y_pred_pos对应的y_true为1位置算出来是个稍小于0的负数”这句没看懂。
嗯,2是我自己的疏忽,不一定是负数。谢谢苏神
只是根据y_true来修改y_pred的正负号;请教一下,这里为神魔需要修改成y_pred 正负号呢?没有理解如此操作的意义
如此操作的意义就是公式是这样写的
March 4th, 2022
苏哥,我有个疑问,比如我的标签总数是n,我可以从下标为0到n-1开始作为正确的标签吗?还是得从1到n标,然后额外引入一个0作为0类?
直接从0~n-1。
好哒!
March 4th, 2022
苏神,我用了一下你那个多标签分类的交叉熵损失函数,结果如下:模型的P值挺高的,但是R值太低了,总体的F值比用多个二分类的损失函数低了2.7个百分点。(PS:S:阈值设定的为0)
就这信息我也无法猜测原因是啥。另外,用这个交叉熵虽然默认阈值为0,但是必要情况下阈值也是可以自己调的,如果recall太低那就降低一点阈值。
March 24th, 2022
苏神,我用了 多标签分类的交叉熵。训练结果很好,但是预测时 有许多图片没有结果,也就是说很多图片的每个标签的预测值都小于0,这种情况有什么好的解决方法吗?
那训练结果很好是怎么体现的?如果有必要的话,可以在预测的时候降低一点阈值。
谢谢,训练结果是从评价指标体现的,我用了acc(标签完全相等)=0.95, MacroF1=0.94。但验证集上的效果这2个值为0.35。
如果是其他方案效果如何呢?(比如多个二分类交叉熵)
一开始使用的“多个二分类交叉熵”,效果和 “多标签分类的交叉熵”差不多。现在改变了数据集,之前的结果也没做保留。
March 30th, 2022
我有一个疑问,就是为什么多分类用softmax+交叉熵不存在类别不平衡,而多标签分类用sigmoid+交叉熵就存在。可不可以理解为多标签分类用sigmoid+交叉熵的时候,其实本质上每个类别是单独做任务的,最后把给各自的损失加起来,所以里面样本占比少的类别在损失函数和中不占优势,可能极端情况下得不到什么学习。而采用softmax+交叉熵的多分类或者说采用(11)式的多标签分类任务,类别和类之间在损失函数中其实在交互拉扯,不再是各自独立了,所以即使占比少也能得到学习,这样类别不平衡问题就不存在了。
可以这么理解。本质上来说,$\text{logsumexp}$是$\max$的光滑近似,所以它相当于每次都取效果最差的(loss最大)的部分学习,所以也就不会因为容易样本太多而造成loss快速趋于0了。
我的理解是softmax+ce里面,对所有正类和负类的梯度是等大(但反向)的。这从本贴公式(6)能很容易证出来,因为S_pos和S_neg是总成对出现的。
然而,在sigmoid+bce中,正类和负类的梯度是独立的,如果负类别数过多,那对于负类的梯度也会远大于正类,导致模型预测值倾向负类。
欢迎讨论和指正。
这种理解没什么问题。
June 1st, 2022
苏神,当前circle loss中多个正标签label都是1,则假定他们都是同等相关的。如果多个正标签之间存在序关系,这种关系怎么在这版loss里面体现出来?使得更相关label的logit更大?最近也试过你刚发表的软标签版本,但是概率p值的设定较难,如果只有序关系,loss有优化的空间吗?
标签的相关性我也思考过,但是还没想到好的解决方案。
June 21st, 2022
[...]不平衡数据的机器学习 - 内存网 内存网首页精品教程数据结构时间复杂度空间复杂度树二叉查找树满二叉树完全二叉树平衡二叉树红黑树B树图队列散列表链表算法基础算法排序算法贪心算法递归算法动态规划分治算法回溯法分支限界法拓扑排序字符串相关算法数组相关算法链表相关算法树相关算法二叉树相关算法LeetCodeOnline Judge剑指offer架构设计设计模式创建型单例模式工厂模式原型模式建造者模式结[...]
September 10th, 2022
苏神,看文章中说该loss也可以用来做0/1序列标注,也就是二分类任务是吗?我在实现过程中二分类任务的y_pred的shape一般为[batch_size,num_labels],y_true的shape为[batch_size],就不能达到y_pred与y_true的shape一致,也就没用成这个loss。不知道是不是理解错了,该loss只能做多标签分类吗?小白问的问题比较低级,忘苏神解惑。
如果二分类的y_pred的shape为[batch_size,num_labels],那么就是说有num_labels个二分类问题,所以你的y_true的shape凭什么是[batch_size]?num_labels个二分类问题,只给1个标签?
感谢苏神回复。我的y_pred的shape为[batch_size,num_labels]是因为还没有经过argmax,是从BERT输出的[batch_size,768]经过linear层[768,num_labels]得到的logits。此时能够使用交叉熵与y_true进行比较从而更新网络,但是如果用您的多标签分类loss的话,我将argmax得到y_pred([batch_size])与y_true进行比较,由于argmax不连续又无法更新网络。所以很困惑该loss是否只能多标签分类。
哦哦,是啊,这就是个多标签分类的loss,一般不适合普通单标签多分类
好的 我还以为0/1序列标注和TO标注一样就是个单标签二分类问题
TO标注又是啥?单标签二分类问题这个没错啊,问题是普通的多标签分类怎么可能在num_labels一维取argmax?
September 14th, 2022
[...]A Survey on Deep Learning for Named Entity Recognition最通俗易懂的BILSTM-CRF的CRF层介绍简明条件随机场CRF介绍 | 附带纯Keras实现keras实现源码BERT标注为何不使用CRFNER综述命名实体识别NERNER论文大礼包nlp中的实体关系抽取方法总结将”softmax+交叉熵”推广到多标签分类问题GlobalPointer:[...]