Linear Attention

Softmax Attention

设有 \(\mathbf Q,\mathbf K,\mathbf V\in\mathbb R^{n\times d}\),其中 \(n\) 表示序列长度,\(d\) 表示 token 维度。标准 softmax attention 的计算机制为: \[ \mathbf O=\text{softmax}\left(\frac{\mathbf Q\mathbf K^{\mathsf T}}{\sqrt{d}}\right)\mathbf V\in\mathbb R^{n\times d} \] 由该式可以清晰地计算 softmax attention 的时间复杂度:\(\mathbf Q\mathbf K^\mathsf T\) 的复杂度为 \(\mathcal O(n^2d)\),结果为 \(n\times n\) 矩阵;softmax 运算将矩阵的每一行做归一化,复杂度为 \(\mathcal O(n^2)\);最后与 \(\mathbf V\) 相乘,复杂度为 \(\mathcal O(n^2d)\). 综上,softmax attention 的时间复杂度为 \(\mathcal O(n^2d)\),与 token 数量呈平方增长关系。当采用多头注意力时,模型维度 \(D\) 会被切分成 \(h\) 个 head,每个 head 的维度为 \(d=D/h\). 由于各个 head 是独立计算的,因此总复杂度为 \(\mathcal O(n^2d\cdot h)=\mathcal O(n^2D)\). 也就是说,softmax attention 的复杂度与 head 的数量无关。

接下来计算空间复杂度:\(\mathbf Q,\mathbf K,\mathbf V,\mathbf O\) 本身就占 \(\mathcal O(nd)\) 的空间,\(\mathbf Q\mathbf K^\mathsf T\)\(\mathcal O(n^2)\) 的空间,因此总空间复杂度为 \(\mathcal O(nd+n^2)\),仍然与 token 数量呈平方增长关系。对于多头注意力,空间复杂度要进一步乘以 head 的数量,为 \(\mathcal O(ndh+n^2h)=\mathcal O(nD+n^2h)\),因此 head 越多,占用空间越大。

Softmax Attention Single head Multi head
Time complexity \(\mathcal O(n^2d)\) \(\mathcal O(n^2D)\)
Space complexity \(\mathcal O(nd+n^2)\) \(\mathcal O(nD+n^2h)\)

我们可以将上述矩阵形式可以展开为分量形式。记 \(\mathbf q_i,\mathbf k_i,\mathbf v_i\in\mathbb R^d\) 表示第 \(i\) 个 token 对应的 query, key 和 value,那么它的输出是各个 \(\mathbf v_j\) 的线性组合,组合系数是 \(\mathbf q_i\) 与各个 \(\mathbf k_j\) 的缩放点积相似度,并经由 softmax 归一化: \[ \mathbf o_i=\frac{\sum_{j=1}^n\exp\left(\frac{\mathbf q_i^\mathsf T\mathbf k_j}{\sqrt{d}}\right)\mathbf v_j}{\sum_{j=1}^n\exp\left(\frac{\mathbf q_i^\mathsf T\mathbf k_j}{\sqrt{d}}\right)}\in\mathbb R^d \] 如若将 \(\exp\left(\frac{\mathbf q_i^\mathsf T\mathbf k_j}{\sqrt{d}}\right)\) 视作 \(\mathbf q_i,\mathbf k_j\) 之间的一种相似度函数,那么上式可以推广为: \[ \mathbf o_i=\frac{\sum_{j=1}^n\text{sim}(\mathbf q_i,\mathbf k_j)\mathbf v_j}{\sum_{j=1}^n\text{sim}(\mathbf q_i,\mathbf k_j)} \] 因此,我们更换不同的相似度函数,即可得到不同的 attention 机制。

Linear Attention

