大模型显存:张量类型的底层原理

大模型显存:张量类型的底层原理

这篇文章回答两个问题:

  1. 浮点数核心结构、浮点误差的根本原因?
  2. 大模型显存:FP32、FP16、BF16等,是什么?

大模型训练与推理的显存占用,本质上就是海量张量(Tensor)的存储开销。理解 FP32/FP16/BF16 这些精度格式,需要从浮点数的底层二进制结构讲起。


一、浮点数的核心结构(IEEE 754)

所有浮点数在显存中都是以二进制位存储的,遵循 符号位 + 指数位 + 尾数位 的三段式结构:

$$
\text{Value} = (-1)^{\text{sign}} \times 2^{\text{exponent}-\text{bias}} \times (1.\text{mantissa})_2
$$

字段 作用 类比科学计数法
符号位 (Sign) 0为正,1为负 正负号
指数位 (Exponent) 表示数量级范围 $10^3$ 中的指数 3
尾数位 (Mantissa/Significand) 表示有效数字精度 $1.234$ 中的小数部分

以 FP32(单精度浮点)为例:

  • 32 bit = 1 bit 符号 + 8 bit 指数 + 23 bit 尾数
  • 指数偏移量 (Bias) = 127
  • 可表示范围:约 $\pm 3.4 \times 10^{38}$
  • 精度:约 7 位有效十进制数字

二、浮点误差的根本原因

大模型中的梯度消失、loss spike、参数更新不稳定,很多时候根源在浮点精度限制

1. 舍入误差(Rounding Error)

尾数位有限,无法精确表示大多数实数。例如:

  • 十进制的 $0.1$ 转为二进制是无限循环小数 $0.0001100110011…$
  • 只能截断存储,导致 $0.1 + 0.2 \neq 0.3$ 的奇观

2. 大数吃小数(Absorption)

当数量级差异过大时相加,小数会被”吞没”:

# FP32 示例
1e8 + 1 = 100000000.0  # 1 被吸收了

大模型中表现为:小梯度更新被大参数掩盖,底层参数几乎不更新。

3. 下溢(Underflow)与上溢(Overflow)

  • 下溢:绝对值太小,低于最小正规格化数,直接变成 0
    • FP16 最小正规格化数:$6.1 \times 10^{-5}$
    • 训练时很小的梯度直接变 0,参数停止更新
  • 上溢:绝对值太大,变成 Inf/NaN,导致训练崩溃

4. 非结合性

$(a + b) + c \neq a + (b + c)$,导致并行计算(如多卡 all-reduce)的累加顺序不同,结果有微小差异。这也是大模型训练难以完全复现的底层原因之一。


三、FP16 vs BF16 vs FP32 详解

FP32(Single Precision)

  • 存储:32 bit = 1 + 8 + 23
  • 显存占用:4 字节/参数
  • 适用场景:训练主权重(Master Weights)、高精度检查点保存
  • 问题:显存占用大,计算吞吐量低(Tensor Core 利用率不如低精度)

FP16(Half Precision / IEEE 754-2008)

  • 存储:16 bit = 1 + 5 + 10
  • 显存占用:2 字节/参数(相比 FP32 减半)
  • 问题
    • 动态范围极小:指数仅 5 bit,范围约 $6 \times 10^{-5}$ 到 $6.5 \times 10^4$
    • 容易上溢:大模型梯度很容易超过 $6.5 \times 10^4$
    • 容易下溢:很小的数直接变 0
  • 应对:需要 Loss Scaling(梯度放大保存,更新时缩回)来弥补下溢问题

BF16(Brain Floating Point / Google Brain 提出)

  • 存储:16 bit = 1 + 8 + 7
  • 显存占用:2 字节/参数(与 FP16 相同)
  • 关键设计指数位与 FP32 相同(8 bit),尾数砍半
  • 优势
    • 动态范围与 FP32 几乎一致:不易上溢/下溢
    • 无需 Loss Scaling:训练稳定性大幅提升
    • 硬件友好:现代 GPU(A100/H100)、TPU 原生支持
  • 代价:精度比 FP16 低(7 bit vs 10 bit 尾数),但深度学习对参数精度不敏感,对范围更敏感
