在多设备上扩展 Flax 模块#

本指南演示如何使用 jax.jit(以前为 experimental.pjit)和 flax.linen 在多设备和主机上扩展 Flax 模块

扩展的 Flax 和 jax.jit#

jax.jit 遵循 单程序多数据 (SPMD) 范式,并自动编译代码以在多个设备上运行。您只需要指定希望如何对代码的输入和输出进行分区,编译器将计算出如何:1) 对内部的所有内容进行分区;2) 编译设备间通信。

Flax 提供了多种功能,可以帮助您在 Flax 模块 上使用自动 SPMD,包括

  1. 在定义 flax.linen.Module 时指定数据分区的接口。

  2. 用于生成 jax.jit 运行所需的分片信息的实用函数。

  3. 一个用于自定义轴名称的接口,称为“逻辑轴注释”,以解耦模块代码和分区计划,从而更轻松地尝试不同的分区布局。

您可以在 JAX 的文档站点上的 多进程环境中的 JAX分布式数组和自动并行化 中了解有关用于扩展的 jax.jit API 的更多信息。

设置#

导入一些必要的依赖项。

注意: 本指南使用 --xla_force_host_platform_device_count=8 标志来模拟 Google Colab/Jupyter Notebook 中 CPU 环境中的多个设备。如果您已经在使用多设备 TPU 环境,则不需要此标志。

# Once Flax v0.6.10 is released, there is no need to do this.
# ! pip3 install -qq "git+https://github.com/google/flax.git@main#egg=flax"
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
import functools
from typing import Optional, Callable

import numpy as np
import jax
from jax import lax, random, numpy as jnp

import flax
from flax import struct, traverse_util, linen as nn
from flax.core import freeze, unfreeze
from flax.training import train_state, checkpoints

import optax # Optax for common losses and optimizers.
2024-12-17 20:00:37.852447: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1734465637.871995     915 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1734465637.877913     915 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
print(f'We have 8 fake JAX devices now: {jax.devices()}')
We have 8 fake JAX devices now: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]

以下代码演示了如何按照 JAX 的 分布式数组和自动并行化 指南导入和设置 JAX 级别的设备 API

  1. 使用 JAX 的 mesh_utils.create_device_mesh 启动一个 2x4 设备 mesh(8 个设备)。此布局与 TPU v3-8 的布局相同。

  2. 使用 jax.sharding.Mesh 中的 axis_names 参数为每个轴添加名称注释。注释轴名称的典型方法是 axis_name=('data', 'model'),其中

  • 'data':用于对输入和激活的批处理维度进行数据并行分片的网格维度。

  • 'model':用于在设备之间对模型参数进行分片的网格维度。

  1. 创建一个简单的实用函数 mesh_sharding,用于从网格和任何布局生成分片对象。

from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.lax import with_sharding_constraint
from jax.experimental import mesh_utils
# Create a mesh and annotate each axis with a name.
device_mesh = mesh_utils.create_device_mesh((2, 4))
print(device_mesh)

mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
print(mesh)

def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
  return NamedSharding(mesh, pspec)
[[CpuDevice(id=0) CpuDevice(id=1) CpuDevice(id=2) CpuDevice(id=3)]
 [CpuDevice(id=4) CpuDevice(id=5) CpuDevice(id=6) CpuDevice(id=7)]]
Mesh('data': 2, 'model': 4)

定义层#

在定义一个简单模型之前,创建一个名为 DotReluDot 的示例层(通过继承 flax.linen.Module)。该层创建两个参数 W1W2 用于点积乘法,并使用 jax.nn.relu (ReLU) 激活函数进行分隔。