Linear attention 的关键在于将相似度函数 \(\text{sim}(\cdot,\cdot)\) 约束为某种核函数 \(k(\cdot,\cdot)\)\[ \mathbf o_i=\frac{\sum_{j=1}^nk(\mathbf q_i,\mathbf k_j)\mathbf v_j}{\sum_{j=1}^nk(\mathbf q_i,\mathbf k_j)} \] 于是,若设核函数对应的特征映射函数为 \(\phi(\cdot)\),也即 \(k(\mathbf x,\mathbf y)=\phi(\mathbf x)^\mathsf T\phi(\mathbf y)\),那么有: \[ \mathbf o_i=\frac{\sum_{j=1}^n\left(\phi(\mathbf q_i)^\mathsf T\phi(\mathbf k_j)\right)\mathbf v_j}{\sum_{j=1}^n\phi(\mathbf q_i)^\mathsf T\phi(\mathbf k_j)} \] 对于分子部分,注意括号中为标量,因此可以作如下变形: \[ \sum_{j=1}^n\left(\phi(\mathbf q_i)^\mathsf T\phi(\mathbf k_j)\right)\mathbf v_j=\sum_{j=1}^n\mathbf v_j\left(\phi(\mathbf q_i)^\mathsf T\phi(\mathbf k_j)\right)=\sum_{j=1}^n\mathbf v_j\left(\phi(\mathbf k_j)^\mathsf T\phi(\mathbf q_i)\right)=\left(\sum_{j=1}^n\mathbf v_j\phi(\mathbf k_j)^\mathsf T\right)\phi(\mathbf q_i) \] 分母可作类似变形,代回得: \[ \mathbf o_i=\frac{\left(\sum_{j=1}^n\mathbf v_j\phi(\mathbf k_j)^\mathsf T\right)\phi(\mathbf q_i)}{\left(\sum_{j=1}^n\phi(\mathbf k_j)^\mathsf T\right)\phi(\mathbf q_i)} \] 由此可见,不同的 \(\mathbf q_i\) 共享 \(\left(\sum_{j=1}^n\mathbf v_j\phi(\mathbf k_j)^\mathsf T\right)\)\(\left(\sum_{j=1}^n\phi(\mathbf k_j)^\mathsf T\right)\),因此这两部分可以只计算一次,它们的时间复杂度分别为 \(\mathcal O(nd^2)\)\(\mathcal O(nd)\). 计算结果与各个 \(\phi(\mathbf q_i)\) 相乘,总时间复杂度为 \(\mathcal O(nd^2)\). 因此,linear attention 机制的时间复杂度就是 \(\mathcal O(nd^2)\),与 token 数量呈线性增长关系。当采用多头注意力时,仍设模型维度为 \(D\),head 数量为 \(h\),各个 head 的维度为 \(d=D/h\),那么总复杂度为 \(\mathcal O(nd^2\cdot h)=\mathcal O(nD^2/h)\). 可见与 softmax attention 不同,linear attention 的复杂度与 head 的数量相关,且 head 数越多,复杂度越低。

对于空间复杂度,输入输出矩阵 \(\mathbf Q,\mathbf K,\mathbf V,\mathbf O\) 本身占 \(\mathcal O(nd)\),而中间计算向量外积占 \(\mathcal O(d^2)\),因此总复杂度为 \(\mathcal O(nd+d^2)\). 对于多头注意力,其复杂度乘以 head 的数量,为 \(\mathcal O(ndh+d^2h)=\mathcal O(nD+D^2/h)\),因此 head 越多,占用空间反而越少。

Linear Attention Single head Multi head
Time complexity \(\mathcal O(nd^2)\) \(\mathcal O(nD^2/h)\)
Space complexity \(\mathcal O(nd+d^2)\) \(\mathcal O(nD+D^2/h)\)

上文中我们用分量形式进行推导,但实践中需要用矩阵形式进行计算。记 \(\phi(\mathbf Q),\phi(\mathbf K)\) 表示对 \(\mathbf Q,\mathbf K\) 的各行施加特征映射函数,则上式的分子分母部分可以分别转换为: \[ \begin{gather} \text{numerator}(\mathbf O)=\phi(\mathbf Q)\left(\phi(\mathbf K)^\mathsf T\mathbf V\right)\in\mathbb R^{n\times d}\\ \text{denominator}(\mathbf O)=\phi(\mathbf Q)\left(\phi(\mathbf K)^\mathsf T\mathbf 1_n\right)\in\mathbb R^n \end{gather} \] 其中 \(\mathbf 1_n\in\mathbb R^n\) 是一个全 1 向量。最后利用 torch 的广播机制将分子分母相除即可。