格式 总位数 符号位 指数位 尾数位 动态范围 精度 显存/参数
FP32 32 1 8 23 4 字节
FP16 16 1 5 10 很小 2 字节
BF16 16 1 8 7 2 字节

四、大模型显存与张量底层原理

1. 显存占用的构成

以训练为例,显存中存储的张量包括:

  • 模型参数(Weights):$P \times \text{bytes\_per\_param}$
  • 梯度(Gradients):与参数同shape
  • 优化器状态(Optimizer States):如 Adam 需要存储一阶/二阶动量($2 \times P$)
  • 激活值(Activations):前向传播中间结果(与 batch size、序列长度相关)

混合精度训练(AMP)的典型显存分布

  • 主权重:FP32
  • 梯度/激活/临时buffer:FP16/BF16
  • Adam 状态:FP32

所以一个参数实际占用 > 16 字节(FP32 权重 4B + FP32 动量1 4B + FP32 动量2 4B + FP16 梯度 2B + …)。

2. 为什么 BF16 成为大模型训练标配?

  • 稳定性 > 精度:神经网络对微小扰动具有鲁棒性,但范围不足会直接让训练崩溃
  • 硬件加速:NVIDIA Ampere/Hopper 架构的 Tensor Core 对 BF16 的吞吐量是 FP32 的 8-16 倍
  • 显存减半:同样网络,BF16 可比 FP32 支持 2 倍大 的 batch size 或模型

3. 量化(INT8/INT4)的延伸

在推理阶段,进一步将 FP16/BF16 量化为 INT8/INT4:

  • 不再存储浮点结构,而是整数 + 缩放因子(Scale)
  • 核心思想:用线性映射 $x_{int} = \text{round}(x_{fp} / \text{scale})$ 压缩动态范围
  • 显存再减 50%-75%,但精度损失需要校准(Calibration)补偿

总结

  1. 浮点误差不可消除,只能 trade-off:用范围换精度(BF16)或用精度换范围(FP16)
  2. BF16 是目前大模型训练的最优解:保持 FP32 的动态范围,显存和吞吐量减半
  3. 显存瓶颈本质是位宽瓶颈:从 FP32 → FP16/BF16 → INT8 → INT4,每一步都是通过压缩单参数位宽来换取模型规模

如果你正在做训练调优,优先使用 BF16 + 梯度检查点(Gradient Checkpointing) ,这是性价比最高的显存优化组合。

参考链接:

感谢你的阅读


欢迎评论交流


忽如一夜春风来,千树万树梨花开。

暂无评论

发送评论 编辑评论


				
|´・ω・)ノ
ヾ(≧∇≦*)ゝ
(☆ω☆)
(╯‵□′)╯︵┴─┴
 ̄﹃ ̄
(/ω\)
∠( ᐛ 」∠)_
(๑•̀ㅁ•́ฅ)
→_→
୧(๑•̀⌄•́๑)૭
٩(ˊᗜˋ*)و
(ノ°ο°)ノ
(´இ皿இ`)
⌇●﹏●⌇
(ฅ´ω`ฅ)
(╯°A°)╯︵○○○
φ( ̄∇ ̄o)
ヾ(´・ ・`。)ノ"
( ง ᵒ̌皿ᵒ̌)ง⁼³₌₃
(ó﹏ò。)
Σ(っ °Д °;)っ
( ,,´・ω・)ノ"(´っω・`。)
╮(╯▽╰)╭
o(*////▽////*)q
>﹏<
( ๑´•ω•) "(ㆆᴗㆆ)
😂
😀
😅
😊
🙂
🙃
😌
😍
😘
😜
😝
😏
😒
🙄
😳
😡
😔
😫
😱
😭
💩
👻
🙌
🖕
👍
👫
👬
👭
🌚
🌝
🙈
💊
😶
🙏
🍦
🍉
😣
Source: github.com/k4yt3x/flowerhd
颜文字
Emoji
小恐龙
花!
上一篇