为了有效地分割参数,应用以下 API 来注释参数和中间变量

  1. 在创建子层或原始参数时,使用 flax.linen.with_partitioning 修饰初始化函数。

  2. 应用 jax.lax.with_sharding_constraint(以前为 pjit.with_sharding_constraint)来注释中间变量,如 yz,以便在已知理想约束时强制使用特定的分片模式。

  • 此步骤是可选的,但有时可以帮助自动 SPMD 有效地分区。在下面的示例中,不需要此调用,因为 XLA 将为 yz 计算出相同分片布局,而不管怎样。

class DotReluDot(nn.Module):
  depth: int
  dense_init: Callable = nn.initializers.xavier_normal()
  @nn.compact
  def __call__(self, x):

    y = nn.Dense(self.depth,
                 kernel_init=nn.with_partitioning(self.dense_init, (None, 'model')),
                 use_bias=False,  # or overwrite with `bias_init`
                 )(x)

    y = jax.nn.relu(y)
    # Force a local sharding annotation.
    y = with_sharding_constraint(y, mesh_sharding(PartitionSpec('data', 'model')))

    W2 = self.param(
        'W2',
        nn.with_partitioning(self.dense_init, ('model', None)),
        (self.depth, x.shape[-1]))

    z = jnp.dot(y, W2)
    # Force a local sharding annotation.
    z = with_sharding_constraint(z, mesh_sharding(PartitionSpec('data', None)))

    # Return a tuple to conform with the API `flax.linen.scan` as shown in the cell below.
    return z, None

请注意,设备轴名称(如 'data''model'None)将传递到 flax.linen.with_partitioningjax.lax.with_sharding_constraint API 调用中。这指的是应如何分割此数据的每个维度 — 跨设备网格维度之一,或者根本不分割。

例如

  • 当您定义形状为 (x.shape[-1], self.depth) 并注释为 (None, 'model')W1

    • 第一个维度(长度为 x.shape[-1])将在所有设备上复制。

    • 第二个维度(长度为 self.depth)将通过设备网格的 'model' 轴进行分片。这意味着在此维度上,W1 将在设备 (0, 4)(1, 5)(2, 6)(3, 7) 上进行 4 向分片。

  • 当您将输出 z 注释为 ('data', None)

    • 第一个维度,即批次维度,将在 'data' 轴上进行分片。这意味着一半的批次将在设备 0-3(前四个设备)上处理,另一半则在设备 4-7(其余四个设备)上处理。

    • 第二个维度,即数据深度维度,将在所有设备上进行复制。

定义一个使用 flax.linen.scan 提升转换的模型#

创建了 DotReluDot 之后,您现在可以定义 MLP 模型(通过继承 flax.linen.Module)作为多个 DotReluDot 层。

要复制相同的层,您可以使用 flax.linen.scan,也可以使用 for 循环

  • flax.linen.scan 可以提供更快的编译时间。

  • for 循环在运行时可能更快。

以下代码展示了如何应用这两种方法,并默认使用 for 循环,以便所有参数都是二维的,并且您可以可视化它们的分片。

flax.linen.scan 代码只是为了展示此 API 如何与 Flax 提升转换一起使用。

class MLP(nn.Module):
  num_layers: int
  depth: int
  use_scan: bool
  @nn.compact
  def __call__(self, x):
    if self.use_scan:
      x, _ = nn.scan(DotReluDot, length=self.num_layers,
                     variable_axes={"params": 0},
                     split_rngs={"params": True},
                     metadata_params={nn.PARTITION_NAME: None}
                     )(self.depth)(x)
    else:
      for i in range(self.num_layers):
        x, _ = DotReluDot(self.depth)(x)
    return x

现在,创建一个 model 实例和一个样本输入 x

# MLP hyperparameters.
BATCH, LAYERS, DEPTH, USE_SCAN = 8, 4, 1024, False
# Create fake inputs.
x = jnp.ones((BATCH, DEPTH))
# Initialize a PRNG key.
k = random.key(0)

# Create an Optax optimizer.
optimizer = optax.adam(learning_rate=0.001)
# Instantiate the model.
model = MLP(LAYERS, DEPTH, USE_SCAN)

