400-100-5265

预约演示

大模型算子的数学底座

2026-06-18

大模型 Infra 优化里,有一类东西特别有意思:表面看是 CUDA Kernel、HBM 带宽、Tensor Core 利用率,往下拆一层,很多优化其实是数学变形。

RMSNorm 砍掉均值,Softmax 减最大值,Attention 里除以 sqrt(d_k),FlashAttention 用 Online Softmax 避免中间矩阵落 HBM,采样阶段用 Gumbel-Max 把前缀和变成 argmax。它们看起来分散在不同模块里,但背后的工程逻辑很一致:在不改变结果,或者只做可接受近似的前提下,把计算改写成更适合硬件执行的形态。

这里的关键不是“数学很优雅”,而是它能不能少读一次显存、少做一次同步、少占一点寄存器,能不能让 GPU 少等内存、多跑计算。大模型推理优化,很多时候就是这么朴素。

一、RMSNorm 省掉了什么

大模型层数一深,隐藏状态经过连续矩阵乘法、残差相加、非线性变换后,数值尺度很容易飘。尺度一旦失控,训练时梯度会出问题,推理时也可能遇到溢出、NaN,最后表现成输出质量异常,甚至直接崩掉。

Normalization 的作用,就是把每层传递的张量拉回一个比较安全的尺度。

LayerNorm 做得更完整:

[ \text{LayerNorm}(x)=\frac{x-\mu}{\sqrt{\sigma^2 \epsilon}}\cdot \gamma \beta ]

它先减均值,再除以标准差。这个形式很像统计里的 Z-score 标准化:让数据均值接近 0,方差接近 1。

问题在于,从 GPU 算子角度看,LayerNorm 并不便宜。它至少需要维护两个统计量:均值和方差。即使用高性能 Kernel 在一次 HBM load 中同时累计 sum(x)sum(x^2),也仍然有额外的规约状态、寄存器压力和 element-wise 减均值操作。

RMSNorm 的判断更激进一点:归一化最核心的收益来自缩放,而不是平移。

所以它直接砍掉均值:

[ \text{RMSNorm}(x)=\frac{x}{\sqrt{\frac{1}{d}\sum_{i=1}^{d}x_i^2 \epsilon}}\cdot \gamma ]

也就是只计算均方根 RMS,不再计算 x - mean(x)

这一步看似只是少了一个均值,工程上却很实在:

算子 统计量 主要 element-wise 操作 工程影响
LayerNorm mean、variance 减均值、除标准差、缩放、偏置 状态更多,寄存器/ALU 压力更高
RMSNorm RMS 除 RMS、缩放 更容易融合,指令更少,访存更友好

现代 LLM 里,RMSNorm 往往还会去掉 bias,只保留 gamma。这不是数学定理强制要求,而是和无 bias Linear、Pre-Norm、SwiGLU 等结构一起形成的一套经验组合。

说得更工程一点:它有效,而且 Infra 友好。

在生产系统里,这类选择很常见。不是每个设计都能从单一数学原则推出,很多时候是效果、稳定性、参数量、访存、Kernel 融合空间共同权衡后的结果。

二、Pre-Norm 让深层网络能训下去

Normalization 放在哪里,也会影响大模型能不能稳定扩展。

早期 Transformer 和 ResNet v1 类似,常用 Post-Norm:

[ y=\text{Norm}(x F(x)) ]

也就是先做子层计算,再残差相加,最后归一化。

这个结构在层数不太深时能跑。但模型堆到几十层、上百层时,问题会变明显:残差主干路径被 Norm 一层层打断,梯度反传时每层都要经过归一化的导数,容易出现梯度衰减或放大。

所以早期很多 Post-Norm Transformer 需要较长 warm-up。warm-up 本质上是在训练初期用很小学习率避免参数更新过猛,给网络一个逐渐稳定的过程。

Pre-Norm 换了顺序:

[ y=x F(\text{Norm}(x)) ]

主干路径上的 x 直接一路往后加,反向传播时也有一条更干净的 identity path。这个设计让深层模型训练稳定很多,现代大模型基本都采用 Pre-Norm 或它的变体。

