Reverse-Time SDE
\[ \newcommand{\coloneqq}{\mathrel{\vcenter{:}}\mathrel{=}} \]
问题引入
考虑如下一维 SDE: \[ \mathrm dX_t=f(X_t,t)\mathrm dt+g(X_t,t)\mathrm dW_t \] 其 Fokker-Planck 方程 为: \[ \frac{\partial p(x,t)}{\partial t}=-\frac{\partial}{\partial x}\left[f(x,t)p(x,t)\right]+\frac{1}{2}\frac{\partial^2}{\partial x^2}\left[g^2(x,t)p(x,t)\right]\label{fp}\tag{1} \] FP 方程刻画了分布 \(p(x,t)\) 在 \(t\) 时刻的瞬时变化情况。如果两个 SDE 有着相同的 FP 方程和相同的初始分布 \(p(x,0)\),那么它们在任意时刻 \(t\) 的分布 \(p(x,t)\) 也都相同——在扩散模型中,这个结论被用于寻找对应于给定 SDE 的 Probability Flow ODE(因为 ODE 可以视为扩散项为零的 SDE,此时 FP 方程退化为连续性方程)。
本文关心一个不同但类似的问题——假设系统的总运行时间为 \(T\),能否找到一个 reverse-time SDE,其在 \(T-t\) 时刻的分布与原 SDE 在 \(t\) 时刻的分布相同?如果能够做到这一点,那么 reverse-time SDE 就是我们在扩散模型中要寻求的生成模型。
推导过程
变量代换
视 \(\eqref{fp}\) 式中的时间参数 \(t\) 为系统已运行的时间,考虑将其代换为距离结束还剩下的时间 \(s\coloneqq T-t\),有: \[ \frac{\partial p(x,T-s)}{\partial (T-s)}=-\frac{\partial}{\partial x}\left[f(x,T-s)p(x,T-s)\right]+\frac{1}{2}\frac{\partial^2}{\partial x^2}\left[g^2(x,T-s)p(x,T-s)\right] \] 为书写方便起见,引入记号: \[ \bar f(x,s)\coloneqq f(x,T-s),\quad \bar g(x,s)\coloneqq g(x,T-s),\quad q(x,s)\coloneqq p(x,T-s) \] 则上式改写作: \[ -\frac{\partial q(x,s)}{\partial s}=-\frac{\partial}{\partial x}\left[\bar f(x,s)q(x,s)\right]+\frac{1}{2}\frac{\partial^2}{\partial x^2}\left[\bar g^2(x,s)q(x,s)\right] \] 取个负号得: \[ \frac{\partial q(x,s)}{\partial s}=\frac{\partial}{\partial x}\left[\bar f(x,s)q(x,s)\right]-\frac{1}{2}\frac{\partial^2}{\partial x^2}\left[\bar g^2(x,s)q(x,s)\right]\label{fp-change}\tag{2} \] 注意,截至目前我们所做的只是一堆变量代换而已,还没有涉及任何与 reverse-time SDE 相关的东西。
大胆猜想
注意到 \(\eqref{fp-change}\) 式与 FP 方程的形式还挺像的,于是我们有了一个大胆的猜想:是否真的存在一个以 \(s\) 为时间参数的 SDE,其 FP 方程正好就是 \(\eqref{fp-change}\) 式呢?具体而言,设有如下 SDE: \[ \mathrm dY_s=b(Y_s,s)\mathrm ds+\sigma(Y_s,s)\mathrm dW_s \] 那么它的 FP 方程为: \[ \frac{\partial q(x,s)}{\partial s}=-\frac{\partial}{\partial x}\left[b(x,s)q(x,s)\right]+\frac{1}{2}\frac{\partial^2}{\partial x^2}\left[\sigma^2(x,s)q(x,s)\right]\label{fp-reverse}\tag{3} \]
如果猜想成立,那么 \(\eqref{fp-reverse}\) 式应与 \(\eqref{fp-change}\) 式相同,联立便可解出 \(b(x,s)\) 与 \(\sigma(x,s)\),即找到了我们想要的 SDE.
小心求解
联立 \(\eqref{fp-reverse}\) 式与 \(\eqref{fp-change}\) 式,得: \[ \frac{\partial}{\partial x}\left[\bar f(x,s)q(x,s)\right]-\frac{1}{2}\frac{\partial^2}{\partial x^2}\left[\bar g^2(x,s)q(x,s)\right]=-\frac{\partial}{\partial x}\left[b(x,s)q(x,s)\right]+\frac{1}{2}\frac{\partial^2}{\partial x^2}\left[\sigma^2(x,s)q(x,s)\right] \] 合并化简得: \[ \frac{\partial}{\partial x}\left[\left(\bar f(x,s)+b(x,s)\right)q(x,s)\right]=\frac{1}{2}\frac{\partial^2}{\partial x^2}\left[\left(\bar g^2(x,s)+\sigma^2(x,s)\right)q(x,s)\right] \]
左右乘上任一光滑紧支的测试函数 \(R(x)\) 并对 \(x\) 积分: \[ \int R(x)\frac{\partial}{\partial x}\left[\left(\bar f(x,s)+b(x,s)\right)q(x,s)\right]\mathrm dx=\frac{1}{2}\int R(x)\frac{\partial^2}{\partial x^2}\left[\left(\bar g^2(x,s)+\sigma^2(x,s)\right)q(x,s)\right]\mathrm dx \] 左右各自做一次分部积分,得: \[ -\int R'(x)\left(\bar f(x,s)+b(x,s)\right)q(x,s)\mathrm dx=-\frac{1}{2}\int R'(x)\frac{\partial}{\partial x}\left[\left(\bar g^2(x,s)+\sigma^2(x,s)\right)q(x,s)\right]\mathrm dx \] 这里假设了 \(q(x,s)\) 在无穷远处为零。根据 \(R(x)\) 的任意性,应有: \[ \left(\bar f(x,s)+b(x,s)\right)q(x,s)=\frac{1}{2}\frac{\partial}{\partial x}\left[\left(\bar g^2(x,s)+\sigma^2(x,s)\right)q(x,s)\right] \] 将右式按求导乘积法则展开,并将左边的 \(q(x,s)\) 除到右边,整理得: \[ b(x,s)=-\bar f(x,s)+\frac{1}{2}\frac{\partial}{\partial x}\left[\bar g^2(x,s)+\sigma^2(x,s)\right]+\frac{1}{2}\left(\bar g^2(x,s)+\sigma^2(x,s)\right)\frac{\partial}{\partial x}\log q(x,s) \] 因此我们要找的 SDE 为: \[ \mathrm dY_s=\left[-\bar f(Y_s,s)+\frac{1}{2}\frac{\partial}{\partial Y_s}\left[\bar g^2(Y_s,s)+\sigma^2(Y_s,s)\right]+\frac{1}{2}\left(\bar g^2(Y_s,s)+\sigma^2(Y_s,s)\right)\frac{\partial}{\partial Y_s}\log q(Y_s,s)\right]\mathrm ds+\sigma(Y_s,s)\mathrm dW_s\label{sde-reverse}\tag{4} \] 注意这是以 \(\sigma(Y_s,s)\) 为自由变量的一族 SDE,而不是一个特定的 SDE. 这是因为 SDE 与 FP 方程是多对一的关系,不同的 SDE 可以具有相同的 FP 方程。下面我们考察两个特解。
Reverse-Time SDE
当取 \(\sigma^2(Y_s,s)=\bar g^2(Y_s,s)\) 时,代入 \(\eqref{sde-reverse}\) 式得到一个特解: \[ \mathrm dY_s=\left[-\bar f(Y_s,s)+\frac{\partial}{\partial Y_s}\bar g^2(Y_s,s)+\bar g^2(Y_s,s)\frac{\partial}{\partial Y_s}\log q(Y_s,s)\right]\mathrm ds+\bar g(Y_s,s)\mathrm dW_s \] 将 \(\bar f,\bar g,q\) 代换回原本的记号 \(f,g,p\),有: \[ \mathrm dY_s=\left[-f(Y_s,T-s)+\frac{\partial}{\partial Y_s}g^2(Y_s,T-s)+g^2(Y_s,T-s)\frac{\partial}{\partial Y_s}\log p(Y_s,T-s)\right]\mathrm ds+g(Y_s,T-s)\mathrm dW_s \] 再将 \(s\) 代换回 \(T-t\),有: \[ \mathrm dY_{T-t}=\left[f(Y_{T-t},t)-\frac{\partial}{\partial Y_{T-t}}g^2(Y_{T-t},t)-g^2(Y_{T-t},t)\frac{\partial}{\partial Y_{T-t}}\log p(Y_{T-t},t)\right]\mathrm dt+g(Y_{T-t},t)\mathrm dW_{T-t} \]
引入记号: \[ \bar X_t\coloneqq Y_{T-t},\quad \bar W_t\coloneqq W_{T-t} \] 则有: \[ \mathrm d\bar X_t=\left[f(\bar X_t,t)-\frac{\partial}{\partial \bar X_t}g^2(\bar X_t,t)-g^2(\bar X_t,t)\frac{\partial}{\partial \bar X_t}\log p(\bar X_t,t)\right]\mathrm dt+g(\bar X_t,t)\mathrm d\bar W_t \] 这就是 reverse-time SDE.
Probability Flow ODE
当取 \(\sigma^2(Y_s,s)=0\) 时,代入 \(\eqref{sde-reverse}\) 式得到另一个特解: \[ \mathrm dY_s=\left[-\bar f(Y_s,s)+\frac{1}{2}\frac{\partial}{\partial Y_s}\bar g^2(Y_s,s)+\frac{1}{2}\bar g^2(Y_s,s)\frac{\partial}{\partial Y_s}\log q(Y_s,s)\right]\mathrm ds \] 将 \(\bar f,\bar g,q\) 代换回原本的记号 \(f,g,p\),有: \[ \mathrm dY_s=\left[-f(Y_s,T-s)+\frac{1}{2}\frac{\partial}{\partial Y_s}g^2(Y_s,T-s)+\frac{1}{2}g^2(Y_s,T-s)\frac{\partial}{\partial Y_s}\log p(Y_s,T-s)\right]\mathrm ds \] 再将 \(s\) 代换回 \(T-t\),有: \[ \mathrm dY_{T-t}=\left[f(Y_{T-t},t)-\frac{1}{2}\frac{\partial}{\partial Y_{T-t}}g^2(Y_{T-t},t)-\frac{1}{2}g^2(Y_{T-t},t)\frac{\partial}{\partial Y_{T-t}}\log p(Y_{T-t},t)\right]\mathrm dt \] 引入记号: \[ \bar X_t\coloneqq Y_{T-t} \] 则有: \[ \mathrm d\bar X_t=\left[f(\bar X_t,t)-\frac{1}{2}\frac{\partial}{\partial \bar X_t}g^2(\bar X_t,t)-\frac{1}{2}g^2(\bar X_t,t)\frac{\partial}{\partial \bar X_t}\log p(\bar X_t,t)\right]\mathrm dt \] 这就是 Probability Flow ODE. 由于 ODE 本来就是可逆的,所以上式与 forward-time 的 PF ODE 完全一致。
多维情形
上文推导的是一维情形,其结论可以推广到多维情形下。设有如下 \(\mathbb R^d\) 空间中的 SDE: \[ \mathrm d\mathbf x_t=\mathbf f(\mathbf x_t,t)\mathrm dt+\mathbf G(\mathbf x_t,t)\mathrm d\mathbf w_t \] 其中 \(\mathbf f(\cdot,t):\mathbb R^d\to\mathbb R^d\),\(\mathbf G(\cdot,t):\mathbb R^d\to\mathbb R^{d\times d}\),则 reverse-time SDE 为: \[ \mathrm d\bar{\mathbf x}_t=\left\{\mathbf f(\bar{\mathbf x}_t,t)-\nabla\cdot\left[\mathbf G(\bar{\mathbf x}_t,t)\mathbf G(\bar{\mathbf x}_t,t)^T\right]-\mathbf G(\bar{\mathbf x}_t,t)\mathbf G(\bar{\mathbf x}_t,t)^T\nabla_{\bar{\mathbf x}_t}\log p(\bar{\mathbf x}_t,t)\right\}\mathrm dt+\mathbf G(\bar{\mathbf x}_t,t)\mathrm d\bar{\mathbf w}_t \] 其中对于矩阵值函数 \(\mathbf F(\mathbf x)=(\mathbf f^1(\mathbf x),\mathbf f^2(\mathbf x),\ldots,\mathbf f^d(\mathbf x))^T\),符号 \(\nabla\cdot\mathbf F(\mathbf x)\coloneqq (\nabla\cdot\mathbf f^1(\mathbf x),\nabla\cdot\mathbf f^2(\mathbf x),\ldots,\nabla\cdot\mathbf f^d(\mathbf x))^T\).
特别地,倘若 \(\mathbf G(\mathbf x_t,t)\coloneqq g(t)\),即原 SDE 为: \[ \mathrm d\mathbf x_t=\mathbf f(\mathbf x_t,t)\mathrm dt+g(t)\mathrm d\mathbf w_t \] 则 reverse-time SDE 简化为: \[ \mathrm d\bar{\mathbf x}_t=\left[\mathbf f(\bar{\mathbf x}_t,t)-g^2(t)\nabla_{\bar{\mathbf x}_t}\log p(\bar{\mathbf x}_t,t)\right]\mathrm dt+g(t)\mathrm d\bar{\mathbf w}_t \] 这就是扩散模型中常用到的形式。
参考资料
- Ji-Ha Kim. Deriving Reverse-Time Stochastic Differential Equations (SDEs). https://jiha-kim.github.io/posts/deriving-reverse-time-stochastic-differential-equations-sdes/ ↩︎
- Song, Yang, et al. Score-Based Generative Modeling through Stochastic Differential Equations. International Conference on Learning Representations. ↩︎