Linear Attention
Softmax Attention
设有 \(\mathbf Q,\mathbf K,\mathbf V\in\mathbb R^{n\times d}\),其中 \(n\) 表示序列长度,\(d\) 表示 token 维度。标准 softmax attention 的计算机制为: \[ \mathbf O=\text{softmax}\left(\frac{\mathbf Q\mathbf K^{\mathsf T}}{\sqrt{d}}\right)\mathbf V\in\mathbb R^{n\times d} \] 由该式可以清晰地计算 softmax attention 的时间复杂度:\(\mathbf Q\mathbf K^\mathsf T\) 的复杂度为 \(\mathcal O(n^2d)\),结果为 \(n\times n\) 矩阵;softmax 运算将矩阵的每一行做归一化,复杂度为 \(\mathcal O(n^2)\);最后与 \(\mathbf V\) 相乘,复杂度为 \(\mathcal O(n^2d)\). 综上,softmax attention 的时间复杂度为 \(\mathcal O(n^2d)\),与 token 数量呈平方增长关系。当采用多头注意力时,模型维度 \(D\) 会被切分成 \(h\) 个 head,每个 head 的维度为 \(d=D/h\). 由于各个 head 是独立计算的,因此总复杂度为 \(\mathcal O(n^2d\cdot h)=\mathcal O(n^2D)\). 也就是说,softmax attention 的时间复杂度与 head 的数量无关。
接下来计算空间复杂度:\(\mathbf Q,\mathbf K,\mathbf V,\mathbf O\) 本身就占 \(\mathcal O(nd)\) 的空间,\(\mathbf Q\mathbf K^\mathsf T\) 占 \(\mathcal O(n^2)\) 的空间,因此总空间复杂度为 \(\mathcal O(nd+n^2)\),仍然与 token 数量呈平方增长关系。对于多头注意力,空间复杂度要进一步乘以 head 的数量,为 \(\mathcal O(ndh+n^2h)=\mathcal O(nD+n^2h)\),因此 head 越多,占用空间越大。
| Softmax Attention | Single Head | Multi Head |
|---|---|---|
| Time Complexity | \(\mathcal O(n^2d)\) | \(\mathcal O(n^2D)\) |
| Space Complexity | \(\mathcal O(nd+n^2)\) | \(\mathcal O(nD+n^2h)\) |
我们可以将上述矩阵形式可以展开为分量形式。记 \(\mathbf q_i,\mathbf k_i,\mathbf v_i\in\mathbb R^{1\times d}\) 表示 \(\mathbf Q,\mathbf K,\mathbf V\) 的第 \(i\) 行,那么它的输出是各个 \(\mathbf v_j\) 的线性组合,组合系数是 \(\mathbf q_i\) 与各个 \(\mathbf k_j\) 的缩放点积相似度,并经由 softmax 归一化: \[ \mathbf o_i=\frac{\sum_{j=1}^n\exp(\mathbf q_i\mathbf k_j^\mathsf T/\sqrt{d})\mathbf v_j}{\sum_{j=1}^n\exp(\mathbf q_i\mathbf k_j^\mathsf T/\sqrt{d})}\in\mathbb R^{1\times d} \] 如若将 \(\exp(\mathbf q_i\mathbf k_j^\mathsf T/\sqrt{d})\) 视作 \(\mathbf q_i,\mathbf k_j\) 之间的一种相似度函数,那么上式可以推广为: \[ \mathbf o_i=\frac{\sum_{j=1}^n\text{sim}(\mathbf q_i,\mathbf k_j)\mathbf v_j}{\sum_{j=1}^n\text{sim}(\mathbf q_i,\mathbf k_j)} \] 因此,我们更换不同的相似度函数,即可得到不同的 attention 机制。
Linear Attention
Linear attention 的关键在于将相似度函数 \(\text{sim}(\cdot,\cdot)\) 约束为某种核函数 \(k(\cdot,\cdot)\): \[ \mathbf o_i=\frac{\sum_{j=1}^nk(\mathbf q_i,\mathbf k_j)\mathbf v_j}{\sum_{j=1}^nk(\mathbf q_i,\mathbf k_j)} \] 于是,若设核函数对应的特征映射函数为 \(\phi(\cdot)\),也即满足 \(k(\mathbf x,\mathbf y)=\phi(\mathbf x)\phi(\mathbf y)^\mathsf T\),那么有: \[ \mathbf o_i=\frac{\sum_{j=1}^n\left(\phi(\mathbf q_i)\phi(\mathbf k_j)^\mathsf T\right)\mathbf v_j}{\sum_{j=1}^n\phi(\mathbf q_i)\phi(\mathbf k_j)^\mathsf T}=\frac{\phi(\mathbf q_i)\left(\sum_{j=1}^n\phi(\mathbf k_j)^\mathsf T\mathbf v_j\right)}{\phi(\mathbf q_i)\left(\sum_{j=1}^n\phi(\mathbf k_j)^\mathsf T\right)} \] 由此可见,不同的 \(\mathbf q_i\) 共享 \(\left(\sum_{j=1}^n\phi(\mathbf k_j)^\mathsf T\mathbf v_j\right)\) 和 \(\left(\sum_{j=1}^n\phi(\mathbf k_j)^\mathsf T\right)\),因此这两部分可以只计算一次,它们的时间复杂度分别为 \(\mathcal O(nd^2)\) 和 \(\mathcal O(nd)\). 计算结果与各个 \(\phi(\mathbf q_i)\) 相乘,总时间复杂度为 \(\mathcal O(nd^2)\). 因此,linear attention 机制的时间复杂度就是 \(\mathcal O(nd^2)\),与 token 数量呈线性增长关系。当采用多头注意力时,仍设模型维度为 \(D\),head 数量为 \(h\),各个 head 的维度为 \(d=D/h\),那么总复杂度为 \(\mathcal O(nd^2\cdot h)=\mathcal O(nD^2/h)\). 可见与 softmax attention 不同,linear attention 的复杂度与 head 的数量相关,且 head 数越多,复杂度越低。
对于空间复杂度,输入输出矩阵 \(\mathbf Q,\mathbf K,\mathbf V,\mathbf O\) 本身占 \(\mathcal O(nd)\),而中间计算向量外积占 \(\mathcal O(d^2)\),因此总复杂度为 \(\mathcal O(nd+d^2)\). 对于多头注意力,其复杂度乘以 head 的数量,为 \(\mathcal O(ndh+d^2h)=\mathcal O(nD+D^2/h)\),因此 head 越多,占用空间反而越少。
| Linear Attention | Single Head | Multi Head |
|---|---|---|
| Time Complexity | \(\mathcal O(nd^2)\) | \(\mathcal O(nD^2/h)\) |
| Space Complexity | \(\mathcal O(nd+d^2)\) | \(\mathcal O(nD+D^2/h)\) |
上文中我们用分量形式进行推导,但实践中需要用矩阵形式并行计算。记 \(\phi(\mathbf Q),\phi(\mathbf K)\) 表示对 \(\mathbf Q,\mathbf K\) 的各行施加特征映射函数,则输出矩阵的分子分母部分可分别写作: \[
\begin{gather}
\text{numerator}(\mathbf O)=\phi(\mathbf Q)\left(\phi(\mathbf K)^\mathsf T\mathbf V\right)\in\mathbb R^{n\times d}\\
\text{denominator}(\mathbf O)=\phi(\mathbf Q)\left(\phi(\mathbf K)^\mathsf T\mathbf 1_n\right)\in\mathbb R^{n\times 1}
\end{gather}
\] 其中 \(\mathbf 1_n\in\mathbb R^{n\times 1}\) 是一个全 1 列向量。最后利用 torch 的广播机制将分子分母相除即可。
值得注意的是,linear attention 虽然对 token 数量 \(n\) 呈线性复杂度,但对维度 \(d\) 呈平方复杂度。因此,当 \(d\ll n\) 时 linear attention 才比 softmax attention 更高效;倘若 \(d\approx n\) 或 \(d>n\),则 linear attention 不能带来明显的效率提升。
Causal Masking
上文讨论的 softmax / linear attention 都是双向的,而在自回归生成中,我们需要因果注意力。对于 softmax attention,这是通过给 attention map 加 causal mask 实现的: \[ \mathbf O=\text{softmax}\left(\frac{\mathbf Q\mathbf K^{\mathsf T}}{\sqrt{d}}+\log\mathbf M\right)\mathbf V,\quad\mathbf M_{ij}=\begin{cases}1,&i\geq j\\0,&i<j\end{cases} \] 其中定义 \(\log 0=-\infty\). 加上负无穷的元素在经过 softmax 后变成了 \(0\),因此不参与后续计算。在分量形式中,加 causal mask 对应于将从 \(1\) 到 \(n\) 的求和改为从 \(1\) 到 \(i\) 的求和,使得 \(\mathbf o_i\) 只依赖于 \(j=1,2,\ldots,i\) 的输入: \[ \mathbf o_i=\frac{\sum_{j=1}^i\text{sim}(\mathbf q_i,\mathbf k_j)\mathbf v_j}{\sum_{j=1}^i\text{sim}(\mathbf q_i,\mathbf k_j)} \] 基于同样的道理,我们修改求和上限即可将 linear attention 改造为 causal 形式: \[ \mathbf o_i=\frac{\phi(\mathbf q_i)\left(\sum_{j=1}^i\phi(\mathbf k_j)^\mathsf T\mathbf v_j\right)}{\phi(\mathbf q_i)\left(\sum_{j=1}^i\phi(\mathbf k_j)^\mathsf T\right)} \] 回忆 linear attention 达成线性复杂度的关键是括号中的两部分可以只计算一次,而在 causal 形式中,尽管这两部分不能一次性计算出来,但它们可以在常数复杂度内递推出来。具体而言,记: \[ \mathbf S_i=\sum_{j=1}^i\phi(\mathbf k_j)^\mathsf T\mathbf v_j\in\mathbb R^{d\times d},\quad \mathbf z_i=\sum_{j=1}^i\phi(\mathbf k_j)^\mathsf T\in\mathbb R^{d\times 1} \] 则有: \[ \begin{cases} \mathbf S_i=\mathbf S_{i-1}+\phi(\mathbf k_i)^\mathsf T\mathbf v_i\\ \mathbf S_0=\mathbf 0 \end{cases},\quad \begin{cases} \mathbf z_i=\mathbf z_{i-1}+\phi(\mathbf k_i)^\mathsf T\\ \mathbf z_0=\mathbf 0 \end{cases},\quad \mathbf o_i=\frac{\phi(\mathbf q_i)\mathbf S_i}{\phi(\mathbf q_i)\mathbf z_i} \] 由于每次递推复杂度是 \(\mathcal O(d^2)\),共递推 \(n\) 次,因此总时间复杂度仍为 \(\mathcal O(nd^2)\). 特别注意到,这种递推形式满足 RNN 的思想,即给定输入(\(\mathbf q_i,\mathbf k_i,\mathbf v_i\)),修改内部状态(\(\mathbf S_{i-1}\to\mathbf S_i,\,\mathbf z_{i-1}\to\mathbf z_i\)),预测输出(\(\mathbf o_i\)),因此 linear attention 的论文标题直言道 "Transformers are RNNs".
在后续研究中,RetNet 发现取 \(\phi(\cdot)\) 为恒等映射并去掉归一化要求在实践中依旧表现得很好,此时有递推形式: \[ \mathbf S_0=\mathbf 0,\quad \mathbf S_i=\mathbf S_{i-1}+\mathbf k_i^\mathsf T\mathbf v_i,\quad\mathbf o_i=\mathbf q_i\mathbf S_i \] 后文中我们默认采用该简化形式。我们还可以看到 causal linear attention 的一个特性:推理时,其内部状态 \(\mathbf S_i\in\mathbb R^{d\times d}\) 的大小是固定的,不随序列长度的增长而变化;相反,causal softmax attention 需要维护的 KV-cache 随序列长度线性增长,会为内存带来不少压力。
Chunkwise Parallel Form
尽管上一节中我们得到的 causal linear attention 的递推形式 (recurrent form) 适合推理,但无法转化为矩阵并行形式 (parallel form),因此不适合在 GPU 上训练。然而,由于 causal 的约束,我们不得不回到二次复杂度才能并行计算: \[ \mathbf O=\left((\mathbf Q\mathbf K^\mathsf T)\odot\mathbf M\right)\mathbf V\in\mathbb R^{n\times d},\quad\mathbf M_{ij}=\begin{cases}1,&i\geq j\\0,&i<j\end{cases} \] 为了解决这个问题,人们提出了 chunkwise parallel form 作为折中。设每个 chunk 长为 \(C\),并记: \[ \begin{align} \mathbf Q_{[i]}&=\mathbf Q_{iC+1:(i+1)C+1}\in\mathbb R^{C\times d}\\ \mathbf K_{[i]}&=\mathbf K_{iC+1:(i+1)C+1}\in\mathbb R^{C\times d}\\ \mathbf V_{[i]}&=\mathbf V_{iC+1:(i+1)C+1}\in\mathbb R^{C\times d}\\ \mathbf O_{[i]}&=\mathbf O_{iC+1:(i+1)C+1}\in\mathbb R^{C\times d}\\ \mathbf S_{[i]}&=\mathbf S_{iC}\in\mathbb R^{d\times d} \end{align} \] 其中 \(i\in\{0,1,\ldots,\frac{n}{C}-1\}\). 那么,chunk 之间内部状态 \(\mathbf S\) 的递推公式为: \[ \mathbf S_{[0]}=\mathbf0,\quad\mathbf S_{[i+1]}=\mathbf S_{[i]}+\sum_{j=iC+1}^{(i+1)C+1}\mathbf k_j^\mathsf T\mathbf v_j=\mathbf S_{[i]}+\mathbf K_{[i]}^\mathsf T\mathbf V_{[i]}\in\mathbb R^{d\times d} \] 第 \(i+1\) 个 chunk 的输出由前一个 chunk 的递推结果和当前 chunk 内部的并行计算结果共同决定: \[ \mathbf O_{[i+1]}=\mathbf Q_{[i+1]}\mathbf S_{[i]}+\left((\mathbf Q_{[i+1]}\mathbf K_{[i+1]}^\mathsf T)\odot\mathbf M\right)\mathbf V_{[i+1]}\in\mathbb R^{C\times d} \] 计算单个 chunk 的时间复杂度为 \(\mathcal O(C^2d+Cd^2)\),计算整个序列的总复杂度为 \(\mathcal O(\frac{n}{C}(C^2d+Cd^2))=\mathcal O(nCd+nd^2)\),关于序列长度呈次二次 (subquadratic) 复杂度增长。特别地,parallel form 和 recurrent form 分别可视作 \(C=n\) 和 \(C=1\) 时的特殊情形。
实践中,Flash Linear Attention 库提供了 chunkwise parallel form 的 causal linear attention 的 I/O-aware 实现,相比纯 PyTorch 实现在速度和显存上都有明显优势。
References
- Katharopoulos, Angelos, Apoorv Vyas, Nikolaos Pappas, and François Fleuret. Transformers are rnns: Fast autoregressive transformers with linear attention. In International conference on machine learning, pp. 5156-5165. PMLR, 2020. ↩︎
- Wang, Jiahao, Ning Kang, Lewei Yao, Mengzhao Chen, Chengyue Wu, Songyang Zhang, Shuchen Xue et al. LiT: Delving into a Simple Linear Diffusion Transformer for Image Generation. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 16068-16078. 2025. ↩︎
- Sun, Yutao, Li Dong, Shaohan Huang, Shuming Ma, Yuqing Xia, Jilong Xue, Jianyong Wang, and Furu Wei. Retentive network: A successor to transformer for large language models. arXiv preprint arXiv:2307.08621 (2023). ↩︎
- Yang, Songlin, Bailin Wang, Yikang Shen, Rameswar Panda, and Yoon Kim. Gated Linear Attention Transformers with Hardware-Efficient Training. In International Conference on Machine Learning, pp. 56501-56523. PMLR, 2024. ↩︎