值得注意的是,linear attention 虽然对 token 数量 \(n\) 呈线性复杂度,但对维度 \(d\) 呈平方复杂度。通常情况下我们有 \(d\ll n\),因此 linear attention 才比 softmax attention 更高效;倘若 \(d\approx n\)\(d>n\),那么 linear attention 则不能带来效率提升。

Causal Attention and RNNs

上文我们讨论的 softmax / linear attention 都是双向的,而在自回归生成中,我们需要因果注意力。对于 softmax attention,这是通过给 attention map 加 causal mask 实现的: \[ \mathbf O=\text{softmax}\left(\frac{\mathbf Q\mathbf K^{\mathsf T}}{\sqrt{d}}+\log\mathbf M\right)\mathbf V \] 其中 \(\mathbf M\in\{0,1\}^{n\times n}\) 是一个下三角取 \(1\),上三角取 \(0\) 的矩阵,并定义 \(\log 0=-\infty\). 加上负无穷的元素在经过 softmax 后变成了 \(0\),因此不参与后续计算。在分量形式中,加 causal mask 只需将从 \(1\)\(n\) 的求和改为从 \(1\)\(i\) 的求和即可: \[ \mathbf o_i=\frac{\sum_{j=1}^i\text{sim}(\mathbf q_i,\mathbf k_j)\mathbf v_j}{\sum_{j=1}^i\text{sim}(\mathbf q_i,\mathbf k_j)} \] 基于同样的道理,我们修改求和上限即可将 linear attention 改造为 causal 形式: \[ \mathbf o_i=\frac{\left(\sum_{j=1}^i\mathbf v_j\phi(\mathbf k_j)^\mathsf T\right)\phi(\mathbf q_i)}{\left(\sum_{j=1}^i\phi(\mathbf k_j)^\mathsf T\right)\phi(\mathbf q_i)} \] 回忆 linear attention 达成线性复杂度的关键是括号中的两部分可以只计算一次,而在 causal 形式中,尽管这两部分不能一次性计算出来,但它们可以在常数复杂度内递推出来。具体而言,记 \(S_i=\sum_{j=1}^i\phi(\mathbf k_j)\mathbf v_j^\mathsf T,\,Z_i=\sum_{j=1}^i\phi(\mathbf k_j)\),则有: \[ \begin{cases} S_i=S_{i-1}+\phi(\mathbf k_i)\mathbf v_i^\mathsf T\\ S_0=0 \end{cases},\quad \begin{cases} Z_i=Z_{i-1}+\phi(\mathbf k_i)\\ Z_0=0 \end{cases},\quad \mathbf o_i=\frac{S_i^{\mathsf T}\phi(\mathbf q_i)}{Z_i^{\mathsf T}\phi(\mathbf q_i)} \] 由于递推复杂度是 \(\mathcal O(d^2)\),因此总时间复杂度仍为 \(\mathcal O(nd^2)\). 特别注意到,这种递推形式满足 RNN 的思想:给定一个输入(\(\mathbf q_i,\mathbf k_i,\mathbf v_i\)),修改内在状态(\(S_{i-1}\to S_i,\,Z_{i-1}\to Z_i\)),预测一个输出(\(\mathbf o_i\)),因此 linear attention 的论文标题直言到 "Transformers are RNNs".

References

  1. Katharopoulos, Angelos, Apoorv Vyas, Nikolaos Pappas, and François Fleuret. Transformers are rnns: Fast autoregressive transformers with linear attention. In International conference on machine learning, pp. 5156-5165. PMLR, 2020. ↩︎
  2. Wang, Jiahao, Ning Kang, Lewei Yao, Mengzhao Chen, Chengyue Wu, Songyang Zhang, Shuchen Xue et al. LiT: Delving into a Simple Linear Diffusion Transformer for Image Generation. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 16068-16078. 2025. ↩︎

Linear Attention
https://xyfjason.github.io/blog-main/2026/02/15/Linear-Attention/
作者
xyfJASON
发布于
2026年2月15日
许可协议