Reparameterization-Trick

之前在学蒸馏的时候接触了gumbel-softmax,顺势了解了一下重参数技巧,还是很有意思的一个东西

引入

重参数技巧主要是尝试对这样形式的一个东西求梯度 \[ \large L_{\theta} = E_{z\sim p_{\theta}(z)}[f_{\theta}(z)] \quad \quad(1) \] 其中\(z\sim p_{\theta}(z)\)表示随机变量\(z\)服从概率密度函数\(p_{\theta}(z)\),显然这个密度函数是跟模型参数\(\theta\)有关的\(f_{\theta}(z)\)一般可以表示模型某一层关于变量\(z\)的输出,显然它也跟模型参数\(\theta\)有关

不妨先来想想这个式子要如何处理。一个非常naive的思路:采样估计。但是如果直接采样的话,每次采样我们只能获得\(\nabla_\theta f_\theta(z)\),而不同样本之间的信息是无法共用的,我们也就无从得到\(\nabla_\theta L_\theta\)。所以我们想想看,有没有什么好的处理方法,能在估计出\((1)\)式的同时还能保留梯度信息

不妨先来做一个简化,我们先假设\(p_{\theta}(z)\)是一个跟\(\theta\)无关的概率密度函数,简记为\(p(z)\),我们很快注意到现在是可以采样估计梯度了: \[ \begin{flalign} \nabla_{\theta} L_{\theta} &= \nabla_{\theta}E_{z\sim p(z)}[f_{\theta}(z)] = \nabla_{\theta}[\int_z p(z)f_{\theta}(z)dz]\\ &=\int_z p(z)\nabla_{\theta}f_{\theta}(z)dz\\ &=E_{z\sim p(z)}[\nabla_{\theta}f_{\theta}(z)] \end{flalign} \] 从而 \[ \large \nabla_{\theta} L_{\theta} \approx \frac{1}{n}\sum_{i=1}^{n} \nabla_{\theta}f_{\theta}(z_i),z_i\sim p(z) \] 这是因为求梯度的操作成功转移到了\(f_\theta(z)\)上面 上述过程可以用一句话来总结:期望的梯度等于梯度的期望

那我们回到\(p_\theta(z)\),并尝试类似的步骤: \[ \large \begin{flalign} \nabla_{\theta} L_{\theta} &= \nabla_{\theta}E_{z\sim p_\theta(z)}[f_{\theta}(z)] = \nabla_{\theta}[\int_z p_\theta(z)f_{\theta}(z)dz]\\ &=\int_z p_\theta(z)\nabla_{\theta}f_{\theta}(z)dz+\int_z \nabla_{\theta}p_\theta(z)f_{\theta}(z)dz\\ &=E_{z\sim p_\theta(z)}[\nabla_{\theta}f_{\theta}(z)]+\underbrace{\int_z \nabla_{\theta}p_\theta(z)f_{\theta}(z)dz}_{???} \end{flalign} \] 前面一块还是可以仿照之前的处理的,但是后者就显得比较诡异了,求梯度操作转移到\(p_\theta(z)\)上面去,也就意味着我们无法将其整理成正常的关于某个东西的期望的形式。或许我们可以将\(\nabla_{\theta}p_\theta(z)\)求出来,但在大部分情况下这是不现实的。

此时就可以引入重参数技巧了

重参数

顾名思义,我们需要引入新的参数来处理上述问题: 考虑一个新的无参数分布 \[ \large \epsilon\sim{q(\epsilon)} \] 以及变换 \[ \large z = g_\theta(\epsilon) \] 保证变换之后得到的\(z\)服从\(p_\theta\) 那么对\((1)\)式求梯度可以变成: \[ \large \begin{flalign} \nabla_{\theta} L_{\theta} &= \nabla_{\theta}E_{z\sim p_\theta(z)}[f_{\theta}(z)] \\ &= E_{\epsilon\sim q(\epsilon)}[f_\theta(g_\theta(\epsilon))]\quad \quad (a)\\ &=E_{\epsilon\sim q(\epsilon)}[\nabla_{\theta}f_\theta(g_\theta(\epsilon))]\ \ \ (b) \end{flalign} \] 从而 \[ \large \nabla_{\theta} L_{\theta} \approx \frac{1}{n}\sum_{i=1}^{n} \nabla_{\theta}f_\theta(g_\theta(\epsilon_i)),\epsilon_i\sim q(\epsilon) \]

我们就成功实现了在采样的同时保持了梯度

