为什么Pre Norm的效果不如Post Norm?
By 苏剑林 | 2022-03-29 | 82952位读者 |Pre Norm与Post Norm之间的对比是一个“老生常谈”的话题了,本博客就多次讨论过这个问题,比如文章《浅谈Transformer的初始化、参数化与标准化》、《模型优化漫谈:BERT的初始标准差为什么是0.02?》等。目前比较明确的结论是:同一设置之下,Pre Norm结构往往更容易训练,但最终效果通常不如Post Norm。Pre Norm更容易训练好理解,因为它的恒等路径更突出,但为什么它效果反而没那么好呢?
笔者之前也一直没有好的答案,直到前些时间在知乎上看到 @唐翔昊 的一个回复后才“恍然大悟”,原来这个问题竟然有一个非常直观的理解!本文让我们一起来学习一下。
基本结论 #
Pre Norm和Post Norm的式子分别如下:
\begin{align}
\text{Pre Norm: } \quad \boldsymbol{x}_{t+1} = \boldsymbol{x}_t + F_t(\text{Norm}(\boldsymbol{x}_t))\\
\text{Post Norm: }\quad \boldsymbol{x}_{t+1} = \text{Norm}(\boldsymbol{x}_t + F_t(\boldsymbol{x}_t))
\end{align}
在Transformer中,这里的$\text{Norm}$主要指Layer Normalization,但在一般的模型中,它也可以是Batch Normalization、Instance Normalization等,相关结论本质上是通用的。
在笔者找到的资料中,显示Post Norm优于Pre Norm的工作有两篇,一篇是《Understanding the Difficulty of Training Transformers》,一篇是《RealFormer: Transformer Likes Residual Attention》。另外,笔者自己也做过对比实验,显示Post Norm的结构迁移性能更加好,也就是说在Pretraining中,Pre Norm和Post Norm都能做到大致相同的结果,但是Post Norm的Finetune效果明显更好。
可能读者会反问《On Layer Normalization in the Transformer Architecture》不是显示Pre Norm要好于Post Norm吗?这是不是矛盾了?其实这篇文章比较的是在完全相同的训练设置下Pre Norm的效果要优于Post Norm,这只能显示出Pre Norm更容易训练,因为Post Norm要达到自己的最优效果,不能用跟Pre Norm一样的训练配置(比如Pre Norm可以不加Warmup但Post Norm通常要加),所以结论并不矛盾。
直观理解 #
为什么Pre Norm的效果不如Post Norm?知乎上 @唐翔昊 给出的答案是:Pre Norm的深度有“水分”!也就是说,一个$L$层的Pre Norm模型,其实际等效层数不如$L$层的Post Norm模型,而层数少了导致效果变差了。
具体怎么理解呢?很简单,对于Pre Norm模型我们迭代得到:
\begin{equation}\begin{aligned}
\boldsymbol{x}_{t+1} =&\,\boldsymbol{x}_t + F_t(\text{Norm}(\boldsymbol{x}_t)) \\
=&\, \boldsymbol{x}_{t-1} + F_{t-1}(\text{Norm}(\boldsymbol{x}_{t-1})) + F_t(\text{Norm}(\boldsymbol{x}_t)) \\
=&\, \cdots \\
=&\, \boldsymbol{x}_0 + F_0 (\text{Norm}(\boldsymbol{x}_0)) + \cdots + F_{t-1}(\text{Norm}(\boldsymbol{x}_{t-1})) + F_t(\text{Norm}(\boldsymbol{x}_t))
\end{aligned}\end{equation}
其中每一项都是同一量级的,那么有$\boldsymbol{x}_{t+1}=\mathcal{O}(t+1)$,也就是说第$t+1$层跟第$t$层的差别就相当于$t+1$与$t$的差别,当$t$较大时,两者的相对差别是很小的,因此
\begin{equation}\begin{aligned}
&\,F_t(\text{Norm}(\boldsymbol{x}_t)) + F_{t+1}(\text{Norm}(\boldsymbol{x}_{t+1})) \\
\approx&\,F_t(\text{Norm}(\boldsymbol{x}_t)) + F_{t+1}(\text{Norm}(\boldsymbol{x}_t)) \\
=&\, \begin{pmatrix} 1 & 1\end{pmatrix}\begin{pmatrix} F_t \\ F_{t+1}\end{pmatrix}(\text{Norm}(\boldsymbol{x}_t))
\end{aligned}\end{equation}
这个意思是说,当$t$比较大时,$\boldsymbol{x}_t,\boldsymbol{x}_{t+1}$相差较小,所以$F_{t+1}(\text{Norm}(\boldsymbol{x}_{t+1}))$与$F_{t+1}(\text{Norm}(\boldsymbol{x}_t))$很接近,因此原本一个$t$层的模型与$t+1$层和,近似等效于一个更宽的$t$层模型,所以在Pre Norm中多层叠加的结果更多是增加宽度而不是深度,层数越多,这个层就越“虚”。
说白了,Pre Norm结构无形地增加了模型的宽度而降低了模型的深度,而我们知道深度通常比宽度更重要,所以是无形之中的降低深度导致最终效果变差了。而Post Norm刚刚相反,在《浅谈Transformer的初始化、参数化与标准化》中我们就分析过,它每Norm一次就削弱一次恒等分支的权重,所以Post Norm反而是更突出残差分支的,因此Post Norm中的层数更加“足秤”,一旦训练好之后效果更优。
相关工作 #
前段时间号称能训练1000层Transformer的DeepNet想必不少读者都听说过,在其论文《DeepNet: Scaling Transformers to 1,000 Layers》中对Pre Norm的描述是:
However, the gradients of Pre-LN at bottom layers tend to be larger than at top layers, leading to a degradation in performance compared with Post-LN.
不少读者当时可能并不理解这段话的逻辑关系,但看了前一节内容的解释后,想必会有新的理解。
简单来说,所谓“the gradients of Pre-LN at bottom layers tend to be larger than at top layers”,就是指Pre Norm结构会过度倾向于恒等分支(bottom layers),从而使得Pre Norm倾向于退化(degradation)为一个“浅而宽”的模型,最终不如同一深度的Post Norm。这跟前面的直观理解本质上是一致的。
文章小结 #
本文主要分享了“为什么Pre Norm的效果不如Post Norm”的一个直观理解。
转载到请包括本文地址:https://www.spaces.ac.cn/archives/9009
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Mar. 29, 2022). 《为什么Pre Norm的效果不如Post Norm? 》[Blog post]. Retrieved from https://www.spaces.ac.cn/archives/9009
@online{kexuefm-9009,
title={为什么Pre Norm的效果不如Post Norm?},
author={苏剑林},
year={2022},
month={Mar},
url={\url{https://www.spaces.ac.cn/archives/9009}},
}
September 21st, 2023
您好,今天才发现您的文章,感觉收获颇丰。有一点想请问 - Scaling Laws的其中一个implications是Transformer的loss与参数量相关,而与架构无关(即深度或宽度无所谓,只和总参数量有关)。这样一来是否还能对pre vs post LN得出同样的推论呢?
我并不认为Scaling Law是架构无关的,你确定一层非常大的Attention,效果跟多层总参数量相同的小Attention一样?
详情请参考Section3.1 Approximate Transformer Shape and Hyperparameter Independence,和Section 3.2 Performance with Non-Embedding Parameter Count N。关于您的specific的例子,这句话(来自Section 3.1)似乎给出了回答:“Transformer performance depends very weakly on the shape parameters n_layer, n_heads, and d_ff when we hold the total non-embedding parameter count N fixed.”。不知您的看法为何?
而且更进一步说,如果paper中的看法是正确的,即performance只与(N,D,C)相关 —— parameters N (excluding embeddings), the size of the dataset D, and the amount of compute C used for training,那么由于pre-norm相较于post更容易训练,是否可以得出一个相反的结论 —— 即pre-norm is always preferred?
不好意思我竟然忘了附上paper link:https://arxiv.org/pdf/2001.08361.pdf
问题就出在weakly究竟有多weakly。
我只说一下我看到的。很明显,从GPT、GPT到GPT3,或者从LLAMA-7B到LLAMA-70B,这些模型参数量的增大,同时也伴随着层数的增加,虽然没有消融实验,但大家都这样做,一定程度上表明大家也认可增加深度是有必要的,而不是完全的架构无关。
虽然GPT或者LLAMA系列模型的层数其实也就是几十层,相比之前一些研究中的深(数百层甚至上千层)还差得远,但毕竟不是个位数的一两层,所以还是那句话,很难想象“一层非常大的Attention,效果跟多层总参数量相同的小Attention一样”。
此外,Scaling Law本身就只是一个渐近的规律,而深度的加减本身带来的变化也许是1%级别的,我不清楚这个程度的变化是否能体现在Scaling Law上。
的确,这个架构无关(或者按原话,very weakly dependent)的结论非常地反常识,也许是因此才给我留下深刻的印象了吧。在看到您这篇分析之前,我也对其半信半疑。但当您指出pre-norm的本质与增加宽度相似,再结合业界的模型(包含您所提及的GPT, GPT2, GPT2, LLAMA, LLAMA2),会有一种非常搞笑的可能是是大家都在做宽模型而不自知。
当然这只是我被您inspired的胡思乱想。感谢您的回复,并希望继续交流 :)
I see. 这个思考确实非常有意思,实际上目前scaling law实验的架构共性都非常大(pre norm、数十层的宽度、transformer等),在这些共性之下得出架构弱相关的结果,我认为不算十分意外的事情。
December 25th, 2023
(3) 式里,如果对post norm做类似的展开,就会得到
$x_{t+1} = x_0 + Norm(F_{0}(x_0))+\dots+Norm(F_t(x_t))$
这样看似乎$x_{t+1}$和$x_t$也差距不大?
不好意思,公式写错了~
post还是不一样的,$F_k$会嵌套在$t-k$个Norm里
是的
February 6th, 2024
苏老师您好,对于公式(3)的描述“其中每一项都是同一量级的”该如何理解?这里的”同一量级“是指的什么?
直观理解的话,就是输出的每个分量数量级大致相同,非要用一个指标描述的话,可以考虑输出的向量模长。
February 19th, 2024
"其中每一项都是同一量级的,那么有xt+1=O(t+1),也就是说第t+1层跟第t层的差别就相当于t+1与t的差别"; 这个结论怎么来的??? 什么叫同一量级?,然后就相差不大了?凭什么?
输入经过了Norm之后,基本上能保持同一量级,然后Attention、MLP这些运算,一般不会大幅改动输入数值的量级(否则容易梯度消失或者爆炸),因此输出的范围也大致相同。
这些都是追求一个直观的理解,没法深究。如果追求严谨,可以尝试去定量化证明它,如果觉得不适,那么点击左上角或者右上角的关闭按钮即可。
July 3rd, 2024
式(4),在不考虑激活函数的情况下,F代表的其实是矩阵的线性变换,那么由倒数第二步推导出最后一步应该是AC + BC = (A+B)C,而不应该是A与B矩阵的直和吧?
如果上述描述正确,那么,“因此原本一个t层的模型与t+1层和,近似等效于一个更宽的t层模型”,深度确实会被弱化,但是不是因为变宽(变宽是不是意味着权重矩阵的某个维度变大了?),而是因为深层的不同层的权重矩阵之间发生了矩阵相加操作F1+F2。
这里$F_t\oplus F_{t+1}$确实容易引起歧义。主要想表达的是一个更大的层,我调整一下描述。
September 3rd, 2024
我有个问题想请教下苏老师,如果不同的 layer 用不同的 norm 方式会如何呢,比如前面的层用 post norm 深层 layer 用 pre norm?
这就太不优雅了吧,没考虑过~