但 Pre-Norm 也不是免费午餐。

随着层数加深,主干残差不断累加,x 的尺度可能越来越大,而每一层新产生的 F(Norm(x)) 因为经过 Norm,尺度相对稳定。到深层以后,新特征对主干的影响可能变弱,这就是常说的 representation collapse。

这也是工程里的典型取舍:

  • Post-Norm 前向表达更“干净”,但深层训练更难;
  • Pre-Norm 训练稳定,利于 scaling,但可能牺牲一部分深层表达效率;
  • DeepNorm 等方案试图修补 Post-Norm,但主流 Infra 已经围绕 Pre-Norm RMSNorm 做了大量优化。

很多时候,架构不是单纯选“理论最优”,而是选“在当前训练规模、硬件生态、算子实现里最划算”。

三、Softmax 的数值稳定不是小细节

Softmax 把 logits 转成概率:

[ \text{Softmax}(z_i)=\frac{e^{z_i}}{\sum_j e^{z_j}} ]

它有两个很重要的性质:

  1. 输出非负;
  2. 所有概率和为 1;
  3. 指数会放大 logits 之间的差异。

但直接算这个公式很危险。如果某个 logit 很大,比如 1000,exp(1000) 很容易溢出成 Inf。

工程实现里一定会做 Safe Softmax:

[ \text{Softmax}(z_i)=\frac{e^{z_i-m}}{\sum_j e^{z_j-m}}, \quad m=\max(z) ]

因为 Softmax 只关心相对差值,所有 logits 同时减去最大值,不会改变结果。

这个变形很简单,但非常关键:

  • 最大 logit 变成 0;
  • 其他 logit 都小于等于 0;
  • exp(x) 不再向上溢出;
  • 非最大项太小时即便下溢到 0,也通常是可接受的。

这里能看到 Infra 优化的一个基本套路:数学等价变换先保证结果不变,再让硬件执行更安全。

四、Attention 为什么要除以 sqrt(d)

Attention 中,Query 和 Key 的点积是:

[ S=QK^T ]

单个 attention score 本质上是两个长度为 (d_k) 的向量点积:

[ q\cdot k=\sum_{i=1}^{d_k}q_i k_i ]

如果假设 (q_i)、(k_i) 独立,均值为 0,方差为 1,那么每个乘积项 (q_i k_i) 的方差约为 1。累加 (d_k) 项后,点积结果的方差会变成:

[ \text{Var}(q\cdot k)=d_k ]

标准差就是:

[ \sqrt{d_k} ]

所以 Transformer 在 Softmax 前做缩放:

[ \text{Attention}(Q,K,V)=\text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V ]

否则 (d_k) 越大,logits 越容易变得极端,Softmax 会快速接近 one-hot。训练时,Softmax 导数和 (p(1-p)) 有关,当概率接近 0 或 1,梯度就会变小。

当然,真实神经网络里的 Q/K 不可能严格满足独立、均值 0、方差 1。这里的推导只是一个理想化模型。

但它在初始化阶段很有用:至少让 logits 一开始处在一个比较合理的尺度里,避免训练开局就被 Softmax 饱和拖死。后面分布逐渐偏离假设,模型参数会自己适应。

这就是大模型里很常见的一类“理论”:它未必完全精确,但能给工程系统一个稳定起点。

五、Causal Mask 的本质是概率归零

自回归语言模型生成第 (t) 个 token 时,只能看历史 token,不能看未来 token。

Attention 里通常通过 Causal Mask 实现这个约束。对未来位置 (j>i),把对应 logit 加上负无穷:

[ S_{ij}= \begin{cases} S_{ij}, & j\le i \ -\infty, & j>i \end{cases} ]

经过 Softmax 后:

[ e^{-\infty}=0 ]

未来 token 的注意力权重就变成 0。

朴素实现会构造一个 (N\times N) 的 mask 矩阵,再和 attention score 相加。长上下文下,这会带来明显的 HBM 访存压力。

高性能 Attention Kernel 不会这么干。

