Score Function Estimator and Reparameterization Trick
The Problem
一般而言,许多机器学习/深度学习的优化目标长这样: \[ \mathbb E_{z\sim \mathcal P}[f_\theta(z)]\tag{1}\label{1} \] 训练时用蒙特卡洛采样近似上述期望: \[ \mathbb E_{z\sim\mathcal P}[f_\theta(z)]\approx\frac{1}{N}\sum_{n=1}^N f_\theta(z^{(n)}),\quad z^{(1)},\ldots,z^{(N)}\sim\mathcal P\tag{2}\label{2} \] 能够这么做的原因在于 \(\eqref{2}\) 式是 \(\eqref{1}\) 式的无偏估计。但容易被忽略的一点是,由于常用的优化算法一般都是基于梯度的,我们还应保证从样本中计算的梯度也是真实梯度的无偏估计!幸运的是,对于上述形式的一般问题而言,这一点是容易成立的。可以计算真实梯度为: \[ \nabla_\theta\mathbb E_{z\sim \mathcal P}[f_\theta(z)]=\nabla_\theta\left[\int p(z)f_\theta(z)\mathrm dz\right]=\int p(z)\nabla_\theta f_\theta(z) \mathrm dz=\mathbb E_{z\sim \mathcal P}[\nabla_\theta f_\theta(z)]\tag{3}\label{3} \] 而基于样本计算的梯度为: \[ \nabla_\theta\left[\frac{1}{N}\sum_{n=1}^N f_\theta(z^{(n)})\right]=\frac{1}{N}\sum_{n=1}^N\nabla_\theta f_\theta(z^{(n)}),\quad z^{(1)},\ldots,z^{(N)}\sim\mathcal P\tag{4}\label{4} \] 显然 \(\eqref{4}\) 式的确是 \(\eqref{3}\) 式的无偏估计。
然而,在有些问题中(变分推断、强化学习等),优化目标中的采样分布也由参数 \(\theta\) 决定: \[ \mathbb E_{z\sim \mathcal P_{\color{magenta}\theta}}[f_\theta(z)]\tag{5}\label{5} \] 此时一般问题的做法还能够无缝迁移过来吗?我们计算一下真实梯度: \[ \begin{align} \nabla_\theta \mathbb E_{z\sim \mathcal P_\theta}[f_\theta(z)]&=\nabla_\theta\left[\int p_\theta(z)f_\theta(z)\mathrm dz\right]\\ &=\int f_\theta(z)\nabla_\theta p_\theta(z)\mathrm dz+\int p_\theta(z)\nabla_\theta f_\theta(z)\mathrm dz\\ &={\color{purple}{\int f_\theta(z)\nabla_\theta p_\theta(z)\mathrm dz}}+\mathbb E_{z\sim \mathcal P_\theta}[\nabla_\theta f_\theta(z)] \end{align}\tag{6}\label{6} \] 可以看见,与 \(\eqref{3}\) 式相比,\(\eqref{6}\) 式多了紫色的这一项。但是,基于样本计算梯度的结果却依旧是 \(\eqref{4}\) 式,二者对不上了!这是因为采样是一个不可导的操作,从 \(\mathcal P_\theta\) 中采样本来应该对 \(\theta\) 的更新产生紫色部分的贡献,但是采样阻断了梯度的传播,导致这部分贡献丢失了。下面我们介绍两种解决这个问题的方法——Score Function Estimator 和重参数化技巧。
Score Function Estimator
Score Function Estimator 是一个常见的技巧,它巧妙地将紫色部分变换为了期望形式: \[ \int f_\theta(z)\nabla_\theta p_\theta(z)\mathrm dz=\int f_\theta(z)p_\theta(z)\nabla_\theta\log p_\theta(z)\mathrm dz=\mathbb E_{z\sim\mathcal P_\theta}[f_\theta(z)\nabla_\theta\log p_\theta(z)]\tag{7}\label{7} \] 其中 \(\nabla_\theta\log p_\theta(z)\) 称作 score function. 于是我们就可以通过蒙特卡洛采样近似 \(\eqref{7}\) 式: \[ \mathbb E_{z\sim\mathcal P_\theta}[f_\theta(z)\nabla_\theta\log p_\theta(z)]\approx\frac{1}{N}\sum_{n=1}^Nf_\theta(z^{(n)})\nabla_\theta\log p_\theta(z^{(n)}),\quad z^{(1)},\ldots,z^{(N)}\sim\mathcal P\tag{8}\label{8} \] 这样就把丢失的紫色部分找了回来。注意这里我们直接近似了梯度,而非先近似损失函数、再通过自动微分机制算梯度。
实践中,人们发现使用 \(\eqref{8}\) 式做估计的方差很大,影响训练的收敛。为了减小方差,一种方法是给 \(f_\theta(z)\) 减去一个 baseline \(b\): \[ \mathbb E_{z\sim \mathcal P_\theta}[f_\theta(z)\nabla_\theta\log p_\theta(z)]=\mathbb E_{z\sim \mathcal P_\theta}[(f_\theta(z)-b)\nabla_\theta\log p_\theta(z)]\tag{9} \] 上式成立的原因在于 score function 的期望为 0: \[ \mathbb E_{z\sim\mathcal P_\theta}[\nabla_\theta\log p_\theta(z)]=\int p_\theta(z)\nabla_\theta\log p_\theta(z)\mathrm dz=\int \nabla_\theta p_\theta(z)\mathrm dz=\nabla_\theta\int p_\theta(z)\mathrm dz=\nabla_\theta 1=0\tag{10} \] 那么,只需要找到一个合适的 baseline \(b\) 最小化方差即可: \[ \min_b\ \mathbb E_{z\sim\mathcal P_\theta}\left[((f_\theta(z)-b)\nabla_\theta\log p_\theta(z))^2\right]-\left(\mathbb E_{z\sim\mathcal P_\theta}[(f_\theta(z)-b)\nabla_\theta\log p_\theta(z)]\right)^2\tag{11} \] 强化学习的策略梯度算法就应用了这种减小方差的方式。
The Reparameterization Trick
另一种解决方案是重参数化技巧 (the reparameterization trick),其基本思想是:我们先从无参数分布 \(\mathcal Q\) 中采样一个 \(\epsilon\),再通过变换 \(z=g_\theta(\epsilon)\) 得到 \(z\),只需保证 \(z\sim\mathcal P_\theta\) 即可。这样,梯度就能够不经过采样操作传递给 \(\theta\): \[ \nabla_\theta \mathbb E_{z\sim \mathcal P_\theta}[f_\theta(z)]=\nabla_\theta \mathbb E_{\epsilon\sim \mathcal Q}[f_\theta(g_\theta(\epsilon))]=\mathbb E_{\epsilon\sim \mathcal Q}[\nabla_\theta f_\theta(g_\theta(\epsilon))]\tag{12} \]
现在的问题就是,怎样确定分布 \(\mathcal Q\) 和变换 \(z=g_\theta(\epsilon)\),使得变换后的结果满足 \(z\sim \mathcal P_\theta\) 呢?这就得具体问题具体分析了,下面介绍两种常见的场景——高斯分布与类别分布。
Gaussian Distribution
设 \(\mathcal P_\theta\) 是一个高斯分布,即: \[ z\sim \mathcal P_\theta=\mathcal N(\mu, \sigma^2),\quad \theta=(\mu,\sigma^2) \] 这是一种较为简单的情形,只需取 \(\epsilon\sim \mathcal Q=\mathcal N(0, 1)\),并作下述变换即可: \[ z=g_\theta(\epsilon)=\sigma\epsilon+\mu,\quad\epsilon\sim\mathcal N(0,1) \] 事实上,这种重参数化对位置-尺度分布族 (location-scale distribution family) 都适用。
Categorical Distribution
设 \(\mathcal P_\theta\) 是离散类别分布: \[ z\sim \mathcal P_\theta=[p_1,p_2,\ldots,p_k]^T \] 其中 \(\sum_{i=1}^k p_i=1\),那么 Gumbel Max 提供了一种巧妙的重参数化方式: \[ \hat z=\mathop{\text{argmax}}_{i=1}^k\left[\log p_i-\log(-\log \epsilon_i)\right],\quad\epsilon_i\sim U[0,1]\tag{13} \] 可以证明 \(\hat z\sim \mathcal P_\theta=[p_1,p_2,\ldots,p_k]^T\).
证明:不妨设 \(\text{argmax}\) 输出为 \(1\),这意味着: \[\log p_1-\log(-\log \epsilon_1)>\log p_j-\log(-\log \epsilon_j)\quad\forall j\neq 1\] 略作化简: \[\epsilon_j<\epsilon_1^{p_j/p_1}\quad\forall j\neq 1\] 因为 \(\epsilon_i\) 都是 \([0,1]\) 上的均匀分布,所以在给定 \(\epsilon_1\) 的条件下,上式成立的条件概率就是: \[\prod_{j\neq 1}\epsilon_1^{p_j/p_1}=\epsilon_1^{1/p_1-1}\] 因此采样结果为 \(1\) 的概率是: \[\int_0^1 \epsilon_1^{1/p_1-1}\mathrm d \epsilon_1=p_1\cdot\left.\epsilon_1^{1/p_1}\right|_0^1=p_1\] 所以说,依据 Gumbel Max 采样和依据 \([p_1,p_2,\ldots,p_k]\) 采样效果相同。
但是这里有个问题,虽然 Gumbel Max 使得采样操作避开了求导,却又引入了 \(\text{argmax}\) 这个不可导操作!因此,我们还需要用可导的 \(\text{softmax}\) 对 \(\text{argmax}\) 做近似(或者更准确地说,是对 \(\text{argmax}\) 对应的那个 \(\text{onehot}\) 向量做近似)。我们将下式称为 Gumbel Softmax: \[ \tilde z=\text{softmax}\left((\log p_i-\log (-\log \epsilon_i))/\tau\right),\quad \epsilon_i\sim U[0,1]\tag{14}\label{14} \] 其中 \(\tau>0\) 是温度参数,\(\tau\to 0\) 时 \(\text{softmax}\to\text{onehot}\).
总结一下,利用 Gumbel Softmax 估计 \(\mathbb E_{z\sim \mathcal P_\theta}[f_\theta(z)]\) 的流程为:
- 采样 \(k\) 个服从 \(U[0,1]\) 的样本 \(\epsilon_i\).
- 计算 Gumbel Softmax \(\eqref{14}\) 式,得到一个 \(k\) 维向量 \(\tilde z\).
- 得到估计 \(\mathbb E_{z\sim P_\theta}[f_\theta(z)]\approx f_\theta(\tilde z)\).
References
[1] 苏剑林. (Jun. 10, 2019). 《漫谈重参数:从正态分布到Gumbel Softmax 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/6705
[2] The Reparameterization Trick. https://gregorygundersen.com/blog/2018/04/29/reparameterization/
[3] PyTorch 32.Gumbel-Softmax Trick - 科技猛兽的文章 - 知乎 https://zhuanlan.zhihu.com/p/166632315
[4] 盘点深度学习中的不可导操作(次梯度和重参数化) - Houye的文章 - 知乎 https://zhuanlan.zhihu.com/p/97465608
[5] Jang, Eric, Shixiang Gu, and Ben Poole. Categorical reparameterization with gumbel-softmax. arXiv preprint arXiv:1611.01144 (2016).
[6] 【Learning Notes】Gumbel 分布及应用浅析. https://blog.csdn.net/jackytintin/article/details/79364490]
[7] [知识点] Reparametrization tricks重参数技巧讲解及应用 - 救命稻草人来了的文章 - 知乎 https://zhuanlan.zhihu.com/p/35218887