指定分片#

接下来,您需要告诉 jax.jit 如何在设备之间分片数据。

输入的sharding#

对于数据并行,您可以通过将批次轴表示为 'data',在 data 轴上分片批次的 *输入* x。然后,使用 jax.device_put 将其放置在正确的 device 上。

x_sharding = mesh_sharding(PartitionSpec('data', None)) # dimensions: (batch, length)
x = jax.device_put(x, x_sharding)
jax.debug.visualize_array_sharding(x)
                                                                                
                                                                                
                                  CPU 0,1,2,3                                   
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                                  CPU 4,5,6,7                                   
                                                                                
                                                                                
                                                                                

输出的sharding#

您需要编译 model.init()(即 flax.linen.Module.init()),并将其输出作为参数的 pytree。此外,您有时可能需要使用 flax.training.train_state 对其进行包装,以跟踪其他变量(例如优化器状态),这将使输出成为更复杂的 pytree。

为了实现这一点,幸运的是,您不必手动硬编码输出的 sharding。相反,您可以

  1. 使用 jax.eval_shape 抽象地评估 model.init(在本例中,它是它的包装器)。

  2. 使用 flax.linen.get_sharding 自动生成 jax.sharding.NamedSharding

def init_fn(k, x, model, optimizer):
  variables = model.init(k, x) # Initialize the model.
  state = train_state.TrainState.create( # Create a `TrainState`.
    apply_fn=model.apply,
    params=variables['params'],
    tx=optimizer)
  return state
# Create an abstract closure to wrap the function before feeding it in
# because `jax.eval_shape` only takes pytrees as arguments.
abstract_variables = jax.eval_shape(
    functools.partial(init_fn, model=model, optimizer=optimizer), k, x)

# This `state_sharding` has the same pytree structure as `state`, the output
# of the `init_fn`.
state_sharding = nn.get_sharding(abstract_variables, mesh)
state_sharding
TrainState(step=NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(), memory_kind=unpinned_host), apply_fn=<bound method Module.apply of MLP(
    # attributes
    num_layers = 4
    depth = 1024
    use_scan = False
)>, params={'DotReluDot_0': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None), memory_kind=unpinned_host)}, 'DotReluDot_1': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None), memory_kind=unpinned_host)}, 'DotReluDot_2': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None), memory_kind=unpinned_host)}, 'DotReluDot_3': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None), memory_kind=unpinned_host)}}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7fce0c117b50>, update=<function chain.<locals>.update_fn at 0x7fce0c117eb0>), opt_state=(ScaleByAdamState(count=NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(), memory_kind=unpinned_host), mu={'DotReluDot_0': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None), memory_kind=unpinned_host)}, 'DotReluDot_1': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None), memory_kind=unpinned_host)}, 'DotReluDot_2': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None), memory_kind=unpinned_host)}, 'DotReluDot_3': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None), memory_kind=unpinned_host)}}, nu={'DotReluDot_0': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None), memory_kind=unpinned_host)}, 'DotReluDot_1': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None), memory_kind=unpinned_host)}, 'DotReluDot_2': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None), memory_kind=unpinned_host)}, 'DotReluDot_3': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None), memory_kind=unpinned_host)}}), EmptyState()))

编译代码#

现在,您可以将 jax.jit 应用于您的 init_fn,但需要两个额外的参数:in_shardingsout_shardings

运行它以获取 initialized_state,其中的参数将按照指示进行分片

jit_init_fn = jax.jit(init_fn, static_argnums=(2, 3),
                      in_shardings=(mesh_sharding(()), x_sharding),  # PRNG key and x
                      out_shardings=state_sharding)

initialized_state = jit_init_fn(k, x, model, optimizer)