FlashAttention 这类实现通常在 tile 调度层面处理 mask:根据当前 Q block 和 K block 的行列位置判断是否需要计算、是否部分遮挡、是否整块跳过。这样就不需要在全局显存里真的生成 mask 矩阵。

这里还有一个容易忽略的点:Causal Mask 主要出现在训练和推理 Prefill 阶段。Decode 阶段每次只有一个新 Query,它要看的 KV Cache 全是历史 token,因果关系天然成立,一般不需要显式做 Causal Mask。

当然,多 token speculative decoding、Medusa、Eagle 这类验证场景另说。只要一次处理多个候选 token,就又可能需要局部因果约束。

六、Online Softmax 如何打掉 HBM 中间矩阵

Attention 的朴素计算会产生两个巨大的中间矩阵:

[ S=QK^T ]

[ P=\text{Softmax}(S) ]

最后再算:

[ O=PV ]

如果序列长度是 (N),这些中间矩阵就是 (N^2) 级别。长上下文下,真正卡住系统的往往不是 FLOPs,而是 HBM 读写。

FlashAttention 的核心不是“近似 Attention”,而是 exact attention 的 IO-aware 实现。它通过分块和 Online Softmax,让 (S) 和 (P) 尽量只存在于寄存器或 SRAM 里,不落 HBM。

Online Softmax 的关键,是维护两个状态:

  • 当前最大值 (m)
  • 当前指数和 (\ell)

当新 block 的 logits 到来时,更新:

[ m_{new}=\max(m_{old}, \max(S_{local})) ]

旧的指数和需要根据新最大值重缩放:

[ \ell_{new}=\ell_{old}\cdot e^{m_{old}-m_{new}} \sum e^{S_{local}-m_{new}} ]

这个缩放因子是整件事的核心。如果新的 block 出现了更大的 max,之前按旧 max 算出来的指数和并没有作废,只需要乘一个修正因子。

FlashAttention 把这个思想扩展到完整 Attention:

# 示意代码:省略具体 block layout、mask、scale、向量化细节
m = -float("inf")
l = 0.0
o = 0.0

for kv_block in kv_blocks:
    s = q @ kv_block.k.T              # 当前 block 的 logits
    m_new = max(m, max(s))

    alpha = exp(m - m_new)            # 修正历史状态
    p = exp(s - m_new)

    l = l * alpha   sum(p)
    o = o * alpha   p @ kv_block.v

    m = m_new

out = o / l

注意这里的 o 不是最终归一化后的输出,而是未归一化累积值。直到遍历完所有 KV block,拿到全局分母以后,才做最后一次除法。

这类优化的收益很直接:

  • 不把 (S) 写回 HBM;
  • 不把 (P) 写回 HBM;
  • Q/K/V 分块进入 SRAM;
  • 中间状态留在寄存器;
  • 最终只写回 (O) 和少量元数据。

很多 Attention 优化的本质都在围绕这条线展开:用更多片上计算,换更少 HBM 访问。

七、LSE 是 FlashAttention 的压缩元数据

FlashAttention 常会保存 LSE,也就是 Log-Sum-Exp:

[ \text{LSE}(z)=\log\sum_j e^{z_j} ]

结合 Safe Softmax 的最大值 (m):

[ \text{LSE}(z)=m \log\sum_j e^{z_j-m} ]

如果令:

[ \ell=\sum_j e^{z_j-m} ]

那么:

[ \text{LSE}=m \log \ell ]

它的作用不只是“数学上好看”。

Softmax 可以写成:

[ \text{Softmax}(z_i)=e^{z_i-\text{LSE}} ]

这意味着,只要保存一个 LSE 标量,后面需要重算 Softmax 时,就可以用 logits 和 LSE 恢复概率,而不用保存完整的 (N\times N) Softmax 矩阵。

这对反向传播尤其重要。

传统 Attention 反传需要前向的 Softmax 矩阵 (P)。如果把它完整存 HBM,长上下文显存压力非常大。FlashAttention 选择前向只保存 LSE,反向时重算局部 (QK^T),再用:

