SPMD#

用于处理 jit 和分区模型的实用工具。

此模块引入 axis_ruleslogical_to_mesh_axeslogical_to_meshwith_logical_constraint,用于根据“逻辑命名轴”而不是 jit 的默认网格轴应用 jit 分片约束。

此外,还定义了 LogicallyPartitioned 元数据包装器以及初始化函数包装器 ``with_logical_partitioning``,用于将逻辑轴元数据引入模型的变量中。

flax.linen.Partitioned(value, names, mesh=None)[source]#

用于分区元数据的包装器。

Partitioned 用于使用 jax.experimental.pjit 所需的分区信息来扩展变量。

定义 Partitioned 变量的最简单方法是使用变量初始化器周围的 with_partitioning 包装器。

示例

class MLP(nn.Module):
  hidden_size: int
  @nn.compact
  def __call__(self, x):
    ki = nn.linear.default_kernel_init
    h = nn.Dense(
        self.hidden_size,
        kernel_init=nn.with_partitioning(ki, ('data', 'model')))(x)
    h = nn.relu(h)
    return nn.Dense(
        x.shape[-1],
        kernel_init=nn.with_partitioning(ki, ('model', 'data')))(h)
mlp = MLP(4096)
x = jnp.ones((8 * 1024, 1024))
# use eval_shape to get the Partitioned instances for the variables.
# this way we can determine the PartitionSpecs for the init variables
# before we call the init fn.
var_spec = nn.get_partition_spec(
    jax.eval_shape(mlp.init, random.key(0), x))
init_fn = mesh(pjit(mlp.init,
                    (None, PartitionSpec("data", "model")), var_spec))
variables = init_fn(random.key(0), x)
apply_fn = mesh(pjit(
    mlp.apply,
    (var_spec, PartitionSpec("data", "model")),
     PartitionSpec("data", "model")))
apply_fn(variables, x)

当使用 nn.vmapnn.scan 等转换时,Partitioned 值可以获得额外的轴。在这种情况下,您可以使用 vmap/scan 中的 metadata_params 参数指定新轴的名称

class Model(nn.Module):
@nn.compact
def __call__(self, x):
  def body(mdl, c):
    c = MLP(4096)(c)
    return c, ()
  c, _ = nn.scan(
      body, variable_axes={"params": 0}, split_rngs={"params": 0}, length=8,
      metadata_params={nn.meta.PARTITION_NAME: "layers"})(self, x)
  return c
flax.linen.with_partitioning(fn, names, mesh=None)[source]#

使用 Partitioned 包装函数的返回值。

示例

>>> import flax.linen as nn
>>> kernel_init = nn.with_partitioning(
...     nn.initializers.lecun_normal(), (None, "data"))
>>> partitioned_dense = nn.Dense(features=3, kernel_init=kernel_init)
参数
  • fn – 要包装的函数。通常这是一个初始化器。

  • names – 传递给 Partitioned 的逻辑轴。

  • mesh – 用于分区的网格。如果为 None,则如果可用,则使用全局网格资源。

返回

一个包装 fn 的函数,它将返回 Partitioned 的实例。

flax.linen.get_partition_spec(tree)[source]#

从包含 Partitioned 值的 PyTree 中提取 PartitionSpec 树。

flax.linen.get_sharding(tree, mesh)[source]#

从包含 Partitioned 值和网格的 PyTree 中提取 jax.sharding 树。

flax.linen.LogicallyPartitioned(value: Any, names: tuple[Optional[str], ...], mesh: jax._src.mesh.Mesh | None = None, rules: collections.abc.Sequence[tuple[str, str | tuple[str, ...] | None]] | None = None)[source]#
flax.linen.logical_axis_rules(rules)[source]#

用于设置逻辑到网格轴绑定的上下文管理器。

flax.linen.set_logical_axis_rules(rules)[source]#

设置全局逻辑轴到网格轴的绑定。

flax.linen.get_logical_axis_rules()[source]#

返回全局逻辑轴到网格轴的绑定。

flax.linen.logical_to_mesh_axes(array_dim_names, rules=None)[source]#

计算数组的布局。

规则的优先级按顺序排列,并且由以下几对组成: (ArrayDimensionName, MeshDimensionName),这意味着给定的数组维度(如果存在且未使用)应在给定的网格维度(如果存在且未使用)上进行分片。

数组的布局表示为一个元组,该元组的每个元素对应数组中的一个维度。该元素要么为 None,要么为网格维度的名称,这意味着数组的此维度在网格的此维度上进行分片。

例如,给定一个具有以下内容的数组

array_dim_names = ('batch', 'length', 'heads', 'features')

并且布局规则是

rules = (('batch', 'X'),
         ('features', 'X'),
         ('heads', 'Y'),
         ('batch', 'Z'))

那么此函数将返回

PartitionSpec('X', None, 'Y', None)
参数
  • array_dim_names – 数组维度名称或 None 的元组。

  • rules – 可选的逻辑到网格规则覆盖。默认为使用从 axis_rules 函数中设置的动态上下文中的规则。

返回

参数的 PartitionSpec。

flax.linen.logical_to_mesh(tree, rules=None)[源代码]#

将 logical_to_mesh_axes 应用于逻辑 PartitionSpecs 的 pytrees。

flax.linen.logical_to_mesh_sharding(tree, mesh, rules=None)[源代码]#

将逻辑 PartitionSpecs 的 pytrees 转换为分片。

flax.linen.with_logical_constraint(x, logical_axis_resources, rules=None, mesh=None, fallback=RulesFallback.AXIS_IS_UNSHARDED)[源代码]#

使用逻辑轴名称的 jit 的 with_sharding_constraint 版本。

flax.linen.with_logical_partitioning(fn, names, mesh=None, rules=None)[源代码]#

使用 LogicallyPartitioned 包装函数的返回值。

示例

>>> import flax.linen as nn
>>> kernel_init = nn.with_logical_partitioning(
...     nn.initializers.lecun_normal(), (None, "data"))
>>> partitioned_dense = nn.Dense(features=3, kernel_init=kernel_init)
参数
  • fn – 要包装的函数。通常这是一个初始化器。

  • names – 传递给 LogicallyPartitioned 的逻辑轴。

  • mesh – 用于分区的网格。如果为 None,则如果可用,则使用全局网格资源。

  • rules – 可选的逻辑到网格规则使用。如果为 None,则如果可用,则使用全局规则。

返回

一个包装 fn 的函数,它将返回 LogicallyPartitioned 的一个实例。