Score Function Estimator and Reparameterization Trick
The Problem
许多机器学习/深度学习问题的优化目标为如下形式: \[ \mathbb E_{z\sim \mathcal P}[f_\theta(z)]\tag{1}\label{1} \] 训练时用蒙特卡洛采样近似期望: \[ \mathbb E_{z\sim\mathcal P}[f_\theta(z)]\approx\frac{1}{N}\sum_{n=1}^N f_\theta(z^{(n)}),\quad z^{(1)},\ldots,z^{(N)}\sim\mathcal P\tag{2}\label{2} \] 能够这么做的原因在于 \(\eqref{2}\) 式是 \(\eqref{1}\) 式的无偏估计。但容易被忽略的一点是,由于常用的优化算法是基于梯度的,我们还应保证从样本中计算的梯度也是真实梯度的无偏估计!幸运的是,对于上述形式的问题而言,这一点是容易成立的。我们可以计算真实梯度为: \[ \nabla_\theta\mathbb E_{z\sim \mathcal P}[f_\theta(z)]=\nabla_\theta\left[\int p(z)f_\theta(z)\mathrm dz\right]=\int p(z)\nabla_\theta f_\theta(z) \mathrm dz=\mathbb E_{z\sim \mathcal P}[\nabla_\theta f_\theta(z)]\tag{3}\label{3} \] 而基于样本计算的梯度为: \[ \nabla_\theta\left[\frac{1}{N}\sum_{n=1}^N f_\theta(z^{(n)})\right]=\frac{1}{N}\sum_{n=1}^N\nabla_\theta f_\theta(z^{(n)}),\quad z^{(1)},\ldots,z^{(N)}\sim\mathcal P\tag{4}\label{4} \] 可见 \(\eqref{4}\) 式的确是 \(\eqref{3}\) 式的无偏估计。
然而,在有些问题中(变分推断、强化学习等),优化目标中的采样分布也由参数 \(\theta\) 决定: \[ \mathbb E_{z\sim \mathcal P_{\color{red}\theta}}[f_\theta(z)]\tag{5}\label{5} \] 此时之前的做法还能够无缝迁移过来吗?我们计算一下真实梯度: \[ \begin{align} \nabla_\theta \mathbb E_{z\sim \mathcal P_\theta}[f_\theta(z)]&=\nabla_\theta\left[\int p_\theta(z)f_\theta(z)\mathrm dz\right]\\ &=\int f_\theta(z)\nabla_\theta p_\theta(z)\mathrm dz+\int p_\theta(z)\nabla_\theta f_\theta(z)\mathrm dz\\ &={\color{purple}{\int f_\theta(z)\nabla_\theta p_\theta(z)\mathrm dz}}+\mathbb E_{z\sim \mathcal P_\theta}[\nabla_\theta f_\theta(z)] \end{align}\tag{6}\label{6} \] 可以看见,与 \(\eqref{3}\) 式相比,\(\eqref{6}\) 式多了紫色一项。但是,基于样本计算梯度的结果却依旧是 \(\eqref{4}\) 式,二者对不上了!这是因为采样是一个不可导的操作,从 \(\mathcal P_\theta\) 中采样本来应该对 \(\theta\) 的更新产生紫色部分的贡献,但是采样阻断了梯度的传播,导致这部分贡献丢失了。下面我们介绍两种解决这个问题的方法——Score Function Estimator 和重参数化技巧。
Score Function Estimator
Score Function Estimator 是一个常见的技巧,它巧妙地将紫色部分变换为了期望形式: \[ \int f_\theta(z)\nabla_\theta p_\theta(z)\mathrm dz=\int f_\theta(z)p_\theta(z)\nabla_\theta\log p_\theta(z)\mathrm dz=\mathbb E_{z\sim\mathcal P_\theta}[f_\theta(z)\nabla_\theta\log p_\theta(z)]\tag{7}\label{7} \] 其中 \(\nabla_\theta\log p_\theta(z)\) 称作 score function. 于是我们就可以通过蒙特卡洛采样近似 \(\eqref{7}\) 式: \[ \mathbb E_{z\sim\mathcal P_\theta}[f_\theta(z)\nabla_\theta\log p_\theta(z)]\approx\frac{1}{N}\sum_{n=1}^Nf_\theta(z^{(n)})\nabla_\theta\log p_\theta(z^{(n)}),\quad z^{(1)},\ldots,z^{(N)}\sim\mathcal P\tag{8}\label{8} \] 这样就把丢失的紫色部分找了回来。这个方法的特点在于它直接去近似梯度,而非先近似损失函数、再通过自动微分机制算梯度。
实践中,人们发现使用 \(\eqref{8}\) 式做估计的方差很大,影响训练的收敛。为了减小方差,我们可以给 \(f_\theta(z)\) 减去一个 baseline \(b\): \[ \mathbb E_{z\sim \mathcal P_\theta}[f_\theta(z)\nabla_\theta\log p_\theta(z)]=\mathbb E_{z\sim \mathcal P_\theta}[(f_\theta(z)-b)\nabla_\theta\log p_\theta(z)]\tag{9} \] 上式成立的原因在于 score function 的期望为 0: \[ \mathbb E_{z\sim\mathcal P_\theta}[\nabla_\theta\log p_\theta(z)]=\int p_\theta(z)\nabla_\theta\log p_\theta(z)\mathrm dz=\int \nabla_\theta p_\theta(z)\mathrm dz=\nabla_\theta\int p_\theta(z)\mathrm dz=\nabla_\theta 1=0\tag{10} \] 于是,我们只需要找到一个合适的 baseline \(b\) 最小化方差即可: \[ \min_b\ \mathbb E_{z\sim\mathcal P_\theta}\left[((f_\theta(z)-b)\nabla_\theta\log p_\theta(z))^2\right]-\left(\mathbb E_{z\sim\mathcal P_\theta}[(f_\theta(z)-b)\nabla_\theta\log p_\theta(z)]\right)^2\tag{11} \] 强化学习的策略梯度算法就应用了这种减小方差的方式。
The Reparameterization Trick
重参数化技巧 (the reparameterization trick) 的基本思想是:我们先从无参数分布 \(\mathcal Q\) 中采样 \(\epsilon\),再通过变换 \(z=g_\theta(\epsilon)\) 得到 \(z\),只需设计合适的 \(\mathcal Q\) 和 \(g_\theta\) 以保证 \(z\sim\mathcal P_\theta\) 即可。这样,梯度就能够不经过采样操作传递给 \(\theta\): \[
\nabla_\theta \mathbb E_{z\sim \mathcal P_\theta}[f_\theta(z)]=\nabla_\theta \mathbb E_{\epsilon\sim \mathcal Q}[f_\theta(g_\theta(\epsilon))]=\mathbb E_{\epsilon\sim \mathcal Q}[\nabla_\theta f_\theta(g_\theta(\epsilon))]\tag{12}
\]
现在的问题就是,怎样确定分布 \(\mathcal Q\) 和变换 \(z=g_\theta(\epsilon)\),使得变换后的结果满足 \(z\sim \mathcal P_\theta\) 呢?这就得具体问题具体分析了,下面介绍两种常见的场景——高斯分布与类别分布。
Gaussian Distribution
设 \(\mathcal P_\theta\) 是一个高斯分布,即: \[ z\sim \mathcal P_\theta=\mathcal N(\mu, \sigma^2),\quad \theta=(\mu,\sigma^2)\tag{13} \] 这是一种较为简单的情形,只需取 \(\epsilon\sim \mathcal Q=\mathcal N(0, 1)\),并作下述变换即可: \[ z=g_\theta(\epsilon)=\sigma\epsilon+\mu,\quad\epsilon\sim\mathcal N(0,1)\tag{14} \] 事实上,这种重参数化对位置-尺度分布族 (location-scale distribution family) 都适用。
Categorical Distribution
设 \(\mathcal P_\theta\) 是离散类别分布,即: \[ z\sim \mathcal P_\theta=[p_1,p_2,\ldots,p_k],\quad\text{where}\;\sum_{i=1}^k p_i=1\tag{15} \] 这种情形较为复杂,Gumbel Max 提供了一种巧妙的重参数化方式: \[ \begin{align} \hat z&=\mathop{\text{argmax}}_{i=1}^k\left[\log p_i+\gamma_i\right],\quad\gamma_i\sim\text{Gumbel}(0,1)\tag{16}\\ &=\mathop{\text{argmax}}_{i=1}^k\left[\log p_i-\log(-\log \epsilon_i)\right],\quad\epsilon_i\sim U[0,1]\tag{17} \end{align} \] 可以证明 \(\hat z\sim \mathcal P_\theta=[p_1,p_2,\ldots,p_k]\).(事实上,在离散选择模型理论中,这种重参数化方法对应多项 logit 模型在 utility 函数选择为 \(\log p_i\) 时的特殊情况。)
证明:不妨设 \(\text{argmax}\) 输出为 \(1\),这意味着: \[\log p_1-\log(-\log \epsilon_1)>\log p_j-\log(-\log \epsilon_j)\quad\forall j\neq 1\] 略作化简: \[\epsilon_j<\epsilon_1^{p_j/p_1}\quad\forall j\neq 1\] 因为 \(\epsilon_i\) 都是 \([0,1]\) 上的均匀分布,所以在给定 \(\epsilon_1\) 的条件下,上式成立的条件概率就是: \[\prod_{j\neq 1}\epsilon_1^{p_j/p_1}=\epsilon_1^{1/p_1-1}\] 因此采样结果为 \(1\) 的概率是: \[\int_0^1 \epsilon_1^{1/p_1-1}\mathrm d \epsilon_1=p_1\cdot\left.\epsilon_1^{1/p_1}\right|_0^1=p_1\] 所以说,依据 Gumbel Max 采样和依据 \([p_1,p_2,\ldots,p_k]\) 采样效果相同。
但是现在依旧有个问题,虽然 Gumbel Max 使得采样操作避开了求导,却又引入了 argmax 这个不可导操作!因此,我们还需要用可导的 softmax 对 argmax 做近似(或者更准确地说,是对 argmax 对应的那个 onehot 向量做近似): \[ \tilde z=\text{softmax}\left((\log p_i-\log (-\log \epsilon_i))/\tau\right),\quad \epsilon_i\sim U[0,1]\tag{18}\label{18} \] 其中 \(\tau>0\) 是温度参数,当 \(\tau\to 0\) 时 softmax 分布趋近于 onehot 分布。我们称 \(\eqref{18}\) 式为 Gumbel Softmax 分布。值得注意的是,\(\tilde z\) 不再是一个整数 index,而是一个概率向量,因此要求 \(f_\theta\) 函数能够处理向量输入。
References
[1] 苏剑林. (Jun. 10, 2019). 《漫谈重参数:从正态分布到Gumbel Softmax 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/6705
[2] The Reparameterization Trick. https://gregorygundersen.com/blog/2018/04/29/reparameterization/
[3] PyTorch 32.Gumbel-Softmax Trick - 科技猛兽的文章 - 知乎 https://zhuanlan.zhihu.com/p/166632315
[4] 盘点深度学习中的不可导操作(次梯度和重参数化) - Houye的文章 - 知乎 https://zhuanlan.zhihu.com/p/97465608
[5] Jang, Eric, Shixiang Gu, and Ben Poole. Categorical reparameterization with gumbel-softmax. arXiv preprint arXiv:1611.01144 (2016).
[6] 【Learning Notes】Gumbel 分布及应用浅析. https://blog.csdn.net/jackytintin/article/details/79364490]
[7] [知识点] Reparametrization tricks重参数技巧讲解及应用 - 救命稻草人来了的文章 - 知乎 https://zhuanlan.zhihu.com/p/35218887