# for weight, partitioned in initialized_state.params['DotReluDot_0'].items():
#     print(f'Sharding of {weight}: {partitioned.names}')
jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value)
jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value)
                                    
                                    
                                    
                                    
                                    
 CPU 0,4  CPU 1,5  CPU 2,6  CPU 3,7 
                                    
                                    
                                    
                                    
                                    
                         
         CPU 0,4         
                         
                         
         CPU 1,5         
                         
                         
         CPU 2,6         
                         
                         
         CPU 3,7         
                         

检查模块输出#

请注意,在 initialized_state 的输出中,params W1W2 的类型为 flax.linen.Partitioned。这是对实际 jax.Array 的包装,允许 Flax 记录与其关联的轴名称。

您可以通过在字典上调用 flax.linen.meta.unbox() 或在单个变量上调用 .value 来访问原始的 jax.Array。您还可以使用 flax.linen.meta.replace_boxed() 来更改底层 jax.Array,而无需修改 sharding 注释。

print(type(initialized_state.params['DotReluDot_0']['Dense_0']['kernel']))
print(type(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value))
print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].names)
print(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value.shape)
<class 'flax.core.meta.Partitioned'>
<class 'jaxlib.xla_extension.ArrayImpl'>
(None, 'model')
(1024, 1024)
# Say for some unknown reason you want to make the whole param tree all-zero
unboxed_params = nn.meta.unbox(initialized_state.params)
all_zero = jax.tree.map(jnp.zeros_like, unboxed_params)
all_zero_params = nn.meta.replace_boxed(initialized_state.params, all_zero)
assert jnp.sum(nn.meta.unbox(all_zero_params['DotReluDot_0']['Dense_0']['kernel'])) == 0

您还可以检查每个参数的底层 jax.sharding,它现在比 NamedSharding 更内部。请注意,诸如 initialized_state.step 之类的数字在所有设备上都是复制的。

initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value.sharding
NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)
print(initialized_state.step)
initialized_state.step.sharding
0
NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(), memory_kind=unpinned_host)

您可以使用 jax.tree_util.tree_map 对一组 boxed 参数的字典进行大规模计算,就像对一组 JAX 数组的字典进行大规模计算一样。

diff = jax.tree_util.tree_map(
    lambda a, b: a - b,
    initialized_state.params['DotReluDot_0'], initialized_state.params['DotReluDot_0'])
print(jax.tree_util.tree_map(jnp.shape, diff))
diff_array = diff['Dense_0']['kernel'].value
print(type(diff_array))
print(diff_array.shape)
{'Dense_0': {'kernel': Partitioned(value=(1024, 1024), names=(None, 'model'), mesh=None)}, 'W2': Partitioned(value=(1024, 1024), names=('model', None), mesh=None)}
<class 'jaxlib.xla_extension.ArrayImpl'>
(1024, 1024)

编译训练步骤和推理#

按如下方式创建一个 jit 化的训练步骤

@functools.partial(jax.jit, in_shardings=(state_sharding, x_sharding),
                   out_shardings=state_sharding)
def train_step(state, x):
  # A fake loss function.
  def loss_unrolled(params):
    y = model.apply({'params': params}, x)
    return y.sum()
  grad_fn = jax.grad(loss_unrolled)
  grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  return state

with mesh:
  new_state = train_step(initialized_state, x)
print(f'Sharding of Weight 1:')
jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['Dense_0']['kernel'].value)
print(f'Sharding of Weight 2:')
jax.debug.visualize_array_sharding(initialized_state.params['DotReluDot_0']['W2'].value)
Sharding of Weight 1:
                                    
                                    
                                    
                                    
                                    
 CPU 0,4  CPU 1,5  CPU 2,6  CPU 3,7 
                                    
                                    
                                    
                                    
                                    
Sharding of Weight 2:
                         
         CPU 0,4         
                         
                         
         CPU 1,5         
                         
                         
         CPU 2,6         
                         
                         
         CPU 3,7         
                         

