SF估计与重参数化技巧

问题引入

一般而言,许多机器学习/深度学习的优化目标长这样: \[ \mathbb E_{z\sim \mathcal P}[f_\theta(z)]\tag{1}\label{1} \] 训练时用采样近似上述期望: \[ \mathbb E_{z\sim\mathcal P}[f_\theta(z)]\approx\frac{1}{N}\sum_{n=1}^N f_\theta(z^{(n)}),\quad z^{(1)},\ldots,z^{(N)}\sim\mathcal P\tag{2}\label{2} \] 能够这么做的原因在于 \(\eqref{2}\) 式是 \(\eqref{1}\) 式的无偏估计,这一点相信大家都很清楚。但容易被忽略的一点是,由于常用的优化算法一般都是基于梯度的,我们还应保证从样本中计算的梯度也是真实梯度的无偏估计!幸运的是,对于上述形式的一般问题而言,这一点是容易成立的。可以计算真实梯度为: \[ \nabla_\theta\mathbb E_{z\sim \mathcal P}[f_\theta(z)]=\nabla_\theta\left[\int p(z)f_\theta(z)\mathrm dz\right]=\int p(z)\nabla_\theta f_\theta(z) \mathrm dz=\mathbb E_{z\sim \mathcal P}[\nabla_\theta f_\theta(z)]\tag{3}\label{3} \] 而基于样本计算的梯度为: \[ \nabla_\theta\left[\frac{1}{N}\sum_{n=1}^N f_\theta(z^{(n)})\right]=\frac{1}{N}\sum_{n=1}^N\nabla_\theta f_\theta(z^{(n)}),\quad z^{(1)},\ldots,z^{(N)}\sim\mathcal P\tag{4}\label{4} \] 显然 \(\eqref{4}\) 式的确是 \(\eqref{3}\) 式的无偏估计。

然而,在有些问题中(变分推断、强化学习等),优化目标中的采样分布也由参数 \(\theta\) 决定: \[ \mathbb E_{z\sim \mathcal P_\theta}[f_\theta(z)]\tag{5}\label{5} \] 此时一般问题的做法还能够无缝迁移过来吗?我们计算一下真实梯度: \[ \begin{align} \nabla_\theta \mathbb E_{z\sim \mathcal P_\theta}[f_\theta(z)]&=\nabla_\theta\left[\int p_\theta(z)f_\theta(z)\mathrm dz\right]\\ &=\int f_\theta(z)\nabla_\theta p_\theta(z)\mathrm dz+\int p_\theta(z)\nabla_\theta f_\theta(z)\mathrm dz\\ &={\color{purple}{\int f_\theta(z)\nabla_\theta p_\theta(z)\mathrm dz}}+\mathbb E_{z\sim \mathcal P_\theta}[\nabla_\theta f_\theta(z)] \end{align}\tag{6}\label{6} \] 可以看见,与 \(\eqref{3}\) 式相比,\(\eqref{6}\) 式多了紫色的这一项。但是,基于样本计算梯度的结果却依旧是 \(\eqref{4}\) 式,二者对不上了!换句话说,采样是一个不可导的操作,从 \(\mathcal P_\theta\) 中采样本来应该对 \(\theta\) 的更新产生紫色部分的贡献,但是采样阻断了梯度的传播,导致这部分贡献丢失了。下面我们介绍两种解决这个问题的方法——SF 估计和重参数化技巧。

Score Function 估计

