FP8 使用指南#
JAX 支持多种 FP8 格式,包括 E4M3 (jnp.float8_e4m3fn) 和 E5M2 (jnp.float8_e5m2)。由于 FP8 数据类型的范围有限,必须对高精度数据进行缩放以使其适合 FP8 可表示的范围,此过程称为量化 (Q)。相反,反量化 (DQ) 将 FP8 数据重新缩放回其原始类型。
尽管 jnp.dot 支持 FP8 输入,但某些限制使其在实际应用中不切实际。或者,我们的编译器 XLA 可以识别诸如以下模式:
本教程将引导您了解如何使用它的基础知识。
设置我们的环境#
在这里,我们提供为我们的笔记本设置环境所必需的代码。此外,我们定义一个函数来检查经过 XLA 优化的 HLO 是否真的会在底层调用 FP8 点运算。
注意:本教程依赖于 XLA-FP8 功能,该功能仅在 NVIDIA Hopper GPU 或更高版本上受支持。
import flax
import jax
import re
import pprint
from jax import random
from jax import numpy as jnp
from jax._src import test_util as jtu
from flax import linen as nn
from flax.linen import fp8_ops
e4m3 = jnp.float8_e4m3fn
e5m2 = jnp.float8_e5m2
f32 = jnp.float32
E4M3_MAX = jnp.finfo(e4m3).max.astype(f32)
assert jtu.is_cuda_compute_capability_at_least("9.0")
def check_fp8_call(lowered):
hlo = lowered.compile()
if re.search(r"custom-call\(f8e4m3fn.*, f8e4m3fn.*", hlo.as_text()):
print("Fp8 call detected!")
else:
print("No Fp8 call!")
FLAX 低级 API#
JAX 点运算(例如 jnp.dot
)支持 FP8 dtype 输入。因此,进行以下调用是合法的
key = random.key(0)
A = random.uniform(key, (16, 32))
B = random.uniform(key, (32, 64))
@jax.jit
def dot_fp8(A, B):
return jnp.dot(A.astype(e4m3), B.astype(e4m3), preferred_element_type=f32)
check_fp8_call(dot_fp8.lower(A, B))
但是,此方法存在两个主要问题。首先,jnp.dot
不接受操作数的缩放因子,默认缩放因子为 1.0。其次,它不支持混合 FP8 数据类型的操作数。例如,当操作数为 E5M2 和 E4M3 时,点积使用提升的 FP16 数据类型执行。
在实际场景中,必须指定缩放因子,无论是来自推理的校准还是训练期间用户定义的算法。此外,通常的做法是为梯度使用 E5M2,为激活和内核使用 E4M3。这些限制使得此方法在实际应用中不太实用。
为了解决这些限制并创建更通用的 FP8 点积,我们建议利用 XLA-FP8。让我们从一个简单的缩放策略开始。
当前缩放#
缩放因子通常定义为 scale = amax(x) / MAX
,其中 amax
是查找张量绝对最大值的操作,MAX
是目标 dtype 的可表示范围的最大值。这种缩放方法允许我们直接从点积的当前操作数张量导出缩放因子。
@jax.jit
def dot_fp8(A, B):
A_scale = jnp.max(jnp.abs(A)) / E4M3_MAX
B_scale = jnp.max(jnp.abs(B)) / E4M3_MAX
A = fp8_ops.quantize_dequantize(A, e4m3, A_scale, f32)
B = fp8_ops.quantize_dequantize(B, e4m3, B_scale, f32)
C = jnp.dot(A, B)
return C
C = dot_fp8(A, B)
check_fp8_call(dot_fp8.lower(A, B))
如代码所示,我们对点积的操作数执行伪量化 (fp8_ops.quantize_dequantize
)。尽管 jnp.dot
仍然处理更高精度的输入,但 XLA 会检测到此模式并将点运算重写为 FP8 点调用(例如,GPU 的 cublasLt 调用)。此方法有效地模仿了第一个示例,但提供了更大的灵活性。我们可以控制输入 dtypes(此处都设置为 E4M3,但我们可以使用混合的 E4M3 和 E5M2)并定义缩放因子,XLA 可以检测到这些缩放因子并在点后端中使用。
当前缩放方法的一个主要问题是计算 A_scale
和 B_scale
引入的开销,这需要额外加载操作数张量。为了克服这个问题,我们建议采用延迟缩放。
延迟缩放#
在延迟缩放中,我们使用与 amax 历史记录关联的缩放因子。缩放因子仍然是标量,但 amax 历史记录是一个列表,存储最近步骤中的 amax 值(例如,1024 步)。这两个张量都是从之前的步骤计算出来的,并保存在模型参数中。
激活和权重的延迟缩放伪量化由 fp8_ops.in_qdq
提供,梯度的延迟缩放伪量化由 fp8_ops.out_qdq
提供。
a_scale = jnp.array(1.0)
b_scale = jnp.array(1.0)
g_scale = jnp.array(1.0)
a_amax_hist = jnp.zeros((1024,))
b_amax_hist = jnp.zeros((1024,))
g_amax_hist = jnp.zeros((1024,))
@jax.jit
def dot_fp8(a, a_scale, a_amax_hist, b, b_scale, b_amax_hist,
g_scale, g_amax_hist):
a = fp8_ops.in_qdq(f32, e4m3, a, a_scale, a_amax_hist)
b = fp8_ops.in_qdq(f32, e4m3, b, b_scale, b_amax_hist)
c = jnp.dot(a, b)
c = fp8_ops.out_qdq(f32, e5m2, c, g_scale, g_amax_hist)
return c
C = dot_fp8(A, a_scale, a_amax_hist, B, b_scale, b_amax_hist,
g_scale, g_amax_hist)
check_fp8_call(dot_fp8.lower(A, a_scale, a_amax_hist, B, b_scale, b_amax_hist,
g_scale, g_amax_hist))
在此示例中,我们首先准备三对缩放因子和 amax 历史记录,将它们视为从先前步骤计算出的结果。然后,我们将 fp8_ops.in_qdq
应用于 jnp.dot
的输入操作数,然后将 fp8_ops.out_qdq
应用于 jnp.dot
的输出。请注意,fp8_ops.out_qdq
将通过 custom_vjp 函数对输出的梯度应用伪量化。新的缩放因子和 amax 历史记录将通过它们的梯度返回,这将在下一节中介绍。
FLAX 高级 API#
使用 FLAX 库,将 FP8 操作合并到现有的 FLAX 层中是一个无缝的过程。用户无需操作低级 API 进行量化。相反,他们可以使用直接的“代码注入”方法将提供的自定义 FP8 点 (fp8_ops.Fp8DotGeneralOp
) 集成到 FLAX 层中。此自定义操作封装了所有与 FP8 相关的任务,包括量化-反量化操作的位置、更新缩放因子的算法以及正向和反向传播的 FP8 dtype 组合的选择。
考虑以下示例
model = nn.Dense(features=64, dot_general_cls=fp8_ops.Fp8DotGeneralOp)
params = model.init(key, A)
@jax.jit
def train_step(var, a):
c = model.apply(var, a)
return jnp.sum(c)
check_fp8_call(train_step.lower(params, A))
在此示例中,我们只需设置 dot_general_cls=fp8_ops.Fp8DotGeneralOp
即可使 Dense 层能够利用 FP8 点运算。模型的使用方式与之前几乎相同。主要区别在于添加了一类新的参数:缩放因子和 amax 历史记录集。在下一节中,我们将探讨如何更新这些参数。
操作 FP8 参数#
让我们首先检查 params
的数据结构。在下面的代码中,我们删除参数值,然后显示 PyTree 结构。
params_structure = flax.core.unfreeze(params).copy()
params_structure = flax.traverse_util.flatten_dict(params_structure, sep='/')
for key, value in params_structure.items():
params_structure[key] = '*'
params_structure = flax.traverse_util.unflatten_dict(params_structure, sep='/')
pprint.pprint(params_structure)
输出如下
{'_overwrite_with_gradient': {'Fp8DotGeneralOp_0': {'input_amax_history': '*',
'input_scale': '*',
'kernel_amax_history': '*',
'kernel_scale': '*',
'output_grad_amax_history': '*',
'output_grad_scale': '*'}},
'params': {'bias': '*', 'kernel': '*'}}
除了预期的 params
之外,还有一个名为 _overwrite_with_gradient
的附加类别。此类别包括三对分别用于激活、内核和点梯度的 amax_history
和 scale
。
更新 FP8 参数的梯度#
现在,我们执行一个训练步骤以获得梯度,并了解如何使用它们来更新参数。
step_fn = jax.jit(jax.grad(train_step, (0, 1)))
grads = step_fn(params, A)
params = flax.core.unfreeze(params)
params = flax.traverse_util.flatten_dict(params, sep='/')
grads = flax.traverse_util.flatten_dict(grads[0], sep='/')
for key, value in params.items():
if key.startswith('params'):
params[key] = value + 0.01 * grads[key]
if key.startswith('_overwrite_with_gradient'):
params[key] = grads[key]
params = flax.traverse_util.unflatten_dict(params, sep='/')
params = flax.core.freeze(params)
上面的代码演示了如何更新 params
和 _overwrite_with_gradient
。对于 params
,我们使用公式 new_param = old_param + 0.01 * grads
,其中 0.01
是学习率(或者用户可以使用 optax
中的任何优化器)。对于 _overwrite_with_gradient
,我们只需使用梯度覆盖旧值。
请注意,flax.training.train_state.TrainState
方便地支持 _overwrite_with_gradient
类别,因此如果用户不使用自定义 TrainState
,则无需修改其脚本。
累积 FP8 参数的梯度#
当同一个参数以分支方式使用时,自动求导机制会将这些分支的梯度相加。这在诸如流水线并行等场景中很常见,在这些场景中,每个微批次共享同一组用于小批次的参数。然而,对于 _overwrite_with_gradient
参数,这种通过加法进行的累积是没有意义的。相反,我们更倾向于通过取最大值进行自定义累积。
为了解决这个问题,我们引入了一个自定义的数据类型 fp8_ops.fp32_max_grad
。基本用法如下所示
fmax32 = fp8_ops.fp32_max_grad
def reuse_fp8_param(x, y, scale, amax_history):
scale = scale.astype(fmax32)
amax_history = amax_history.astype(fmax32)
x = fp8_ops.in_qdq(f32, e4m3, x, scale, amax_history)
y = fp8_ops.in_qdq(f32, e4m3, y, scale, amax_history)
return x + y
reuse_fp8_param_fn = jax.grad(reuse_fp8_param, (0, 1, 2, 3))
reuse_fp8_param_fn = jax.jit(reuse_fp8_param_fn)
_, _, new_ah, new_sf = reuse_fp8_param_fn(2.0, 3.0, a_scale, a_amax_hist)
print(new_ah, new_sf)
在这个例子中,我们首先将 scale
和 amax_history
转换为 fp8_ops.fp32_max_grad
,然后使用同一对 scale
和 amax_history
调用两次 fp8_ops.in_qdq
。在自动求导过程中,它们来自每个分支的梯度将取最大值,从而得到正确的结果
1.0 [3. 0. 0. ... 0. 0. 0.]
如果我们不进行类型转换,我们会得到以下结果,这意味着两个分支的梯度被相加了
2.0 [5. 0. 0. ... 0. 0. 0.]
如果用户选择使用高级 API,则此转换已包含在内。