注意,在这个过程中最重要的一步转化就是: \[ \large L_\theta = E_{\epsilon\sim q(\epsilon)}[f_\theta(g_\theta(\epsilon))] \] 它将随机性从参数\(\theta\)转移到了内部无参数的\(\epsilon\)上面,从而可以利用我们之前讨论过的对无参数分布(或者说无可变参数)而言成立的“期望的梯度等于梯度的期望”这一性质来处理

例子

不妨就取\(p_\theta(z)\)是一个正态分布,即 \[ \large p_\theta(z) = N(\mu_\theta,\sigma_\theta^2) \] 那么\(q(\epsilon)\)我们就取标准正态分布 \[ \large q(\epsilon) = N(0,1) \] 那么显然有 \[ \large \sigma_\theta\epsilon+\mu_\theta \sim N(\mu_\theta,\sigma_\theta^2) \] 所以我们就取 \[ \large g_\theta(\epsilon) = \sigma_\theta\epsilon+\mu_\theta \] 最后有 \[ \large E_{z\sim N(\mu_\theta,\sigma_\theta^2)}[f_{\theta}(z)] = E_{\epsilon\sim N(0,1)}[f_\theta(\sigma_\theta\epsilon+\mu_\theta)] \]

离散情况的重参数处理

上述过程处理的是分布为连续密度函数的情况,但我们也经常遇到离散分布的情况,这种该如何处理? 为做区分,我们换一种写法: \[ \large L_{\theta} = E_{y\sim p_{\theta}(y)}[f_{\theta}(y)] = \sum_{y}p_\theta(y)f_\theta(y) \quad \quad (2) \] 一般来说,此时\(y\)是可枚举的,它在大部分情况下都对应了一个k分类问题,也就是说,\(y\)可以表示为 \[ \large p_\theta(y) = softmax(o_1,o_2,...o_k)_y = \frac{1}{\sum e^{o_i}}e^{o_y}\quad \quad(3) \] 其中\(o_i\)一般就是模型的logits,它当然也是关于参数\(\theta\)的函数

还是同一个问题,\((2)\)式直接用求和的形式是没法计算梯度的,我们还是得试试重参数方法。

所以现在问题就变成了: >找到一个合适的无参数分布\(q(\epsilon)\)以及对应的变换\(g_\theta(\epsilon)\)保证它服从\(p_\theta\)这个分布

事实上也确实已经有对应的成果了,它叫做

Gumbel Max

\[ \large \epsilon\sim U(0,1) \] 对应的\(q_{\theta}(\epsilon)\)为: \[ \large argmax_i(log p_i-log(-log \epsilon_i))_{i=1}^{k}\quad \quad (4) \] 这里第\(p_{\theta}(i)\)简记为\(p_i\)了 我们只需证明\((3)\)式与\((4)\)式是同一个分布,即\((4)\)式输出数字\(i\)的概率为\(p_i\)

不失一般性地,我们考虑\((4)\)式输出数字1的概率: 此时意味着\(log p_1-log(-log \epsilon_1)\)\(1-k\)中最大的,即 \[ \large log p_1-log(-log \epsilon_1)\geq log p_i-log(-log \epsilon_i) ,\forall i\in (1,k] \] 得到 \[ \large \epsilon_i\leq \epsilon_1^{p_i/p_1}\leq 1,\forall i\in (1,k] \]\(e_i\sim U(0,1)\),从而 \[ \large P(\epsilon_i\leq \epsilon_1^{p_i/p_1})=\epsilon_1^{p_i/p_1},\forall i\in (1,k] \] 从而\((4)\)式输出1的概率为 \[ \large P(\epsilon_2\leq \epsilon_1^{p_2/p_1},\epsilon_3\leq \epsilon_1^{p_3/p_1},...\epsilon_k\leq \epsilon_1^{p_k/p_1}) = \prod_{i=2}^{k}\epsilon_1^{p_i/p_1}=\epsilon_1^{(1-p_1)/p_1} \]\(\epsilon_1\)的所有情况求个平均,得到 \[ \large \int_0^1 \epsilon_1^{(1-p_1)/p_1}d\epsilon_1 = p_1 \] 这就是\((4)\)式输出1的概率,它恰好为\(p_1\) 从而我们证明了\((4)\)式与\((3)\)式确实是同分布,所以我们就成功找到了合理的无参数分布\(q(\epsilon)\)以及对应的变换\(g_\theta(\epsilon)\) \(\square\) 那么所有过程似乎到这里就圆满结束了。

