FP8 使用指南#

JAX 支持多种 FP8 格式,包括 E4M3 (jnp.float8_e4m3fn) 和 E5M2 (jnp.float8_e5m2)。由于 FP8 数据类型的范围有限,必须对高精度数据进行缩放以使其适合 FP8 可表示的范围,此过程称为量化 (Q)。相反,反量化 (DQ) 将 FP8 数据重新缩放回其原始类型。

尽管 jnp.dot 支持 FP8 输入,但某些限制使其在实际应用中不切实际。或者,我们的编译器 XLA 可以识别诸如以下模式:->DQ->Dot 并随后调用 FP8 后端(例如,用于 GPU 的 cublasLt)。FLAX 将此类模式封装到 nn.fp8_ops.Fp8DotGeneralOp 模块中,允许用户轻松地为现有层(例如,nn.Dense)配置它。

本教程将引导您了解如何使用它的基础知识。

设置我们的环境#

在这里,我们提供为我们的笔记本设置环境所必需的代码。此外,我们定义一个函数来检查经过 XLA 优化的 HLO 是否真的会在底层调用 FP8 点运算。

注意:本教程依赖于 XLA-FP8 功能,该功能仅在 NVIDIA Hopper GPU 或更高版本上受支持。

import flax
import jax
import re
import pprint
from jax import random
from jax import numpy as jnp
from jax._src import test_util as jtu
from flax import linen as nn
from flax.linen import fp8_ops

e4m3 = jnp.float8_e4m3fn
e5m2 = jnp.float8_e5m2
f32 = jnp.float32
E4M3_MAX = jnp.finfo(e4m3).max.astype(f32)

assert jtu.is_cuda_compute_capability_at_least("9.0")

def check_fp8_call(lowered):
  hlo = lowered.compile()
  if re.search(r"custom-call\(f8e4m3fn.*, f8e4m3fn.*", hlo.as_text()):
    print("Fp8 call detected!")
  else:
    print("No Fp8 call!")

FLAX 低级 API#

JAX 点运算(例如 jnp.dot)支持 FP8 dtype 输入。因此,进行以下调用是合法的

key = random.key(0)
A = random.uniform(key, (16, 32))
B = random.uniform(key, (32, 64))
@jax.jit
def dot_fp8(A, B):
  return jnp.dot(A.astype(e4m3), B.astype(e4m3), preferred_element_type=f32)
check_fp8_call(dot_fp8.lower(A, B))

但是,此方法存在两个主要问题。首先,jnp.dot 不接受操作数的缩放因子,默认缩放因子为 1.0。其次,它不支持混合 FP8 数据类型的操作数。例如,当操作数为 E5M2 和 E4M3 时,点积使用提升的 FP16 数据类型执行。

在实际场景中,必须指定缩放因子,无论是来自推理的校准还是训练期间用户定义的算法。此外,通常的做法是为梯度使用 E5M2,为激活和内核使用 E4M3。这些限制使得此方法在实际应用中不太实用。

为了解决这些限制并创建更通用的 FP8 点积,我们建议利用 XLA-FP8。让我们从一个简单的缩放策略开始。

当前缩放#

缩放因子通常定义为 scale = amax(x) / MAX,其中 amax 是查找张量绝对最大值的操作,MAX 是目标 dtype 的可表示范围的最大值。这种缩放方法允许我们直接从点积的当前操作数张量导出缩放因子。

@jax.jit
def dot_fp8(A, B):
  A_scale = jnp.max(jnp.abs(A)) / E4M3_MAX
  B_scale = jnp.max(jnp.abs(B)) / E4M3_MAX
  A = fp8_ops.quantize_dequantize(A, e4m3, A_scale, f32)
  B = fp8_ops.quantize_dequantize(B, e4m3, B_scale, f32)

  C = jnp.dot(A, B)
  return C

C = dot_fp8(A, B)
check_fp8_call(dot_fp8.lower(A, B))

