大模型显存:张量类型的底层原理
这篇文章回答两个问题:
- 浮点数核心结构、浮点误差的根本原因?
- 大模型显存:FP32、FP16、BF16等,是什么?
大模型训练与推理的显存占用,本质上就是海量张量(Tensor)的存储开销。理解 FP32/FP16/BF16 这些精度格式,需要从浮点数的底层二进制结构讲起。
一、浮点数的核心结构(IEEE 754)
所有浮点数在显存中都是以二进制位存储的,遵循 符号位 + 指数位 + 尾数位 的三段式结构:
$$
\text{Value} = (-1)^{\text{sign}} \times 2^{\text{exponent}-\text{bias}} \times (1.\text{mantissa})_2
$$
| 字段 | 作用 | 类比科学计数法 |
|---|---|---|
| 符号位 (Sign) | 0为正,1为负 | 正负号 |
| 指数位 (Exponent) | 表示数量级范围 | $10^3$ 中的指数 3 |
| 尾数位 (Mantissa/Significand) | 表示有效数字精度 | $1.234$ 中的小数部分 |
以 FP32(单精度浮点)为例:
- 32 bit = 1 bit 符号 + 8 bit 指数 + 23 bit 尾数
- 指数偏移量 (Bias) = 127
- 可表示范围:约 $\pm 3.4 \times 10^{38}$
- 精度:约 7 位有效十进制数字
二、浮点误差的根本原因
大模型中的梯度消失、loss spike、参数更新不稳定,很多时候根源在浮点精度限制。
1. 舍入误差(Rounding Error)
尾数位有限,无法精确表示大多数实数。例如:
- 十进制的 $0.1$ 转为二进制是无限循环小数 $0.0001100110011…$
- 只能截断存储,导致 $0.1 + 0.2 \neq 0.3$ 的奇观
2. 大数吃小数(Absorption)
当数量级差异过大时相加,小数会被”吞没”:
# FP32 示例
1e8 + 1 = 100000000.0 # 1 被吸收了
大模型中表现为:小梯度更新被大参数掩盖,底层参数几乎不更新。
3. 下溢(Underflow)与上溢(Overflow)
- 下溢:绝对值太小,低于最小正规格化数,直接变成 0
- FP16 最小正规格化数:$6.1 \times 10^{-5}$
- 训练时很小的梯度直接变 0,参数停止更新
- 上溢:绝对值太大,变成 Inf/NaN,导致训练崩溃
4. 非结合性
$(a + b) + c \neq a + (b + c)$,导致并行计算(如多卡 all-reduce)的累加顺序不同,结果有微小差异。这也是大模型训练难以完全复现的底层原因之一。
三、FP16 vs BF16 vs FP32 详解
FP32(Single Precision)
- 存储:32 bit = 1 + 8 + 23
- 显存占用:4 字节/参数
- 适用场景:训练主权重(Master Weights)、高精度检查点保存
- 问题:显存占用大,计算吞吐量低(Tensor Core 利用率不如低精度)
FP16(Half Precision / IEEE 754-2008)
- 存储:16 bit = 1 + 5 + 10
- 显存占用:2 字节/参数(相比 FP32 减半)
- 问题:
- 动态范围极小:指数仅 5 bit,范围约 $6 \times 10^{-5}$ 到 $6.5 \times 10^4$
- 容易上溢:大模型梯度很容易超过 $6.5 \times 10^4$
- 容易下溢:很小的数直接变 0
- 应对:需要 Loss Scaling(梯度放大保存,更新时缩回)来弥补下溢问题
BF16(Brain Floating Point / Google Brain 提出)
- 存储:16 bit = 1 + 8 + 7
- 显存占用:2 字节/参数(与 FP16 相同)
- 关键设计:指数位与 FP32 相同(8 bit),尾数砍半
- 优势:
- 动态范围与 FP32 几乎一致:不易上溢/下溢
- 无需 Loss Scaling:训练稳定性大幅提升
- 硬件友好:现代 GPU(A100/H100)、TPU 原生支持
- 代价:精度比 FP16 低(7 bit vs 10 bit 尾数),但深度学习对参数精度不敏感,对范围更敏感
| 格式 | 总位数 | 符号位 | 指数位 | 尾数位 | 动态范围 | 精度 | 显存/参数 |
|---|---|---|---|---|---|---|---|
| FP32 | 32 | 1 | 8 | 23 | 大 | 高 | 4 字节 |
| FP16 | 16 | 1 | 5 | 10 | 很小 | 中 | 2 字节 |
| BF16 | 16 | 1 | 8 | 7 | 大 | 低 | 2 字节 |
四、大模型显存与张量底层原理
1. 显存占用的构成
以训练为例,显存中存储的张量包括:
- 模型参数(Weights):$P \times \text{bytes\_per\_param}$
- 梯度(Gradients):与参数同shape
- 优化器状态(Optimizer States):如 Adam 需要存储一阶/二阶动量($2 \times P$)
- 激活值(Activations):前向传播中间结果(与 batch size、序列长度相关)
混合精度训练(AMP)的典型显存分布:
- 主权重:FP32
- 梯度/激活/临时buffer:FP16/BF16
- Adam 状态:FP32
所以一个参数实际占用 > 16 字节(FP32 权重 4B + FP32 动量1 4B + FP32 动量2 4B + FP16 梯度 2B + …)。
2. 为什么 BF16 成为大模型训练标配?
- 稳定性 > 精度:神经网络对微小扰动具有鲁棒性,但范围不足会直接让训练崩溃
- 硬件加速:NVIDIA Ampere/Hopper 架构的 Tensor Core 对 BF16 的吞吐量是 FP32 的 8-16 倍
- 显存减半:同样网络,BF16 可比 FP32 支持 2 倍大 的 batch size 或模型
3. 量化(INT8/INT4)的延伸
在推理阶段,进一步将 FP16/BF16 量化为 INT8/INT4:
- 不再存储浮点结构,而是整数 + 缩放因子(Scale)
- 核心思想:用线性映射 $x_{int} = \text{round}(x_{fp} / \text{scale})$ 压缩动态范围
- 显存再减 50%-75%,但精度损失需要校准(Calibration)补偿
总结
- 浮点误差不可消除,只能 trade-off:用范围换精度(BF16)或用精度换范围(FP16)
- BF16 是目前大模型训练的最优解:保持 FP32 的动态范围,显存和吞吐量减半
- 显存瓶颈本质是位宽瓶颈:从 FP32 → FP16/BF16 → INT8 → INT4,每一步都是通过压缩单参数位宽来换取模型规模
如果你正在做训练调优,优先使用 BF16 + 梯度检查点(Gradient Checkpointing) ,这是性价比最高的显存优化组合。