大模型 Infra 优化里,有一类东西特别有意思:表面看是 CUDA Kernel、HBM 带宽、Tensor Core 利用率,往下拆一层,很多优化其实是数学变形。
RMSNorm 砍掉均值,Softmax 减最大值,Attention 里除以 sqrt(d_k),FlashAttention 用 Online Softmax 避免中间矩阵落 HBM,采样阶段用 Gumbel-Max 把前缀和变成 argmax。它们看起来分散在不同模块里,但背后的工程逻辑很一致:在不改变结果,或者只做可接受近似的前提下,把计算改写成更适合硬件执行的形态。
这里的关键不是“数学很优雅”,而是它能不能少读一次显存、少做一次同步、少占一点寄存器,能不能让 GPU 少等内存、多跑计算。大模型推理优化,很多时候就是这么朴素。
一、RMSNorm 省掉了什么
大模型层数一深,隐藏状态经过连续矩阵乘法、残差相加、非线性变换后,数值尺度很容易飘。尺度一旦失控,训练时梯度会出问题,推理时也可能遇到溢出、NaN,最后表现成输出质量异常,甚至直接崩掉。
Normalization 的作用,就是把每层传递的张量拉回一个比较安全的尺度。
LayerNorm 做得更完整:
[ \text{LayerNorm}(x)=\frac{x-\mu}{\sqrt{\sigma^2 \epsilon}}\cdot \gamma \beta ]
它先减均值,再除以标准差。这个形式很像统计里的 Z-score 标准化:让数据均值接近 0,方差接近 1。
问题在于,从 GPU 算子角度看,LayerNorm 并不便宜。它至少需要维护两个统计量:均值和方差。即使用高性能 Kernel 在一次 HBM load 中同时累计 sum(x) 和 sum(x^2),也仍然有额外的规约状态、寄存器压力和 element-wise 减均值操作。
RMSNorm 的判断更激进一点:归一化最核心的收益来自缩放,而不是平移。
所以它直接砍掉均值:
[ \text{RMSNorm}(x)=\frac{x}{\sqrt{\frac{1}{d}\sum_{i=1}^{d}x_i^2 \epsilon}}\cdot \gamma ]
也就是只计算均方根 RMS,不再计算 x - mean(x)。
这一步看似只是少了一个均值,工程上却很实在:
| 算子 | 统计量 | 主要 element-wise 操作 | 工程影响 |
|---|---|---|---|
| LayerNorm | mean、variance | 减均值、除标准差、缩放、偏置 | 状态更多,寄存器/ALU 压力更高 |
| RMSNorm | RMS | 除 RMS、缩放 | 更容易融合,指令更少,访存更友好 |
现代 LLM 里,RMSNorm 往往还会去掉 bias,只保留 gamma。这不是数学定理强制要求,而是和无 bias Linear、Pre-Norm、SwiGLU 等结构一起形成的一套经验组合。
说得更工程一点:它有效,而且 Infra 友好。
在生产系统里,这类选择很常见。不是每个设计都能从单一数学原则推出,很多时候是效果、稳定性、参数量、访存、Kernel 融合空间共同权衡后的结果。
二、Pre-Norm 让深层网络能训下去
Normalization 放在哪里,也会影响大模型能不能稳定扩展。
早期 Transformer 和 ResNet v1 类似,常用 Post-Norm:
[ y=\text{Norm}(x F(x)) ]
也就是先做子层计算,再残差相加,最后归一化。
这个结构在层数不太深时能跑。但模型堆到几十层、上百层时,问题会变明显:残差主干路径被 Norm 一层层打断,梯度反传时每层都要经过归一化的导数,容易出现梯度衰减或放大。
所以早期很多 Post-Norm Transformer 需要较长 warm-up。warm-up 本质上是在训练初期用很小学习率避免参数更新过猛,给网络一个逐渐稳定的过程。
Pre-Norm 换了顺序:
[ y=x F(\text{Norm}(x)) ]
主干路径上的 x 直接一路往后加,反向传播时也有一条更干净的 identity path。这个设计让深层模型训练稳定很多,现代大模型基本都采用 Pre-Norm 或它的变体。
但 Pre-Norm 也不是免费午餐。
随着层数加深,主干残差不断累加,x 的尺度可能越来越大,而每一层新产生的 F(Norm(x)) 因为经过 Norm,尺度相对稳定。到深层以后,新特征对主干的影响可能变弱,这就是常说的 representation collapse。
这也是工程里的典型取舍:
- Post-Norm 前向表达更“干净”,但深层训练更难;
- Pre-Norm 训练稳定,利于 scaling,但可能牺牲一部分深层表达效率;
- DeepNorm 等方案试图修补 Post-Norm,但主流 Infra 已经围绕 Pre-Norm RMSNorm 做了大量优化。
很多时候,架构不是单纯选“理论最优”,而是选“在当前训练规模、硬件生态、算子实现里最划算”。
三、Softmax 的数值稳定不是小细节
Softmax 把 logits 转成概率:
[ \text{Softmax}(z_i)=\frac{e^{z_i}}{\sum_j e^{z_j}} ]
它有两个很重要的性质:
- 输出非负;
- 所有概率和为 1;
- 指数会放大 logits 之间的差异。
但直接算这个公式很危险。如果某个 logit 很大,比如 1000,exp(1000) 很容易溢出成 Inf。
工程实现里一定会做 Safe Softmax:
[ \text{Softmax}(z_i)=\frac{e^{z_i-m}}{\sum_j e^{z_j-m}}, \quad m=\max(z) ]
因为 Softmax 只关心相对差值,所有 logits 同时减去最大值,不会改变结果。
这个变形很简单,但非常关键:
- 最大 logit 变成 0;
- 其他 logit 都小于等于 0;
exp(x)不再向上溢出;- 非最大项太小时即便下溢到 0,也通常是可接受的。
这里能看到 Infra 优化的一个基本套路:数学等价变换先保证结果不变,再让硬件执行更安全。
四、Attention 为什么要除以 sqrt(d)
Attention 中,Query 和 Key 的点积是:
[ S=QK^T ]
单个 attention score 本质上是两个长度为 (d_k) 的向量点积:
[ q\cdot k=\sum_{i=1}^{d_k}q_i k_i ]
如果假设 (q_i)、(k_i) 独立,均值为 0,方差为 1,那么每个乘积项 (q_i k_i) 的方差约为 1。累加 (d_k) 项后,点积结果的方差会变成:
[ \text{Var}(q\cdot k)=d_k ]
标准差就是:
[ \sqrt{d_k} ]
所以 Transformer 在 Softmax 前做缩放:
[ \text{Attention}(Q,K,V)=\text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V ]
否则 (d_k) 越大,logits 越容易变得极端,Softmax 会快速接近 one-hot。训练时,Softmax 导数和 (p(1-p)) 有关,当概率接近 0 或 1,梯度就会变小。
当然,真实神经网络里的 Q/K 不可能严格满足独立、均值 0、方差 1。这里的推导只是一个理想化模型。
但它在初始化阶段很有用:至少让 logits 一开始处在一个比较合理的尺度里,避免训练开局就被 Softmax 饱和拖死。后面分布逐渐偏离假设,模型参数会自己适应。
这就是大模型里很常见的一类“理论”:它未必完全精确,但能给工程系统一个稳定起点。
五、Causal Mask 的本质是概率归零
自回归语言模型生成第 (t) 个 token 时,只能看历史 token,不能看未来 token。
Attention 里通常通过 Causal Mask 实现这个约束。对未来位置 (j>i),把对应 logit 加上负无穷:
[ S_{ij}= \begin{cases} S_{ij}, & j\le i \ -\infty, & j>i \end{cases} ]
经过 Softmax 后:
[ e^{-\infty}=0 ]
未来 token 的注意力权重就变成 0。
朴素实现会构造一个 (N\times N) 的 mask 矩阵,再和 attention score 相加。长上下文下,这会带来明显的 HBM 访存压力。
高性能 Attention Kernel 不会这么干。
FlashAttention 这类实现通常在 tile 调度层面处理 mask:根据当前 Q block 和 K block 的行列位置判断是否需要计算、是否部分遮挡、是否整块跳过。这样就不需要在全局显存里真的生成 mask 矩阵。
这里还有一个容易忽略的点:Causal Mask 主要出现在训练和推理 Prefill 阶段。Decode 阶段每次只有一个新 Query,它要看的 KV Cache 全是历史 token,因果关系天然成立,一般不需要显式做 Causal Mask。
当然,多 token speculative decoding、Medusa、Eagle 这类验证场景另说。只要一次处理多个候选 token,就又可能需要局部因果约束。
六、Online Softmax 如何打掉 HBM 中间矩阵
Attention 的朴素计算会产生两个巨大的中间矩阵:
[ S=QK^T ]
[ P=\text{Softmax}(S) ]
最后再算:
[ O=PV ]
如果序列长度是 (N),这些中间矩阵就是 (N^2) 级别。长上下文下,真正卡住系统的往往不是 FLOPs,而是 HBM 读写。
FlashAttention 的核心不是“近似 Attention”,而是 exact attention 的 IO-aware 实现。它通过分块和 Online Softmax,让 (S) 和 (P) 尽量只存在于寄存器或 SRAM 里,不落 HBM。
Online Softmax 的关键,是维护两个状态:
- 当前最大值 (m)
- 当前指数和 (\ell)
当新 block 的 logits 到来时,更新:
[ m_{new}=\max(m_{old}, \max(S_{local})) ]
旧的指数和需要根据新最大值重缩放:
[ \ell_{new}=\ell_{old}\cdot e^{m_{old}-m_{new}} \sum e^{S_{local}-m_{new}} ]
这个缩放因子是整件事的核心。如果新的 block 出现了更大的 max,之前按旧 max 算出来的指数和并没有作废,只需要乘一个修正因子。
FlashAttention 把这个思想扩展到完整 Attention:
# 示意代码:省略具体 block layout、mask、scale、向量化细节
m = -float("inf")
l = 0.0
o = 0.0
for kv_block in kv_blocks:
s = q @ kv_block.k.T # 当前 block 的 logits
m_new = max(m, max(s))
alpha = exp(m - m_new) # 修正历史状态
p = exp(s - m_new)
l = l * alpha sum(p)
o = o * alpha p @ kv_block.v
m = m_new
out = o / l
注意这里的 o 不是最终归一化后的输出,而是未归一化累积值。直到遍历完所有 KV block,拿到全局分母以后,才做最后一次除法。
这类优化的收益很直接:
- 不把 (S) 写回 HBM;
- 不把 (P) 写回 HBM;
- Q/K/V 分块进入 SRAM;
- 中间状态留在寄存器;
- 最终只写回 (O) 和少量元数据。
很多 Attention 优化的本质都在围绕这条线展开:用更多片上计算,换更少 HBM 访问。
七、LSE 是 FlashAttention 的压缩元数据
FlashAttention 常会保存 LSE,也就是 Log-Sum-Exp:
[ \text{LSE}(z)=\log\sum_j e^{z_j} ]
结合 Safe Softmax 的最大值 (m):
[ \text{LSE}(z)=m \log\sum_j e^{z_j-m} ]
如果令:
[ \ell=\sum_j e^{z_j-m} ]
那么:
[ \text{LSE}=m \log \ell ]
它的作用不只是“数学上好看”。
Softmax 可以写成:
[ \text{Softmax}(z_i)=e^{z_i-\text{LSE}} ]
这意味着,只要保存一个 LSE 标量,后面需要重算 Softmax 时,就可以用 logits 和 LSE 恢复概率,而不用保存完整的 (N\times N) Softmax 矩阵。
这对反向传播尤其重要。
传统 Attention 反传需要前向的 Softmax 矩阵 (P)。如果把它完整存 HBM,长上下文显存压力非常大。FlashAttention 选择前向只保存 LSE,反向时重算局部 (QK^T),再用:
[ P_i=e^{S_i-\text{LSE}} ]
恢复需要的概率块。
这是典型的 recomputation trade-off:
- 多做一点计算;
- 少存大量中间结果;
- 用算力换显存和带宽。
在今天的 GPU 上,这个方向通常是划算的。因为大多数 LLM 推理和训练瓶颈越来越偏向 memory-bound,而不是单纯缺 FLOPs。
LSE 在 Split-K、长上下文解码、Ring Attention 里也有价值。不同 shard 可以各自算局部 LSE 和局部输出,最后通过 LSE 做全局归并,而不是把完整概率矩阵搬来搬去。
八、FlashAttention 的演进不是只换循环顺序
FlashAttention v1 到 v2,一个重要变化是 work partitioning。
FA1 更偏向外层遍历 KV block、内层处理 Q block,容易带来中间状态写回和并行度不足的问题。
FA2 改成外层固定 Q tile,内层流式遍历 KV block。每个 CTA 负责一个 Q tile,把对应输出状态长驻寄存器,直到所有 KV 遍历完才写回。
这个变化带来的效果很现实:
- Q tile 的输出只写回一次;
- Online Softmax 状态可以更自然地留在寄存器;
- 并行维度扩展到 sequence block;
- 减少跨 warp/CTA 的中间归约。
当然,FA2 也有代价。不同 Q tile 的 CTA 可能会重复读取相同 KV block。实际性能依赖 L2 cache 命中,如果同一 wave 内 CTA 调度足够接近,后面的 CTA 可以从 L2 命中 KV,HBM 压力不会按理论上限放大。
但这也意味着它一定程度上依赖硬件调度行为。长上下文、L2 容量不足、SM 占用不均时,性能曲线会劣化。
后续 FA3、FA4、TRT-LLM、vLLM 高性能后端都在继续往更可控的方向做:更细的流水线、更明确的软件调度、更充分利用 TMA、WGMMA、Tensor Memory、多播能力。
这个方向挺符合 Infra 优化的演进规律:早期吃硬件 cache 红利,后期把不确定性收回到软件调度里。
九、采样也能从前缀和变成 argmax
生成阶段,模型输出 logits 后要采样下一个 token。
Temperature 实际上就是作用在 Softmax 前:
[ p_i=\text{Softmax}\left(\frac{z_i}{\tau}\right) ]
- (\tau) 越小,分布越尖锐,越接近贪心;
- (\tau) 越大,分布越平,随机性越强。
得到概率后,传统 multinomial sampling 会做前缀和:
- 生成随机数 (u\in[0,1])
- 计算概率累计和
- 找第一个超过 (u) 的位置
这个逻辑很直观,但对 GPU 不友好。词表现在动辄 128K、256K,前缀和有同步和串行依赖,尤其在张量并行切词表时,跨卡采样成本也不低。
Gumbel-Max Trick 提供了另一个等价形式:
[ \text{sample}=\arg\max_i(\log p_i g_i) ]
其中 (g_i) 是标准 Gumbel 噪声。
vLLM 里常见的实现变体更硬件友好:
# 示意代码
probs = logits.softmax(dim=-1)
q = torch.empty_like(probs)
q.exponential_() # q ~ Exp(1)
token = probs.div_(q).argmax(dim=-1)
为什么 p / q 可以?
指数分布变量可以由均匀分布生成:
[ q=-\log u ]
而:
[ \arg\max \frac{p_i}{q_i} ]
等价于:
[ \arg\max(\log p_i - \log q_i) ]
代入 (q=-\log u):
[ -\log q=-\log(-\log u) ]
这正是 Gumbel 噪声形式。所以它和按概率分布采样是等价的。
工程收益很明显:
| 采样方式 | 核心操作 | GPU 友好度 |
|---|---|---|
| Multinomial | prefix-sum search | 有串行依赖,同步重 |
| Gumbel-Max | element-wise noise argmax | 并行友好,规约简单 |
更重要的是,在张量并行下,Gumbel-Max 让采样退化成分布式 max reduce:
- 每张卡在本地 vocab shard 上算局部最大值和索引;
- 跨卡做一次 max-with-index all-reduce;
- 得到全局 token。
通信量从全量 logits 级别,降到每卡一个候选值和索引。这个差距在大词表下非常实在。
十、SFU 也是瓶颈
很多人看算子性能时,会盯着 Tensor Core 和 HBM,但 Softmax、Norm、RoPE 这类操作里还有一个容易被忽略的硬件单元:SFU,Special Function Unit。
像:
expsin/cosrsqrt
这类非线性函数通常不走普通 FP32 ALU 的 FMA 流水线,而是由 SFU 执行。
问题在于,SFU 数量远少于普通 ALU。以现代 NVIDIA GPU 为例,每个 SM 的 FP32 单元很多,但 SFU 通常少得多。于是 Softmax 里的 exp、Norm 里的 rsqrt、RoPE 的 sin/cos 都可能在局部成为瓶颈。
所以 vLLM 这类系统不会在 RoPE Kernel 里实时算 sin/cos,而是在初始化时预计算 cos_sin_cache。运行时只按 position 取表,再做逐元素乘加。
这是经典的空间换时间:
- 多存一张表;
- 少做大量 SFU 调用;
- 把超越函数变成普通访存和 FMA。
FlashAttention-4 甚至开始对 exp 做部分软件模拟:一部分仍走硬件 SFU,一部分用 ALU 通过多项式近似分担压力。
这不是因为软件算 exp 更优雅,而是 SFU 吞吐不够时,适当让 FMA 管线分担,整体吞吐反而更高。
当然不能 100% 软算。全软算会增加寄存器压力,可能导致 spill,最后反噬性能。这里仍然是权衡:SFU、ALU、寄存器、吞吐之间找平衡点。
十一、这些优化的共同逻辑
把 RMSNorm、Safe Softmax、Causal Mask、Online Softmax、LSE、Gumbel-Max 放在一起看,会发现它们不是孤立技巧。
它们的共同目标很清楚:
- 数值稳定例如 Softmax 减最大值,Attention 除以
sqrt(d_k),Norm 控制隐藏状态尺度。 - 减少 HBM 访问例如 FlashAttention 不落中间矩阵,反向用 LSE 重算 Softmax。
- 提高并行度例如 Gumbel-Max 把采样从 prefix-sum 改成 argmax。
- 降低同步和中间状态例如 FA2 让 Q tile 状态长驻寄存器,减少跨 block 归约。
- 用近似或经验结构换工程收益 例如 RMSNorm 去均值、去 bias,不一定是数学唯一解,但实际有效且硬件友好。
可以用一句话概括:
大模型 Infra 优化的核心,经常是在数学等价、数值稳定重写和可控近似之间找空间,把模型计算改造成更适合 GPU 的执行形态。
这里面没有银弹。每个优化都带着边界条件。
RMSNorm 适合现代 LLM 的结构组合,但不代表所有网络都该无脑替换 LayerNorm。FlashAttention 对长序列收益明显,但具体性能受 head dim、block size、GPU 架构、cache、调度影响。Gumbel-Max 对分布式采样很友好,但也要求随机数生成、数值精度和并行规约实现足够可靠。
工程上真正麻烦的地方,往往不是公式写不出来,而是公式落到硬件后,各种瓶颈会互相转移:省了 HBM,可能卡 SFU;减少同步,可能增加寄存器;提高并行度,可能降低 cache locality。
这也是做 AI Infra 最现实的部分。
当前 Transformer 架构在效果上已经足够强,Full Attention、MoE、KV Cache、Test-Time Scaling 这些方向都还在继续演进。但大量性能优化并不是重新发明模型,而是在不明显损失效果的前提下,把每个算子、每次访存、每个同步点压到极致。
大模型当然有很多不确定性,训练结果也常带实验科学的味道。但支撑它跑起来的这些数学和系统逻辑,反而非常确定。



























