[ P_i=e^{S_i-\text{LSE}} ]

恢复需要的概率块。

这是典型的 recomputation trade-off:

  • 多做一点计算;
  • 少存大量中间结果;
  • 用算力换显存和带宽。

在今天的 GPU 上,这个方向通常是划算的。因为大多数 LLM 推理和训练瓶颈越来越偏向 memory-bound,而不是单纯缺 FLOPs。

LSE 在 Split-K、长上下文解码、Ring Attention 里也有价值。不同 shard 可以各自算局部 LSE 和局部输出,最后通过 LSE 做全局归并,而不是把完整概率矩阵搬来搬去。

八、FlashAttention 的演进不是只换循环顺序

FlashAttention v1 到 v2,一个重要变化是 work partitioning。

FA1 更偏向外层遍历 KV block、内层处理 Q block,容易带来中间状态写回和并行度不足的问题。

FA2 改成外层固定 Q tile,内层流式遍历 KV block。每个 CTA 负责一个 Q tile,把对应输出状态长驻寄存器,直到所有 KV 遍历完才写回。

这个变化带来的效果很现实:

  • Q tile 的输出只写回一次;
  • Online Softmax 状态可以更自然地留在寄存器;
  • 并行维度扩展到 sequence block;
  • 减少跨 warp/CTA 的中间归约。

当然,FA2 也有代价。不同 Q tile 的 CTA 可能会重复读取相同 KV block。实际性能依赖 L2 cache 命中,如果同一 wave 内 CTA 调度足够接近,后面的 CTA 可以从 L2 命中 KV,HBM 压力不会按理论上限放大。

但这也意味着它一定程度上依赖硬件调度行为。长上下文、L2 容量不足、SM 占用不均时,性能曲线会劣化。

后续 FA3、FA4、TRT-LLM、vLLM 高性能后端都在继续往更可控的方向做:更细的流水线、更明确的软件调度、更充分利用 TMA、WGMMA、Tensor Memory、多播能力。

这个方向挺符合 Infra 优化的演进规律:早期吃硬件 cache 红利,后期把不确定性收回到软件调度里。

九、采样也能从前缀和变成 argmax

生成阶段,模型输出 logits 后要采样下一个 token。

Temperature 实际上就是作用在 Softmax 前:

[ p_i=\text{Softmax}\left(\frac{z_i}{\tau}\right) ]

  • (\tau) 越小,分布越尖锐,越接近贪心;
  • (\tau) 越大,分布越平,随机性越强。

得到概率后,传统 multinomial sampling 会做前缀和:

  1. 生成随机数 (u\in[0,1])
  2. 计算概率累计和
  3. 找第一个超过 (u) 的位置

这个逻辑很直观,但对 GPU 不友好。词表现在动辄 128K、256K,前缀和有同步和串行依赖,尤其在张量并行切词表时,跨卡采样成本也不低。

Gumbel-Max Trick 提供了另一个等价形式:

[ \text{sample}=\arg\max_i(\log p_i g_i) ]

其中 (g_i) 是标准 Gumbel 噪声。

vLLM 里常见的实现变体更硬件友好:

# 示意代码
probs = logits.softmax(dim=-1)

q = torch.empty_like(probs)
q.exponential_()          # q ~ Exp(1)

token = probs.div_(q).argmax(dim=-1)

为什么 p / q 可以?

指数分布变量可以由均匀分布生成:

[ q=-\log u ]

而:

[ \arg\max \frac{p_i}{q_i} ]

等价于:

[ \arg\max(\log p_i - \log q_i) ]

代入 (q=-\log u):

[ -\log q=-\log(-\log u) ]

这正是 Gumbel 噪声形式。所以它和按概率分布采样是等价的。

工程收益很明显:

采样方式 核心操作 GPU 友好度
Multinomial prefix-sum search 有串行依赖,同步重
Gumbel-Max element-wise noise argmax 并行友好,规约简单