事实上,紫色部分是可以处理为期望形式的: \[ \int f_\theta(z)\nabla_\theta p_\theta(z)\mathrm dz=\int f_\theta(z)p_\theta(z)\nabla_\theta\log p_\theta(z)\mathrm dz=\mathbb E_{z\sim\mathcal P_\theta}[f_\theta(z)\nabla_\theta\log p_\theta(z)]\tag{7}\label{7} \] 其中 \(\nabla_\theta\log p_\theta(z)\) 称作 score function. 于是可以通过采样近似 \(\eqref{7}\) 式: \[ \mathbb E_{z\sim\mathcal P_\theta}[f_\theta(z)\nabla_\theta\log p_\theta(z)]\approx\frac{1}{N}\sum_{n=1}^Nf_\theta(z^{(n)})\nabla_\theta\log p_\theta(z^{(n)}),\quad z^{(1)},\ldots,z^{(N)}\sim\mathcal P\tag{8}\label{8} \] 这样就把丢失的紫色部分找了回来。但是实践中,人们发现使用 \(\eqref{8}\) 式做估计的方差往往很大,影响训练的收敛。为了减小方差,一种方法是给 \(f_\theta(z)\) 减去一个 baseline: \[ \mathbb E_{z\sim \mathcal P_\theta}[f_\theta(z)\nabla_\theta\log p_\theta(z)]=\mathbb E_{z\sim \mathcal P_\theta}[(f_\theta(z)-b)\nabla_\theta\log p_\theta(z)]\tag{9} \] 上式成立的原因在于 score function 的期望为 0: \[ \mathbb E_{z\sim\mathcal P_\theta}[\nabla_\theta\log p_\theta(z)]=\int p_\theta(z)\nabla_\theta\log p_\theta(z)\mathrm dz=\int \nabla_\theta p_\theta(z)\mathrm dz=\nabla_\theta\int p_\theta(z)\mathrm dz=\nabla_\theta 1=0\tag{10} \] 那么,只需要找到一个合适的 baseline \(b\) 最小化方差即可: \[ \min_b\ \mathbb E_{z\sim\mathcal P_\theta}\left[((f_\theta(z)-b)\nabla_\theta\log p_\theta(z))^2\right]-\left(\mathbb E_{z\sim\mathcal P_\theta}[(f_\theta(z)-b)\nabla_\theta\log p_\theta(z)]\right)^2\tag{11} \] 这种减小方差的方式在强化学习的策略梯度算法中有所应用。

重参数化技巧

相比 SF 估计利用 score function 把紫色部分估计了出来,重参数化技巧则直接绕过了从 \(\mathcal P_\theta\) 中采样这一不可导操作。我们先从无参数分布 \(\mathcal Q\) 中采样一个 \(\epsilon\),再通过变换 \(z=g_\theta(\epsilon)\) 得到 \(z\),保证 \(z\sim\mathcal P_\theta\) 即可。这样,梯度就能够不经过采样操作传递给 \(\theta\)\[ \nabla_\theta \mathbb E_{z\sim \mathcal P_\theta}[f_\theta(z)]=\nabla_\theta \mathbb E_{\epsilon\sim \mathcal Q}[f_\theta(g_\theta(\epsilon))]=\mathbb E_{\epsilon\sim \mathcal Q}[\nabla_\theta f_\theta(g_\theta(\epsilon))]\tag{12} \] 采样操作在计算图之外

现在的问题就是,怎样确定分布 \(\mathcal Q\) 和变换 \(z=g_\theta(\epsilon)\),使得变换后的结果满足 \(z\sim \mathcal P_\theta\) 呢?这就得具体问题具体分析了。

高斯分布情形

在 VAE 中,\(\mathcal P_\theta\) 是一个高斯分布,即: \[ z\sim \mathcal P_\theta=\mathcal N(\mu_\theta, \sigma^2_\theta) \] 其中 \(\mu_\theta, \sigma^2_\theta\) 由一个 encoder 网络输出得到,\(\theta\) 是这个 encoder 网络的参数。

这是一种较为简单的情形,很容易想到取 \(\epsilon\sim \mathcal Q=\mathcal N(0, 1)\),并作下述变换即可: \[ z=g_\theta(\epsilon)=\sigma_\theta \epsilon+\mu_\theta \]

离散分布情形