然后,创建一个编译的推理步骤。请注意,输出也沿着 (data, None) 进行分片。

@functools.partial(jax.jit, in_shardings=(state_sharding, x_sharding),
                   out_shardings=x_sharding)
def apply_fn(state, x):
  return state.apply_fn({'params': state.params}, x)

with mesh:
  y = apply_fn(new_state, x)
print(type(y))
print(y.dtype)
print(y.shape)
jax.debug.visualize_array_sharding(y)
<class 'jaxlib.xla_extension.ArrayImpl'>
float32
(8, 1024)
                                                                                
                                                                                
                                  CPU 0,1,2,3                                   
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                                  CPU 4,5,6,7                                   
                                                                                
                                                                                
                                                                                

性能分析#

如果您在 TPU pod 或 pod 切片上运行,则可以使用如下定义的自定义 block_all 实用程序函数来衡量性能

%%timeit

def block_all(xs):
  jax.tree_util.tree_map(lambda x: x.block_until_ready(), xs)
  return xs

with mesh:
  new_state = block_all(train_step(initialized_state, x))
240 ms ± 4.25 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

逻辑轴注释#

JAX 的自动 SPMD 鼓励用户探索不同的 sharding 布局以找到最佳布局。为此,在 Flax 中,您实际上可以使用更具描述性的轴名称(而不仅仅是像 'data''model' 这样的设备网格轴名称)来注释任何数据的维度。

下面的 LogicalDotReluDotLogicalMLP 模块定义与您之前创建的模块类似,但以下情况除外

  1. 所有轴都使用更具体、更有意义的名称进行注释,例如 'embed''hidden''batch''layer'。这些名称在 Flax 中被称为 *逻辑轴名称*。它们使模型定义内部的维度更改更具可读性。

  2. flax.linen.with_logical_partitioning 替换了 flax.linen.with_partitioningflax.linen.with_logical_constraint 替换了 jax.lax.with_sharding_constraint,以识别逻辑轴名称。

class LogicalDotReluDot(nn.Module):
  depth: int
  dense_init: Callable = nn.initializers.xavier_normal()
  @nn.compact
  def __call__(self, x):
    y = nn.Dense(self.depth,
                 kernel_init=nn.with_logical_partitioning(self.dense_init, ('embed', 'hidden')),
                 use_bias=False,  # or overwrite with `bias_init`
                 )(x)

    y = jax.nn.relu(y)
    # Force a local sharding annotation.
    y = with_sharding_constraint(y, mesh_sharding(PartitionSpec('data', 'model')))

    W2 = self.param(
        'W2',
        nn.with_logical_partitioning(self.dense_init, ('hidden', 'embed')),
        (self.depth, x.shape[-1]))

    z = jnp.dot(y, W2)
    # Force a local sharding annotation.
    z = nn.with_logical_constraint(z, ('batch', 'embed'))
    return z, None

class LogicalMLP(nn.Module):
  num_layers: int
  depth: int
  use_scan: bool
  @nn.compact
  def __call__(self, x):
    if self.use_scan:
      x, _ = nn.scan(LogicalDotReluDot, length=self.num_layers,
                    variable_axes={"params": 0},
                    split_rngs={"params": True},
                    metadata_params={nn.PARTITION_NAME: 'layer'}
                    )(self.depth)(x)
    else:
      for i in range(self.num_layers):
        x, _ = LogicalDotReluDot(self.depth)(x)
    return x

现在,初始化一个模型并尝试找出其 state 应具有的 sharding。

为了使设备网格正确获取您的模型,您需要确定这些逻辑轴名称中的哪些映射到设备轴 'data''model'。此规则是 (logical_axis_name, device_axis_name) 元组列表,并且 flax.linen.logical_to_mesh_sharding 会将它们转换为设备网格可以理解的那种 sharding。

这允许您更改规则并尝试新的分区布局,而无需修改模型定义。

