《为什么Adam的Update RMS是0.2?》中,我们用平均场近似估计了Adam的Update RMS。不久后,读者 @EIFY 指出相同的结果已经出现在论文《Rotational Equilibrium: How Weight Decay Balances Learning Across Neural Networks》中。阅读后,笔者发现其中不仅包含了Update RMS的估计,还包含了Weight RMS的估计。

也就是说,AdamW训出来的模型,其权重的RMS是可以事先估计出来一个渐近结果的。大家会不会觉得这个结论有点意外?反正笔者第一次看到它是颇为意外的,直觉上权重模长是模型根据训练集自己学出来的,结果它告诉我这已经隐藏在优化器的超参中,可谓很反直觉了。

这篇文章我们还是用平均场近似方法,来复现对Weight RMS的渐近估计。

滑动视角 #

首先还是来回顾AdamW的更新规则:
\begin{equation}\text{Adam}\color{skyblue}{\text{W}}:=\left\{\begin{aligned}
&\boldsymbol{m}_t = \beta_1 \boldsymbol{m}_{t-1} + \left(1 - \beta_1\right) \boldsymbol{g}_t\\
&\boldsymbol{v}_t = \beta_2 \boldsymbol{v}_{t-1} + \left(1 - \beta_2\right) \boldsymbol{g}_t^2\\
&\hat{\boldsymbol{m}}_t = \boldsymbol{m}_t\left/\left(1 - \beta_1^t\right)\right.\\
&\hat{\boldsymbol{v}}_t = \boldsymbol{v}_t\left/\left(1 - \beta_2^t\right)\right.\\
&\boldsymbol{u}_t =\hat{\boldsymbol{m}}_t\left/\left(\sqrt{\hat{\boldsymbol{v}}_t} + \epsilon\right)\right.\\
&\boldsymbol{\theta}_t = \boldsymbol{\theta}_{t-1} - \eta_t (\boldsymbol{u}_t \color{skyblue}{ + \lambda_t \boldsymbol{\theta}_{t-1}})
\end{aligned}\right.\end{equation}
再次说明,这里加粗符号默认都是$\mathbb{R}^d$的向量,向量的乘除(包括平方、开根号)默认都是Element-wise的Hadamard积/商。

《为什么Adam的Update RMS是0.2?》一样,我们考虑$t\to\infty$(对于$\beta_1,\beta_2$来说)和$\epsilon\to 0$,所以$\boldsymbol{u}_t=\boldsymbol{m}_t/\sqrt{\boldsymbol{v}_t}$。我们暂时先考虑$\eta_t,\lambda_t$都是常数的例子,所以它们的下标可以省略掉,并且记$\beta_3 = 1-\eta\lambda$,我们有
\begin{equation}\boldsymbol{\theta}_t = \beta_3\boldsymbol{\theta}_{t-1} + (1-\beta_3)(-\boldsymbol{u}_t/\lambda)\label{eq:ema-wd}\end{equation}
这个式子表明,我们可以从更新量的滑动平均(Exponential Moving Average,EMA)角度来理解Weight Decay。这是一个很有意义的视角转换,是《How to set AdamW’s weight decay as you scale model and dataset size》《Power Lines: Scaling Laws for Weight Decay and Batch Size in LLM Pre-training》等工作的基础。

加权平均 #

根据式$\eqref{eq:ema-wd}$,我们可以将$\boldsymbol{\theta}_t$展开为加权平均形式
\begin{equation}\boldsymbol{\theta}_t = \beta_3^t\boldsymbol{\theta}_0 + (1-\beta_3)\sum_{i=1}^t \beta_3^{t-i} (-\boldsymbol{u}_i/\lambda)\label{eq:theta-t}\end{equation}
同理,$\boldsymbol{m}_t$和$\boldsymbol{v}_t$也可以展开为
\begin{equation}\boldsymbol{m}_t = (1 - \beta_1)\sum_{i=1}^t \beta_1^{t-i}\boldsymbol{g}_i,\qquad \boldsymbol{v}_t = (1 - \beta_2)\sum_{i=1}^t \beta_2^{t-i}\boldsymbol{g}_i^2\label{eq:mv-roll}\end{equation}
这里有个小细节,$\boldsymbol{\theta}_t$的表达式我们保留了$\boldsymbol{\theta}_0$,但$\boldsymbol{m}_t$和$\boldsymbol{v}_t$的表达式我们没有保留$\boldsymbol{m}_0$和$\boldsymbol{v}_0$,原因有两个:1、$\boldsymbol{m}$和$\boldsymbol{v}$的初始化一般是零;2、即便它们初始化不是零,但对应的$\beta_1^t$和$\beta_2^t$也会足够接近于零,因此初始化的影响可以忽略。

然而,$\boldsymbol{\theta}$是模型权重,它的初始化通常不是零,并且$\beta_3$往往非常接近于1,对于整个训练周期而言,$\beta_3^t$不一定能充分接近于零,因此我们显式保留$\beta_3^t$和$\boldsymbol{\theta}_0$,按需取舍。

快速估计 #

我们的任务是估计Weight RMS,即$\Vert\boldsymbol{\theta}_t\Vert_{RMS}$,顾名思义,它是各个分量的Root Mean Square:
\begin{equation}\Vert\boldsymbol{\theta}\Vert_{RMS} = \sqrt{\frac{1}{d}\sum_{i=1}^d \theta_i^2},\qquad\qquad \text{其中 }\boldsymbol{\theta} = (\theta_1,\theta_2,\cdots,\theta_d)\end{equation}
它跟模长的区别就是多除了个$\sqrt{d}$,所以模长的大部分性质对RMS同样成立。对于$\Vert\boldsymbol{\theta}_t\Vert_{RMS}$,我们有一个快速但不是那么准确的推导方式:直接对式$\eqref{eq:ema-wd}$两边求$\Vert\cdot\Vert_{RMS}^2$,可以得到
\begin{equation}\begin{aligned}
\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 =&\, \Vert\beta_3\boldsymbol{\theta}_{t-1} + (1-\beta_3)(-\boldsymbol{u}_t/\lambda)\Vert_{RMS}^2 \\[5pt]
=&\, \beta_3^2\Vert\boldsymbol{\theta}_{t-1}\Vert_{RMS}^2 + (1-\beta_3)^2\Vert\boldsymbol{u}_t\Vert_{RMS}^2/\lambda^2 - 2\beta_3(1-\beta_3)\boldsymbol{\theta}_{t-1}\cdot\boldsymbol{u}_t/(\lambda d)
\end{aligned}\end{equation}
假设$\boldsymbol{\theta}_{t-1},\boldsymbol{u}_t$近乎正交,那么$\boldsymbol{\theta}_{t-1}\cdot\boldsymbol{u}_t\approx 0$,这在高维空间中通常是一个不错的近似(参考《n维空间下两个随机向量的夹角分布》),然后$\Vert\boldsymbol{u}_t\Vert_{RMS}$我们已经算过了,答案是约等于$\sqrt{\frac{1-\beta_1}{1+\beta_1}}$,最后我们考虑的是趋于稳态的结果,所以$\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2=\Vert\boldsymbol{\theta}_{t-1}\Vert_{RMS}^2$,于是有
\begin{equation}(1-\beta_3^2)\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 \approx (1-\beta_3)^2 \frac{1-\beta_1}{1+\beta_1} /\lambda^2\qquad\Rightarrow\qquad \Vert\boldsymbol{\theta}_t\Vert_{RMS} \approx \sqrt{\frac{1-\beta_1}{1+\beta_1}\frac{\eta}{2\lambda}}\end{equation}
从左式到右式还用到了$\beta_3\approx 1$的近似。最后的结果会有些误差,因为$\boldsymbol{\theta}_t\cdot\boldsymbol{u}_t\approx 0$实际上并不那么成立,但$\Vert\boldsymbol{\theta}_t\Vert_{RMS}\propto \sqrt{\eta/\lambda}$的结论是正确的。类似的推导还出现在《Why Gradients Rapidly Increase Near the End of Training》

更好近似 #

很多情况下我们只需要知道$\Vert\boldsymbol{\theta}_t\Vert_{RMS}\propto \sqrt{\eta/\lambda}$就行了,这是一个比较通用的结论。而对于追求更准确结论的读者来说,我们可以用平均场方法得到一个更好的近似,代价是计算过程会复杂不少,但好处是我们可以获得更多更清晰的认知。

步骤之一 #

我们从式$\eqref{eq:theta-t}$出发,求和这一项,本身就具有加权平均的形式,所以我们先用第一次平均场:
\begin{equation}\underbrace{\frac{1-\beta_3}{1-\beta_3^t}\sum_{i=1}^t \beta_3^{t-i} \boldsymbol{u}_i}_{\text{记为}\bar{\boldsymbol{u}}_t} = \frac{1-\beta_3}{1-\beta_3^t}\sum_{i=1}^t \beta_3^{t-i} \frac{\boldsymbol{m}_i}{\sqrt{\boldsymbol{v}_i}}\approx \frac{\bar{\boldsymbol{m}}_t \,\,\triangleq\,\, \frac{1-\beta_3}{1-\beta_3^t}\sum_{i=1}^t \beta_3^{t-i}\boldsymbol{m}_i}{\sqrt{\bar{\boldsymbol{v}}_t \,\,\triangleq\,\, \frac{1-\beta_3}{1-\beta_3^t}\sum_{i=1}^t \beta_3^{t-i}\boldsymbol{v}_i}}\label{eq:u-bar}\end{equation}
现在再次回到式$\eqref{eq:theta-t}$,由于$\boldsymbol{\theta}_0$是随机的初始化向量,因此可以假设$\boldsymbol{\theta}_0$与$\bar{\boldsymbol{u}}_t$正交,于是我们有
\begin{equation}\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 \approx \beta_3^{2t}\Vert\boldsymbol{\theta}_0\Vert_{RMS}^2 + (1-\beta_3^t)^2 \lambda^{-2}\Vert \bar{\boldsymbol{u}}_t\Vert_{RMS}^2\end{equation}
现在我们要求$\Vert \bar{\boldsymbol{u}}_t\Vert_{RMS}^2$,根据之前的经验,我们需要假设$\boldsymbol{g}_j$独立同分布地服从$\mathcal{N}(\boldsymbol{\mu},\boldsymbol{\sigma}^2)$,然后求
\begin{equation}\mathbb{E}[\bar{\boldsymbol{u}}_t^2] \approx \mathbb{E}\left[\frac{\bar{\boldsymbol{m}}_t^2}{\bar{\boldsymbol{v}}_t}\right] \approx \frac{\mathbb{E}[\bar{\boldsymbol{m}}_t^2]}{\mathbb{E}[\bar{\boldsymbol{v}}_t]}\end{equation}
最后再对$\mathbb{E}[\bar{\boldsymbol{u}}_t^2]$的各个分量求平均,那么就可以作为$\Vert \bar{\boldsymbol{u}}_t\Vert_{RMS}^2$的近似。

步骤之二 #

结合式$\eqref{eq:mv-roll}$,我们得到
\begin{gather}
\sum_{i=1}^t \beta_3^{t-i}\boldsymbol{m}_i = (1 - \beta_1)\sum_{i=1}^t \beta_3^{t-i} \sum_{j=1}^i \beta_1^{i-j}\boldsymbol{g}_j = (1 - \beta_1)\sum_{j=1}^t \frac{\beta_3^{t-j+1} - \beta_1^{t-j+1}}{\beta_3 - \beta_1}\boldsymbol{g}_j\\
\sum_{i=1}^t \beta_3^{t-i}\boldsymbol{v}_i = (1 - \beta_2)\sum_{i=1}^t \beta_3^{t-i} \sum_{j=1}^i \beta_2^{i-j}\boldsymbol{g}_j^2 = (1 - \beta_2)\sum_{j=1}^t \frac{\beta_3^{t-j+1} - \beta_2^{t-j+1}}{\beta_3 - \beta_2}\boldsymbol{g}_j^2\\
\end{gather}
最后一个双重求和化简,如果大家没有思路,可以交给Kimi完成(参考链接)。由上式可知$\bar{\boldsymbol{m}}_t,\bar{\boldsymbol{v}}_t$分别是梯度和梯度平方的加权平均,所以求$\Vert \bar{\boldsymbol{u}}_t\Vert_{RMS}^2$跟《为什么Adam的Update RMS是0.2?》求$\Vert \boldsymbol{u}_t\Vert_{RMS}^2$本质上是一样的,只不过加权系数不同。

步骤之三 #

我们先求分母
\begin{equation}\begin{aligned}
\mathbb{E}[\bar{\boldsymbol{v}}_t] =&\, \frac{(1 - \beta_3)(1 - \beta_2)}{1 - \beta_3^t}\sum_{j=1}^t \frac{\beta_3^{t-j+1} - \beta_2^{t-j+1}}{\beta_3 - \beta_2}\mathbb{E}[\boldsymbol{g}_j^2] \\
=&\, \frac{(1 - \beta_3)(1 - \beta_2)}{1 - \beta_3^t}\sum_{j=1}^t \frac{\beta_3^{t-j+1} - \beta_2^{t-j+1}}{\beta_3 - \beta_2}(\boldsymbol{\mu}^2 + \boldsymbol{\sigma}^2) \\
=&\, \frac{(1 - \beta_3)(1 - \beta_2)}{(1 - \beta_3^t)(\beta_3 - \beta_2)}\left(\frac{\beta_3 - \beta_3^{t+1}}{1 - \beta_3} - \frac{\beta_2 - \beta_2^{t+1}}{1 - \beta_2}\right)(\boldsymbol{\mu}^2 + \boldsymbol{\sigma}^2) \\[5pt]
\approx &\, \boldsymbol{\mu}^2 + \boldsymbol{\sigma}^2
\end{aligned}\end{equation}
最后一步的约等号,是因为实际训练中,$\beta_3$会足够接近于1,而$\beta_2^{t+1}$会足够接近于0,但$\beta_3^{t+1}$不一定,所以我们将$\beta_2^{t+1}$替换成零,并在化简之后将独立的$\beta_3$替换成$1$,最后再加上近似$\beta_3^{t+1}\approx \beta_3^t$。

步骤之四 #

然后是$\mathbb{E}[\bar{\boldsymbol{m}}_t^2] = \mathbb{E}[\bar{\boldsymbol{m}}_t]^2 + \mathbb{V}ar[\bar{\boldsymbol{m}}_t]$,$\mathbb{E}[\bar{\boldsymbol{m}}_t]$的计算跟$\mathbb{E}[\bar{\boldsymbol{v}}_t]$类似,结果是$\boldsymbol{\mu}$,$\mathbb{V}ar[\bar{\boldsymbol{m}}_t]$的计算我们则利用方差的平方可加性:
\begin{equation}\begin{aligned}
\mathbb{V}ar[\bar{\boldsymbol{m}}_t] =&\, \frac{(1 - \beta_3)^2(1 - \beta_1)^2}{(1-\beta_3^t)^2}\sum_{j=1}^t \left(\frac{\beta_3^{t-j+1} - \beta_1^{t-j+1}}{\beta_3 - \beta_1}\right)^2\mathbb{V}ar[\boldsymbol{g}_j] \\
=&\, \frac{(1 - \beta_3)^2(1 - \beta_1)^2}{(1-\beta_3^t)^2}\sum_{j=1}^t \left(\frac{\beta_3^{t-j+1} - \beta_1^{t-j+1}}{\beta_3 - \beta_1}\right)^2 \boldsymbol{\sigma}^2 \\
=&\, \frac{(1 - \beta_3)^2(1 - \beta_1)^2}{(1-\beta_3^t)^2(\beta_3 - \beta_1)^2}\left(\frac{\beta_3^2 - \beta_3^{2(t+1)}}{1 - \beta_3^2} + \frac{\beta_1^2 - \beta_1^{2(t+1)}}{1 - \beta_1^2} - 2\frac{\beta_1\beta_3 - \beta_1^{t+1}\beta_3^{t+1}}{1 - \beta_1\beta_3}\right) \boldsymbol{\sigma}^2 \\[5pt]
\approx &\, (1 - \beta_3)(1 + \beta_3^t)\boldsymbol{\sigma}^2/2(1 - \beta_3^t)
\end{aligned}\end{equation}
取约等号的理由同上。

步骤之五 #

代入上两节的计算结果,我们有
\begin{equation}\mathbb{E}[\bar{\boldsymbol{u}}_t^2] \approx \frac{\boldsymbol{\mu}^2 + (1 - \beta_3)(1 + \beta_3^t)\boldsymbol{\sigma}^2/2(1 - \beta_3^t)}{\boldsymbol{\mu}^2 + \boldsymbol{\sigma}^2}\end{equation}
那么
\begin{equation}\Vert\bar{\boldsymbol{u}}_t\Vert_{RMS}^2 \approx \frac{\Vert\boldsymbol{\mu}\Vert^2/\Vert\boldsymbol{\sigma}\Vert^2 + (1 - \beta_3)(1 + \beta_3^t)/2(1 - \beta_3^t)}{\Vert\boldsymbol{\mu}\Vert^2/\Vert\boldsymbol{\sigma}\Vert^2 + 1} \end{equation}
最终有
\begin{equation}\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 \approx \beta_3^{2t}\Vert\boldsymbol{\theta}_0\Vert_{RMS}^2 + (1-\beta_3^t)^2 \frac{\Vert\boldsymbol{\mu}\Vert^2/\Vert\boldsymbol{\sigma}\Vert^2 + (1 - \beta_3)(1 + \beta_3^t)/2(1 - \beta_3^t)}{\lambda^2(\Vert\boldsymbol{\mu}\Vert^2/\Vert\boldsymbol{\sigma}\Vert^2 + 1)}\label{eq:theta-rms}\end{equation}

结果浅析 #

式$\eqref{eq:theta-rms}$看起来比较复杂,我们观察几个特例。首先考虑$\boldsymbol{\mu}=\boldsymbol{0}$这个例子,此时
\begin{equation}\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 \approx \beta_3^{2t}\Vert\boldsymbol{\theta}_0\Vert_{RMS}^2 + (1-\beta_3^{2t}) (1 - \beta_3)/2\lambda^2 = \beta_3^{2t}\Vert\boldsymbol{\theta}_0\Vert_{RMS}^2 + (1-\beta_3^{2t}) \eta/2\lambda\label{eq:theta-rms-mu0}\end{equation}
特别地,如果考虑$t\to\infty$,或者$\Vert\boldsymbol{\theta}_0\Vert_{RMS}^2$就初始化为$\eta/2\lambda$,那么就有
\begin{equation}\Vert\boldsymbol{\theta}_t\Vert_{RMS} \approx \sqrt{\frac{\eta}{2\lambda}}\label{eq:theta-rms-simple}\end{equation}
这便是论文《Rotational Equilibrium: How Weight Decay Balances Learning Across Neural Networks》给出的结果,跟原论文的假设一致,它是零均值下的随机游走的稳态结果。如果不考虑$t\to\infty$,而是考虑$\lambda\to 0$的极限,那么由式$\eqref{eq:theta-rms-mu0}$我们将得到
\begin{equation}\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 \approx \Vert\boldsymbol{\theta}_0\Vert_{RMS}^2 + \eta^2 t\end{equation}
这表明在没有Weight Decay的时候,$\Vert\boldsymbol{\theta}_t\Vert_{RMS}$大致按照$\eta\sqrt{t}$的速度增长,这也表明在没有Weight Decay时,我们可以通过设置特定的学习率Schedule来实现Weight RMS的稳定性。另一方面,如果Batch Size足够大,导致信噪比项$\Vert\boldsymbol{\mu}\Vert^2/\Vert\boldsymbol{\sigma}\Vert^2$占主导,那么由式$\eqref{eq:theta-rms}$得
\begin{equation}\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 \approx \beta_3^{2t}\Vert\boldsymbol{\theta}_0\Vert_{RMS}^2 + (1-\beta_3^t)^2 \frac{\Vert\boldsymbol{\mu}\Vert^2/\Vert\boldsymbol{\sigma}\Vert^2}{\lambda^2(\Vert\boldsymbol{\mu}\Vert^2/\Vert\boldsymbol{\sigma}\Vert^2 + 1)}\end{equation}
这个可能适用于模型需要主动增加Weight RMS的特殊情形。不过从经验来看,这种情况发生的概率一般比较小。

模拟实验 #

我们可以用如下模拟脚本,来简单验证上述的准确性:

import numpy as np

N, T = 10000, 100000
beta1, beta2 = 0.9, 0.95
m, v = 0, 0
w = np.random.randn(N) * 0.1
for i in range(T):
    g = np.random.randn(N)
    m = beta1 * m + (1 - beta1) * g
    v = beta2 * v + (1 - beta2) * g**2
    w = w - 0.001 * (m / v**0.5 + 0.1 * w)

weight_rms = (w**2).mean()**0.5
print(weight_rms)

大家可以自行改变权重的初始化或梯度的均值方差等,看最终结果跟式$\eqref{eq:theta-rms}$的吻合程度,笔者自行试了一波,整体来说还是比较靠谱的。

符号版本 #

只需要将前述证明调整一下,就可以适用于“SignSGDM + Weight Decay”的组合:
\begin{equation}\text{SignSGDM}\color{skyblue}{\text{W}}:=\left\{\begin{aligned}
&\boldsymbol{m}_t = \beta_1 \boldsymbol{m}_{t-1} + \left(1 - \beta_1\right) \boldsymbol{g}_t\\
&\boldsymbol{u}_t = \newcommand{sign}{\mathop{\text{sign}}}\sign(\boldsymbol{m}_t)\\
&\boldsymbol{\theta}_t = \boldsymbol{\theta}_{t-1} - \eta_t (\boldsymbol{u}_t \color{skyblue}{ + \lambda_t \boldsymbol{\theta}_{t-1}})
\end{aligned}\right.\end{equation}
修改的地方是由于$\sign(\boldsymbol{m}_t)=\boldsymbol{m}_t/\sqrt{\boldsymbol{m}_t^2}$,所以要将$\bar{\boldsymbol{v}}_t$的定义改为
\begin{equation}\bar{\boldsymbol{v}}_t \triangleq \frac{1-\beta_3}{1-\beta_3^t}\sum_{i=1}^t \beta_3^{t-i}\boldsymbol{m}_i^2\end{equation}
那么
\begin{equation}\mathbb{E}[\bar{\boldsymbol{v}}_t] = \frac{1-\beta_3}{1-\beta_3^t}\sum_{i=1}^t \beta_3^{t-i}\mathbb{E}[\boldsymbol{m}_i^2] \approx \frac{1-\beta_3}{1-\beta_3^t}\sum_{i=1}^t \beta_3^{t-i}\mathbb{E}\left(\boldsymbol{\mu}^2 + \frac{1-\beta_1}{1 + \beta_1}\boldsymbol{\sigma}^2\right) = \boldsymbol{\mu}^2 + \frac{1-\beta_1}{1 + \beta_1}\boldsymbol{\sigma}^2\end{equation}
其中$\mathbb{E}[\boldsymbol{m}_i^2]$的计算我们参考《为什么Adam的Update RMS是0.2?》《重新思考学习率与Batch Size(四):EMA》都行。利用上述结果,我们得到
\begin{equation}\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 \approx \beta_3^{2t}\Vert\boldsymbol{\theta}_0\Vert_{RMS}^2 + (1-\beta_3^t)^2 \frac{\Vert\boldsymbol{\mu}\Vert^2/\Vert\boldsymbol{\sigma}\Vert^2 + (1 - \beta_3)(1 + \beta_3^t)/2(1 - \beta_3^t)}{\lambda^2\left(\Vert\boldsymbol{\mu}\Vert^2/\Vert\boldsymbol{\sigma}\Vert^2 + \frac{1-\beta_1}{1 + \beta_1}\right)}\end{equation}
特别地,考虑$\boldsymbol{\mu}=0,t\to\infty$的极限,我们有
\begin{equation}\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 \approx \sqrt{\frac{\eta}{2\lambda}\frac{1+\beta_1}{1 - \beta_1}}\end{equation}
这个结果也很合理,因为SignSGDMW的Update RMS是AdamW的$\sqrt{\frac{1+\beta_1}{1 - \beta_1}}$倍,所以同样$\eta,\lambda$下它的Weight RMS也是$\sqrt{\frac{1+\beta_1}{1 - \beta_1}}$倍。

相关分析 #

前面说了,结果$\eqref{eq:theta-rms-simple}$跟论文《Rotational Equilibrium: How Weight Decay Balances Learning Across Neural Networks》是一致的,但我们的推导方法是完全不同的,并且能得到更一般的$\eqref{eq:theta-rms}$。不过,原论文也有一些很有意思的地方,比如它所提的 Total Update Contribution (TUC) 概念,就值得赏析一番。

TUC的思想是这样的:由于动量机制的存在,当前的梯度$\boldsymbol{g}_t$不止停留在当前步骤,它还会影响到未来的步骤(但会打个“折扣”),所以假设训练步数趋于无穷,我们可以考虑当前梯度$\boldsymbol{g}_t$对整个训练过程的总贡献。具体来说,对于Adam我们有$\boldsymbol{u}_t=\boldsymbol{m}_t/\sqrt{\boldsymbol{v}_t}$,当前$\boldsymbol{g}_t$对$\boldsymbol{u}_t$的贡献是$(1-\beta_1)\boldsymbol{g}_t/\sqrt{\boldsymbol{v}_t}$,下一步$\boldsymbol{g}_t$将会打个折扣(乘以$\beta_1$),而且分母改为$\boldsymbol{v}_{t+1}$,依此类推,所以可以定义总贡献为
\begin{equation}\tilde{\boldsymbol{u}}_t = \sum_{k=t}^{\infty} (1-\beta_1)\beta_1^{k-t}\frac{\boldsymbol{g}_t}{\sqrt{\boldsymbol{v}_k}}\end{equation}
这样我们就将更新$\boldsymbol{u}_1,\boldsymbol{u}_2,\boldsymbol{u}_3,\cdots$分解为更新$\tilde{\boldsymbol{u}}_1,\tilde{\boldsymbol{u}}_2,\tilde{\boldsymbol{u}}_3,\cdots$,这样的好处是每个$\tilde{\boldsymbol{u}}$只有单步梯度,那么我们就可以重复快速估计一节的推导:
\begin{equation}\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 = \Vert\beta_3\boldsymbol{\theta}_{t-1} + (1-\beta_3)(-\tilde{\boldsymbol{u}}_t/\lambda)\Vert_{RMS}^2 \approx \beta_3^2\Vert\boldsymbol{\theta}_{t-1}\Vert_{RMS}^2 + (1-\beta_3)^2\Vert\tilde{\boldsymbol{u}}_t\Vert_{RMS}^2/\lambda^2 \label{eq:tilde-u-rms}\end{equation}
最后的近似依赖于$\boldsymbol{\theta}_{t-1}\cdot\tilde{\boldsymbol{u}}_t\approx 0$,我们断言$\boldsymbol{\theta}_{t-1}\cdot\tilde{\boldsymbol{u}}_t$比$\boldsymbol{\theta}_{t-1}\cdot\boldsymbol{u}_t$更接近于零,因为$\tilde{\boldsymbol{u}}_t$只依赖于当前梯度$\boldsymbol{g}_t$,而$\boldsymbol{\theta}_{t-1}$还没接触到$\boldsymbol{g}_t$,所以它们是独立的变量,假设$\boldsymbol{g}_t$具有零均值时,$\boldsymbol{\theta}_{t-1}\cdot\tilde{\boldsymbol{u}}_t\approx 0$往往就容易成立了。而为了估计$\Vert\tilde{\boldsymbol{u}}_t\Vert_{RMS}^2$,原论文直接假设$\boldsymbol{g}_t/\sqrt{\boldsymbol{v}_k}$具有相同方向并且单位RMS,于是
\begin{equation}\Vert\tilde{\boldsymbol{u}}_t\Vert_{RMS} = \sum_{k=t}^{\infty} (1-\beta_1)\beta_1^{k-t}\left\Vert\frac{\boldsymbol{g}_t}{\sqrt{\boldsymbol{v}_k}}\right\Vert_{RMS} = \sum_{k=t}^{\infty} (1-\beta_1)\beta_1^{k-t} = 1\end{equation}
代入式$\eqref{eq:tilde-u-rms}$,结合快速估计一节同样的近似处理,解得
\begin{equation}\Vert\boldsymbol{\theta}_t\Vert_{RMS} \approx \sqrt{\frac{\eta}{2\lambda}}\end{equation}
然而,如果局限在原论文看,我们会发现有很多近似是莫名其妙的,比如$\boldsymbol{v}_t$中也有$\boldsymbol{g}_t$,所以说$\tilde{\boldsymbol{u}}_t$只包含当前$\boldsymbol{g}_t$的影响是不大准确的,还有$\Vert\boldsymbol{g}_t/\sqrt{\boldsymbol{v}_k}\Vert_{RMS}=1$的断言也显得比较生硬。但如果放到本文来看,我们会发现在平均场近似下,原论文的各种操作会显得很合理,所以原论文其实已经隐含地用到了平均场方法。

文章小结 #

这篇文章我们用平均场近似推导了一个有趣且可能让人意外的结论:AdamW训出来的模型,其权重的RMS也是可以渐近估计出来的,一般情况下,它只依赖于学习率和Weight Decay。

转载到请包括本文地址:https://www.spaces.ac.cn/archives/11307

更详细的转载事宜请参考:《科学空间FAQ》

如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。

如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!

如果您需要引用本文,请参考:

苏剑林. (Oct. 01, 2025). 《AdamW的Weight RMS的渐近估计(上) 》[Blog post]. Retrieved from https://www.spaces.ac.cn/archives/11307

@online{kexuefm-11307,
        title={AdamW的Weight RMS的渐近估计(上)},
        author={苏剑林},
        year={2025},
        month={Oct},
        url={\url{https://www.spaces.ac.cn/archives/11307}},
}