假若 \(z\) 是离散随机变量: \[ z\sim \mathcal P_\theta=[p_1,p_2,\ldots,p_k]^T \] 其中 \(\sum_{i=1}^k p_i=1\),那么 Gumbel Max 提供了一种重参数化方式: \[ \hat z=\mathop{\text{argmax}}_{i=1}^k\left[\log p_i-\log(-\log \epsilon_i)\right]\quad\quad \epsilon_i\sim U[0,1]\tag{13} \] 可以证明 \(\hat z\sim \mathcal P_\theta=[p_1,p_2,\ldots,p_k]^T\)​​​.

证明:不妨设 \(\text{argmax}\) 输出为 \(1\),这意味着: \[\log p_1-\log(-\log \epsilon_1)>\log p_j-\log(-\log \epsilon_j)\quad\forall j\neq 1\] 略作化简: \[\epsilon_j<\epsilon_1^{p_j/p_1}\quad\forall j\neq 1\] 因为 \(\epsilon_i\) 都是 \([0,1]\) 上的均匀分布,所以在给定 \(\epsilon_1\) 的条件下,上式成立的条件概率就是: \[\prod_{j\neq 1}\epsilon_1^{p_j/p_1}=\epsilon_1^{1/p_1-1}\] 因此采样结果为 \(1\) 的概率是: \[\int_0^1 \epsilon_1^{1/p_1-1}\mathrm d \epsilon_1=p_1\cdot\left.\epsilon_1^{1/p_1}\right|_0^1=p_1\] 所以说,依据 Gumbel Max 采样和依据 \([p_1,p_2,\ldots,p_k]\) 采样效果相同。

但是这里有个问题,虽然 Gumbel Max 使得采样操作避开了求导,却又引入了 \(\text{argmax}\) 这个不可导操作!因此,我们需要进一步地用可导的 \(\text{softmax}\)\(\text{argmax}\) 做近似(或者更准确地说,是对 \(\text{argmax}\) 对应的那个 \(\text{onehot}\) 向量做近似),我们将下式称为 Gumbel Softmax\[ \text{softmax}\left(\frac{\log p_i-\log (-\log \epsilon_i)}{\tau}\right)\quad\quad \epsilon_i\sim U[0,1]\tag{14}\label{14} \] 其中 \(\tau>0\) 是温度参数,\(\tau\to 0\)\(\text{softmax}\to\text{onehot}\).

source:[5]

总结一下,利用 Gumbel Softmax 估计 \(\mathbb E_{z\sim \mathcal P_\theta}[f_\theta(z)]\) 的流程为:

  1. 采样 \(k\) 个服从 \(U[0,1]\) 的样本 \(\epsilon_i\)
  2. 计算 Gumbel Softmax \(\eqref{14}\) 式,得到一个 \(k\) 维向量 \(\tilde z\)
  3. 得到估计 \(\mathbb E_{z\sim P_\theta}[f_\theta(z)]\approx f_\theta(\tilde z)\)

参考资料

[1] 苏剑林. (Jun. 10, 2019). 《漫谈重参数:从正态分布到Gumbel Softmax 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/6705

[2] The Reparameterization Trick. https://gregorygundersen.com/blog/2018/04/29/reparameterization/

[3] PyTorch 32.Gumbel-Softmax Trick - 科技猛兽的文章 - 知乎 https://zhuanlan.zhihu.com/p/166632315

[4] 盘点深度学习中的不可导操作(次梯度和重参数化) - Houye的文章 - 知乎 https://zhuanlan.zhihu.com/p/97465608

[5] Jang, Eric, Shixiang Gu, and Ben Poole. Categorical reparameterization with gumbel-softmax. arXiv preprint arXiv:1611.01144 (2016).

[6] 【Learning Notes】Gumbel 分布及应用浅析. https://blog.csdn.net/jackytintin/article/details/79364490]

[7] [知识点] Reparametrization tricks重参数技巧讲解及应用 - 救命稻草人来了的文章 - 知乎 https://zhuanlan.zhihu.com/p/35218887


SF估计与重参数化技巧
https://xyfjason.github.io/blog-main/2022/06/22/SF估计与重参数化技巧/
作者
xyfJASON
发布于
2022年6月22日
许可协议