但是!但是,这里还是有点问题:argmax这个运算本身也是无法求导的... 也就是说,我们将求梯度运算转移到了\(argmax\)运算上面,结果它还是没有办法求梯度? 不过没关系,这一步其实并不是很难处理。我们知道\(argmax\)其实可以扩展成\(one\_hot(argmax)\),而后者的一个光滑近似就是\(softmax\):对于这一点,我相信接触过蒸馏的同学肯定是很清楚的,我们只需要调整蒸馏的温度就能使得\(softmax\)无限趋近于\(ont\_hot\)\(softmax\)显然是可以求梯度的,我们就顺利解决了这个遗留的问题。 这种策略被称为

Gumbel Softmax

具体来说,我们的\(g_\theta(\epsilon)\)要改成: \[ \large softmax_i((log p_i-log(-log \epsilon_i))/\tau)_{i=1}^{k}\quad \quad (5) \] 其中\(\tau\)就是蒸馏的温度,当\(\tau\rightarrow 0\)的时候,\(softmax\)就可以看成\(ont\_hot\),当然此时梯度消失现象也会很严重。 由此我们也可以得到训练策略:对参数\(\tau\)进行退火,最后得到接近于\(ont\_hot\)形式对应的结果。常见的一个退火策略为: \[ \large \tau_p = \tau_0(\tau_p/\tau_0)^{p/P} \] 其中\(\tau_p\)是第\(p\)次训练的温度,\(\tau_0\)是初始温度,\(P\)是总轮数。


总结一下,对于总体的\(k\)个情况,我们从0到1的均匀分布中取\(k\)个值,利用Gumbel softmax得到一个\(k\)维向量\(\tilde{p}\), 那么 \[ \sum_y \tilde{p}_yf_\theta(y) \] 就是\(L_\theta\)的一个良好估计,并且它成功保留了梯度信息

需要指出的是,Gumbel Max是原式的等价形式,但是Gumbel Softmax并不是,它是Gumbel Max的一个光滑近似,当\(\tau\)足够小的时候,它可以近似看成Gumbel Max

顺便提一嘴这个东西为啥叫Gumbel Max/Softmax:

我们仔细观察\((5)\)式: \[ \large softmax_i((log p_i-log(-log \epsilon_i))/\tau)_{i=1}^{k} \] 按照原本的思路,我们可以先从均匀分布里采样\(\epsilon\),然后再做log运算,再做log运算,再与\(logp_i\)做差,不过实际上实际从一个\(-log(-log \epsilon)\)服从的分布里直接采样也是完全OK的,那我们就来看看这个分布长什么样子: 记 \[ x = -log(-log \epsilon) \] 那么 \[ F_X(x) = P_X(X\leq x) = P_\epsilon(-log(-log \epsilon)\leq x) = P_\epsilon(\epsilon\leq e^{-e^{-x}}) = F_\epsilon(e^{-e^{-x}}) \] 从而 \[ F_X(x) = exp(-exp(-x)) \] 这就是这个分布的累积分布函数,它就被称为Gumbel分布。实际上Gumbel分布还带有另外两个参数 \[ F_X(x,\mu,\beta) = exp(-exp(-\frac{x-\mu}{\beta})) \] 也就是说这里是\(\mu=\beta=0\)的特殊情况。不过这一点不必细讲,感兴趣的读者可以再去了解一下。

最后讲一个实现细节: 在求原分布\(q_\theta\)的时候,我们需要从\(\{o_i\}\)出发做softmax得到\(\{p_i\}\),但是实际上\((5)\)式可以直接替换为 \[ \large softmax_i((o_i-log(-log \epsilon_i))/\tau)_{i=1}^{k} \] 那么我们就不必去做softmax了 至于证明其实也很简单: \[ \large log p_i = log(softmax(o_i)) = log(\frac{e^{o_i}}{\sum_j e^{o_j}}) \] 从而 \[ logp_i = o_i-C \] 从而 \[ softmax((logp_i+g_i)/\tau) = \frac{e^{(logp_i+g_i)/\tau}}{\sum_j e^{(logp_j+g_j)/\tau}} = \frac{e^{(o_i-C+g_i)/\tau}}{\sum_j e^{(o_j-C+g_j)/\tau}} \] 显然可以将常数\(C\)对应的部分提出来 \[ = \frac{e^{(o_i+g_i)/\tau}}{\sum_j e^{(o_j+g_j)/\tau}} = softmax((o_i+g_i)/\tau) \] 这里\(g_i\)就指之前讲的Gumbel分布

总结

以上就是重参数在连续和离散两个场景的应用了,它最初也是最多的应用应该是在VAE里面,我以后应该也会接触,到时候也许会对这篇文章加以补充。


Reparameterization-Trick
https://sophilex.github.io/posts/9c7e35cb/
作者
Sophilex
发布于
2024年6月11日
许可协议