# Unspecified rule means unsharded by default, so no need to specify `('embed', None)` and `('layer', None)`.
rules = (('batch', 'data'),
         ('hidden', 'model'))

logical_model = LogicalMLP(LAYERS, DEPTH, USE_SCAN)

logical_abstract_variables = jax.eval_shape(
    functools.partial(init_fn, model=logical_model, optimizer=optimizer), k, x)
logical_state_spec = nn.get_partition_spec(logical_abstract_variables)
print('annotations are logical, not mesh-specific: ',
      logical_state_spec.params['LogicalDotReluDot_0']['Dense_0']['kernel'])

logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, rules)
print('sharding annotations are mesh-specific: ',
      logical_state_sharding.params['LogicalDotReluDot_0']['Dense_0']['kernel'].spec)
annotations are logical, not mesh-specific:  PartitionSpec('embed', 'hidden')
sharding annotations are mesh-specific:  PartitionSpec(None, 'model')

您可以验证这里的 logical_state_spec 是否与前一个(“非逻辑”)示例中的 state_spec 具有相同的内容。这允许您以与上述相同的方式 jax.jit 您的模块的 flax.linen.Module.initflax.linen.Module.apply

state_sharding.params['DotReluDot_0'] == logical_state_sharding.params['LogicalDotReluDot_0']
True
logical_jit_init_fn = jax.jit(init_fn, static_argnums=(2, 3),
                      in_shardings=(mesh_sharding(()), x_sharding),  # PRNG key and x
                      out_shardings=logical_state_sharding)

logical_initialized_state = logical_jit_init_fn(k, x, logical_model, optimizer)
print(f'Sharding of Weight 1:')
jax.debug.visualize_array_sharding(logical_initialized_state.params['LogicalDotReluDot_0']['Dense_0']['kernel'].value)
print(f'Sharding of Weight 2:')
jax.debug.visualize_array_sharding(logical_initialized_state.params['LogicalDotReluDot_0']['W2'].value)
Sharding of Weight 1:
                                    
                                    
                                    
                                    
                                    
 CPU 0,4  CPU 1,5  CPU 2,6  CPU 3,7 
                                    
                                    
                                    
                                    
                                    
Sharding of Weight 2:
                         
         CPU 0,4         
                         
                         
         CPU 1,5         
                         
                         
         CPU 2,6         
                         
                         
         CPU 3,7         
                         

何时使用设备轴/逻辑轴#

选择何时使用设备轴或逻辑轴取决于您想对模型的分区进行多少控制。

  • 设备网格轴:如果您想要一个非常简单的模型,或者您对自己的分区方式非常有信心,那么使用设备网格轴定义它可能会为您节省一些额外的代码行,从而将逻辑命名转换回设备命名。

  • 逻辑命名:另一方面,逻辑命名助手可用于探索不同的分片布局。如果您想进行实验并找到模型的最优分区布局,请使用此方法。

  • 设备轴名称:在非常高级的用例中,您可能会有更复杂的分片模式,需要以不同于参数维度名称的方式注释激活维度名称。如果您希望对手动网格分配进行更精细的控制,直接使用设备轴名称可能会更有帮助。

保存数据#

要保存跨设备数组,您可以使用flax.training.checkpoints,如保存和加载检查点指南 - 多主机/多进程检查点中所示。如果您在多主机环境(例如,TPU pod)上运行,则尤其需要这样做。

在实践中,您可能希望将原始的jax.Array pytree 作为检查点保存,而不是包装的 Partitioned 值,以降低复杂性。您可以按原样恢复它,并使用 flax.linen.meta.replace_boxed() 将其放回带注释的 pytree 中。

请记住,要将数组恢复到所需的分区,您需要提供一个示例 target pytree,该 pytree 具有相同的结构,并且每个 JAX 数组都具有所需的 jax.sharding.Sharding。您用来恢复数组的分片不必与您用来存储数组的分片相同。