更重要的是,在张量并行下,Gumbel-Max 让采样退化成分布式 max reduce:

  • 每张卡在本地 vocab shard 上算局部最大值和索引;
  • 跨卡做一次 max-with-index all-reduce;
  • 得到全局 token。

通信量从全量 logits 级别,降到每卡一个候选值和索引。这个差距在大词表下非常实在。

十、SFU 也是瓶颈

很多人看算子性能时,会盯着 Tensor Core 和 HBM,但 Softmax、Norm、RoPE 这类操作里还有一个容易被忽略的硬件单元:SFU,Special Function Unit。

像:

  • exp
  • sin/cos
  • rsqrt

这类非线性函数通常不走普通 FP32 ALU 的 FMA 流水线,而是由 SFU 执行。

问题在于,SFU 数量远少于普通 ALU。以现代 NVIDIA GPU 为例,每个 SM 的 FP32 单元很多,但 SFU 通常少得多。于是 Softmax 里的 exp、Norm 里的 rsqrt、RoPE 的 sin/cos 都可能在局部成为瓶颈。

所以 vLLM 这类系统不会在 RoPE Kernel 里实时算 sin/cos,而是在初始化时预计算 cos_sin_cache。运行时只按 position 取表,再做逐元素乘加。

这是经典的空间换时间:

  • 多存一张表;
  • 少做大量 SFU 调用;
  • 把超越函数变成普通访存和 FMA。

FlashAttention-4 甚至开始对 exp 做部分软件模拟:一部分仍走硬件 SFU,一部分用 ALU 通过多项式近似分担压力。

这不是因为软件算 exp 更优雅,而是 SFU 吞吐不够时,适当让 FMA 管线分担,整体吞吐反而更高。

当然不能 100% 软算。全软算会增加寄存器压力,可能导致 spill,最后反噬性能。这里仍然是权衡:SFU、ALU、寄存器、吞吐之间找平衡点。

十一、这些优化的共同逻辑

把 RMSNorm、Safe Softmax、Causal Mask、Online Softmax、LSE、Gumbel-Max 放在一起看,会发现它们不是孤立技巧。

它们的共同目标很清楚:

  1. 数值稳定例如 Softmax 减最大值,Attention 除以 sqrt(d_k),Norm 控制隐藏状态尺度。
  2. 减少 HBM 访问例如 FlashAttention 不落中间矩阵,反向用 LSE 重算 Softmax。
  3. 提高并行度例如 Gumbel-Max 把采样从 prefix-sum 改成 argmax。
  4. 降低同步和中间状态例如 FA2 让 Q tile 状态长驻寄存器,减少跨 block 归约。
  5. 用近似或经验结构换工程收益 例如 RMSNorm 去均值、去 bias,不一定是数学唯一解,但实际有效且硬件友好。

可以用一句话概括:

大模型 Infra 优化的核心,经常是在数学等价、数值稳定重写和可控近似之间找空间,把模型计算改造成更适合 GPU 的执行形态。

这里面没有银弹。每个优化都带着边界条件。

RMSNorm 适合现代 LLM 的结构组合,但不代表所有网络都该无脑替换 LayerNorm。FlashAttention 对长序列收益明显,但具体性能受 head dim、block size、GPU 架构、cache、调度影响。Gumbel-Max 对分布式采样很友好,但也要求随机数生成、数值精度和并行规约实现足够可靠。

工程上真正麻烦的地方,往往不是公式写不出来,而是公式落到硬件后,各种瓶颈会互相转移:省了 HBM,可能卡 SFU;减少同步,可能增加寄存器;提高并行度,可能降低 cache locality。

这也是做 AI Infra 最现实的部分。

当前 Transformer 架构在效果上已经足够强,Full Attention、MoE、KV Cache、Test-Time Scaling 这些方向都还在继续演进。但大量性能优化并不是重新发明模型,而是在不明显损失效果的前提下,把每个算子、每次访存、每个同步点压到极致。

大模型当然有很多不确定性,训练结果也常带实验科学的味道。但支撑它跑起来的这些数学和系统逻辑,反而非常确定。

创作声明:本内容包含AI辅助创作,观点仅供参考。