如代码所示,我们对点积的操作数执行伪量化 (fp8_ops.quantize_dequantize)。尽管 jnp.dot 仍然处理更高精度的输入,但 XLA 会检测到此模式并将点运算重写为 FP8 点调用(例如,GPU 的 cublasLt 调用)。此方法有效地模仿了第一个示例,但提供了更大的灵活性。我们可以控制输入 dtypes(此处都设置为 E4M3,但我们可以使用混合的 E4M3 和 E5M2)并定义缩放因子,XLA 可以检测到这些缩放因子并在点后端中使用。

当前缩放方法的一个主要问题是计算 A_scaleB_scale 引入的开销,这需要额外加载操作数张量。为了克服这个问题,我们建议采用延迟缩放。

延迟缩放#

在延迟缩放中,我们使用与 amax 历史记录关联的缩放因子。缩放因子仍然是标量,但 amax 历史记录是一个列表,存储最近步骤中的 amax 值(例如,1024 步)。这两个张量都是从之前的步骤计算出来的,并保存在模型参数中。

激活和权重的延迟缩放伪量化由 fp8_ops.in_qdq 提供,梯度的延迟缩放伪量化由 fp8_ops.out_qdq 提供。

a_scale = jnp.array(1.0)
b_scale = jnp.array(1.0)
g_scale = jnp.array(1.0)
a_amax_hist = jnp.zeros((1024,))
b_amax_hist = jnp.zeros((1024,))
g_amax_hist = jnp.zeros((1024,))

@jax.jit
def dot_fp8(a, a_scale, a_amax_hist, b, b_scale, b_amax_hist,
            g_scale, g_amax_hist):
  a = fp8_ops.in_qdq(f32, e4m3, a, a_scale, a_amax_hist)
  b = fp8_ops.in_qdq(f32, e4m3, b, b_scale, b_amax_hist)
  
  c = jnp.dot(a, b)
  c = fp8_ops.out_qdq(f32, e5m2, c, g_scale, g_amax_hist)
  return c

C = dot_fp8(A, a_scale, a_amax_hist, B, b_scale, b_amax_hist,
            g_scale, g_amax_hist)
check_fp8_call(dot_fp8.lower(A, a_scale, a_amax_hist, B, b_scale, b_amax_hist,
                             g_scale, g_amax_hist))

在此示例中,我们首先准备三对缩放因子和 amax 历史记录,将它们视为从先前步骤计算出的结果。然后,我们将 fp8_ops.in_qdq 应用于 jnp.dot 的输入操作数,然后将 fp8_ops.out_qdq 应用于 jnp.dot 的输出。请注意,fp8_ops.out_qdq 将通过 custom_vjp 函数对输出的梯度应用伪量化。新的缩放因子和 amax 历史记录将通过它们的梯度返回,这将在下一节中介绍。

FLAX 高级 API#

使用 FLAX 库,将 FP8 操作合并到现有的 FLAX 层中是一个无缝的过程。用户无需操作低级 API 进行量化。相反,他们可以使用直接的“代码注入”方法将提供的自定义 FP8 点 (fp8_ops.Fp8DotGeneralOp) 集成到 FLAX 层中。此自定义操作封装了所有与 FP8 相关的任务,包括量化-反量化操作的位置、更新缩放因子的算法以及正向和反向传播的 FP8 dtype 组合的选择。

考虑以下示例

model = nn.Dense(features=64, dot_general_cls=fp8_ops.Fp8DotGeneralOp)
params = model.init(key, A)

@jax.jit
def train_step(var, a): 
  c = model.apply(var, a)
  return jnp.sum(c)

check_fp8_call(train_step.lower(params, A))

在此示例中,我们只需设置 dot_general_cls=fp8_ops.Fp8DotGeneralOp 即可使 Dense 层能够利用 FP8 点运算。模型的使用方式与之前几乎相同。主要区别在于添加了一类新的参数:缩放因子和 amax 历史记录集。在下一节中,我们将探讨如何更新这些参数。

操作 FP8 参数#

