Score Function Estimator and Reparameterization Trick

The Problem

许多机器学习/深度学习问题的优化目标为如下形式: 训练时用蒙特卡洛采样近似期望: 能够这么做的原因在于 式是 式的无偏估计。但容易被忽略的一点是,由于常用的优化算法是基于梯度的,我们还应保证从样本中计算的梯度也是真实梯度的无偏估计!幸运的是,对于上述形式的问题而言,这一点是容易成立的。我们可以计算真实梯度为: 而基于样本计算的梯度为: 可见 式的确是 式的无偏估计。

然而,在有些问题中(变分推断、强化学习等),优化目标中的采样分布也由参数 决定: 此时之前的做法还能够无缝迁移过来吗?我们计算一下真实梯度: 可以看见,与 式相比, 式多了紫色一项。但是,基于样本计算梯度的结果却依旧是 式,二者对不上了!这是因为采样是一个不可导的操作,从 中采样本来应该对 的更新产生紫色部分的贡献,但是采样阻断了梯度的传播,导致这部分贡献丢失了。下面我们介绍两种解决这个问题的方法——Score Function Estimator 和重参数化技巧。

Score Function Estimator

Score Function Estimator 是一个常见的技巧,它巧妙地将紫色部分变换为了期望形式: 其中 称作 score function. 于是我们就可以通过蒙特卡洛采样近似 式: 这样就把丢失的紫色部分找了回来。这个方法的特点在于它直接去近似梯度,而非先近似损失函数、再通过自动微分机制算梯度。

实践中,人们发现使用 式做估计的方差很大,影响训练的收敛。为了减小方差,我们可以给 减去一个 baseline 上式成立的原因在于 score function 的期望为 0: 于是,我们只需要找到一个合适的 baseline 最小化方差即可: 强化学习的策略梯度算法就应用了这种减小方差的方式。

The Reparameterization Trick

重参数化技巧 (the reparameterization trick) 的基本思想是:我们先从无参数分布 中采样 ,再通过变换 得到 ,只需设计合适的 以保证 即可。这样,梯度就能够不经过采样操作传递给

现在的问题就是,怎样确定分布 和变换 ,使得变换后的结果满足 呢?这就得具体问题具体分析了,下面介绍两种常见的场景——高斯分布与类别分布。

Gaussian Distribution

是一个高斯分布,即: 这是一种较为简单的情形,只需取 ,并作下述变换即可: 事实上,这种重参数化对位置-尺度分布族 (location-scale distribution family) 都适用。

Categorical Distribution

是离散类别分布,即: 这种情形较为复杂,Gumbel Max 提供了一种巧妙的重参数化方式: 可以证明 .(事实上,在离散选择模型理论中,这种重参数化方法对应多项 logit 模型在 utility 函数选择为 时的特殊情况。)

证明:不妨设 输出为 ,这意味着: 略作化简: 因为 都是 上的均匀分布,所以在给定 的条件下,上式成立的条件概率就是: 因此采样结果为 的概率是: 所以说,依据 Gumbel Max 采样和依据 采样效果相同。

但是现在依旧有个问题,虽然 Gumbel Max 使得采样操作避开了求导,却又引入了 argmax 这个不可导操作!因此,我们还需要用可导的 softmax 对 argmax 做近似(或者更准确地说,是对 argmax 对应的那个 onehot 向量做近似): 其中 是温度参数,当 时 softmax 分布趋近于 onehot 分布。我们称 式为 Gumbel Softmax 分布。值得注意的是, 不再是一个整数 index,而是一个概率向量,因此要求 函数能够处理向量输入。

References

[1] 苏剑林. (Jun. 10, 2019). 《漫谈重参数:从正态分布到Gumbel Softmax 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/6705

[2] The Reparameterization Trick. https://gregorygundersen.com/blog/2018/04/29/reparameterization/

[3] PyTorch 32.Gumbel-Softmax Trick - 科技猛兽的文章 - 知乎 https://zhuanlan.zhihu.com/p/166632315

[4] 盘点深度学习中的不可导操作(次梯度和重参数化) - Houye的文章 - 知乎 https://zhuanlan.zhihu.com/p/97465608

[5] Jang, Eric, Shixiang Gu, and Ben Poole. Categorical reparameterization with gumbel-softmax. arXiv preprint arXiv:1611.01144 (2016).

[6] 【Learning Notes】Gumbel 分布及应用浅析. https://blog.csdn.net/jackytintin/article/details/79364490]

[7] [知识点] Reparametrization tricks重参数技巧讲解及应用 - 救命稻草人来了的文章 - 知乎 https://zhuanlan.zhihu.com/p/35218887


Score Function Estimator and Reparameterization Trick
https://xyfjason.github.io/blog-main/2022/09/06/Score-Function-Estimator-and-Reparameterization-Trick/
作者
xyfJASON
发布于
2022年9月6日
许可协议