SPMD#
用于处理 jit 和分区模型的实用工具。
此模块引入 axis_rules
、logical_to_mesh_axes
、logical_to_mesh
、with_logical_constraint
,用于根据“逻辑命名轴”而不是 jit 的默认网格轴应用 jit 分片约束。
此外,还定义了 LogicallyPartitioned
元数据包装器以及初始化函数包装器 ``with_logical_partitioning``,用于将逻辑轴元数据引入模型的变量中。
- flax.linen.Partitioned(value, names, mesh=None)[source]#
用于分区元数据的包装器。
Partitioned
用于使用jax.experimental.pjit
所需的分区信息来扩展变量。定义 Partitioned 变量的最简单方法是使用变量初始化器周围的
with_partitioning
包装器。示例
class MLP(nn.Module): hidden_size: int @nn.compact def __call__(self, x): ki = nn.linear.default_kernel_init h = nn.Dense( self.hidden_size, kernel_init=nn.with_partitioning(ki, ('data', 'model')))(x) h = nn.relu(h) return nn.Dense( x.shape[-1], kernel_init=nn.with_partitioning(ki, ('model', 'data')))(h) mlp = MLP(4096) x = jnp.ones((8 * 1024, 1024)) # use eval_shape to get the Partitioned instances for the variables. # this way we can determine the PartitionSpecs for the init variables # before we call the init fn. var_spec = nn.get_partition_spec( jax.eval_shape(mlp.init, random.key(0), x)) init_fn = mesh(pjit(mlp.init, (None, PartitionSpec("data", "model")), var_spec)) variables = init_fn(random.key(0), x) apply_fn = mesh(pjit( mlp.apply, (var_spec, PartitionSpec("data", "model")), PartitionSpec("data", "model"))) apply_fn(variables, x)
当使用
nn.vmap
和nn.scan
等转换时,Partitioned
值可以获得额外的轴。在这种情况下,您可以使用 vmap/scan 中的 metadata_params 参数指定新轴的名称class Model(nn.Module): @nn.compact def __call__(self, x): def body(mdl, c): c = MLP(4096)(c) return c, () c, _ = nn.scan( body, variable_axes={"params": 0}, split_rngs={"params": 0}, length=8, metadata_params={nn.meta.PARTITION_NAME: "layers"})(self, x) return c
- flax.linen.with_partitioning(fn, names, mesh=None)[source]#
使用 Partitioned 包装函数的返回值。
示例
>>> import flax.linen as nn >>> kernel_init = nn.with_partitioning( ... nn.initializers.lecun_normal(), (None, "data")) >>> partitioned_dense = nn.Dense(features=3, kernel_init=kernel_init)
- 参数
fn – 要包装的函数。通常这是一个初始化器。
names – 传递给
Partitioned
的逻辑轴。mesh – 用于分区的网格。如果为 None,则如果可用,则使用全局网格资源。
- 返回
一个包装
fn
的函数,它将返回Partitioned
的实例。
- flax.linen.LogicallyPartitioned(value: Any, names: tuple[Optional[str], ...], mesh: jax._src.mesh.Mesh | None = None, rules: collections.abc.Sequence[tuple[str, str | tuple[str, ...] | None]] | None = None)[source]#
- flax.linen.logical_to_mesh_axes(array_dim_names, rules=None)[source]#
计算数组的布局。
规则的优先级按顺序排列,并且由以下几对组成:
(ArrayDimensionName, MeshDimensionName)
,这意味着给定的数组维度(如果存在且未使用)应在给定的网格维度(如果存在且未使用)上进行分片。数组的布局表示为一个元组,该元组的每个元素对应数组中的一个维度。该元素要么为 None,要么为网格维度的名称,这意味着数组的此维度在网格的此维度上进行分片。
例如,给定一个具有以下内容的数组
array_dim_names = ('batch', 'length', 'heads', 'features')
并且布局规则是
rules = (('batch', 'X'), ('features', 'X'), ('heads', 'Y'), ('batch', 'Z'))
那么此函数将返回
PartitionSpec('X', None, 'Y', None)
- 参数
array_dim_names – 数组维度名称或 None 的元组。
rules – 可选的逻辑到网格规则覆盖。默认为使用从
axis_rules
函数中设置的动态上下文中的规则。
- 返回
参数的 PartitionSpec。
- flax.linen.logical_to_mesh(tree, rules=None)[源代码]#
将 logical_to_mesh_axes 应用于逻辑 PartitionSpecs 的 pytrees。
- flax.linen.logical_to_mesh_sharding(tree, mesh, rules=None)[源代码]#
将逻辑 PartitionSpecs 的 pytrees 转换为分片。
- flax.linen.with_logical_constraint(x, logical_axis_resources, rules=None, mesh=None, fallback=RulesFallback.AXIS_IS_UNSHARDED)[源代码]#
使用逻辑轴名称的 jit 的 with_sharding_constraint 版本。
- flax.linen.with_logical_partitioning(fn, names, mesh=None, rules=None)[源代码]#
使用 LogicallyPartitioned 包装函数的返回值。
示例
>>> import flax.linen as nn >>> kernel_init = nn.with_logical_partitioning( ... nn.initializers.lecun_normal(), (None, "data")) >>> partitioned_dense = nn.Dense(features=3, kernel_init=kernel_init)
- 参数
fn – 要包装的函数。通常这是一个初始化器。
names – 传递给
LogicallyPartitioned
的逻辑轴。mesh – 用于分区的网格。如果为 None,则如果可用,则使用全局网格资源。
rules – 可选的逻辑到网格规则使用。如果为 None,则如果可用,则使用全局规则。
- 返回
一个包装
fn
的函数,它将返回LogicallyPartitioned
的一个实例。