让我们首先检查 params 的数据结构。在下面的代码中,我们删除参数值,然后显示 PyTree 结构。

params_structure = flax.core.unfreeze(params).copy()
params_structure = flax.traverse_util.flatten_dict(params_structure, sep='/')
for key, value in params_structure.items():
    params_structure[key] = '*'
params_structure = flax.traverse_util.unflatten_dict(params_structure, sep='/')
pprint.pprint(params_structure)

输出如下

{'_overwrite_with_gradient': {'Fp8DotGeneralOp_0': {'input_amax_history': '*',
                                                    'input_scale': '*',
                                                    'kernel_amax_history': '*',
                                                    'kernel_scale': '*',
                                                    'output_grad_amax_history': '*',
                                                    'output_grad_scale': '*'}},
 'params': {'bias': '*', 'kernel': '*'}}

除了预期的 params 之外,还有一个名为 _overwrite_with_gradient 的附加类别。此类别包括三对分别用于激活、内核和点梯度的 amax_historyscale

更新 FP8 参数的梯度#

现在,我们执行一个训练步骤以获得梯度,并了解如何使用它们来更新参数。

step_fn = jax.jit(jax.grad(train_step, (0, 1)))

grads = step_fn(params, A)

params = flax.core.unfreeze(params)
params = flax.traverse_util.flatten_dict(params, sep='/')
grads = flax.traverse_util.flatten_dict(grads[0], sep='/')

for key, value in params.items():
  if key.startswith('params'):
    params[key] = value + 0.01 * grads[key]
  if key.startswith('_overwrite_with_gradient'):
    params[key] = grads[key]

params = flax.traverse_util.unflatten_dict(params, sep='/')
params = flax.core.freeze(params)

上面的代码演示了如何更新 params_overwrite_with_gradient。对于 params,我们使用公式 new_param = old_param + 0.01 * grads,其中 0.01 是学习率(或者用户可以使用 optax 中的任何优化器)。对于 _overwrite_with_gradient,我们只需使用梯度覆盖旧值。

请注意,flax.training.train_state.TrainState 方便地支持 _overwrite_with_gradient 类别,因此如果用户不使用自定义 TrainState,则无需修改其脚本。

累积 FP8 参数的梯度#

当同一个参数以分支方式使用时,自动求导机制会将这些分支的梯度相加。这在诸如流水线并行等场景中很常见,在这些场景中,每个微批次共享同一组用于小批次的参数。然而,对于 _overwrite_with_gradient 参数,这种通过加法进行的累积是没有意义的。相反,我们更倾向于通过取最大值进行自定义累积。

为了解决这个问题,我们引入了一个自定义的数据类型 fp8_ops.fp32_max_grad。基本用法如下所示

fmax32 = fp8_ops.fp32_max_grad

def reuse_fp8_param(x, y, scale, amax_history):
  scale = scale.astype(fmax32)
  amax_history = amax_history.astype(fmax32)

  x = fp8_ops.in_qdq(f32, e4m3, x, scale, amax_history)
  y = fp8_ops.in_qdq(f32, e4m3, y, scale, amax_history)
  return x + y

reuse_fp8_param_fn = jax.grad(reuse_fp8_param, (0, 1, 2, 3))
reuse_fp8_param_fn = jax.jit(reuse_fp8_param_fn)

_, _, new_ah, new_sf = reuse_fp8_param_fn(2.0, 3.0, a_scale, a_amax_hist)
print(new_ah, new_sf)

在这个例子中,我们首先将 scaleamax_history 转换为 fp8_ops.fp32_max_grad,然后使用同一对 scaleamax_history 调用两次 fp8_ops.in_qdq。在自动求导过程中,它们来自每个分支的梯度将取最大值,从而得到正确的结果

1.0 [3. 0. 0. ... 0. 0. 0.]

如果我们不进行类型转换,我们会得到以下结果,这意味着两个分支的梯度被相加了

2.0 [5. 0. 0. ... 0. 0. 0.]

如果用户选择使用高级 API,则此转换已包含在内。