flax.jax_utils 包#

我们可以考虑上游到 Jax 的实用程序。

flax.jax_utils.partial_eval_by_shape(fn, input_spec, *args, **kwargs)[源代码]#

通过使用输入的形状来惰性评估函数。

此函数类似于 jax.eval_shape,主要区别在于,可以在没有输入具体值的情况下计算的函数输出将按原样返回,而不仅仅是形状。例如,请参阅 module.init_by_shape,其中此功能用于初始化模型而不使用输入数据 lr 计算。

参数
  • fn – 要惰性评估的函数。

  • input_spec – 形状或(形状、dtype)元组的可迭代对象,指定输入的形状和类型。如果未指定,则 dtype 为 float32。

  • *args – 传递给模块 apply 函数的其他参数

  • **kwargs – 传递给模块 apply 函数的关键字参数

返回

一个由模型输出和 Model 实例组成的对

多设备实用程序#

flax.jax_utils.replicate(tree, devices=None)[源代码]#

将数组复制到多个设备。

参数
  • tree – 包含应复制的数组的 pytree。

  • devices – 数据复制到的设备(默认:与 jax.pmap() 预期的顺序相同)。

返回

包含复制数组的新 pytree。

flax.jax_utils.unreplicate(tree)[源代码]#

返回复制数组的单个实例。

flax.jax_utils.prefetch_to_device(iterator, size, devices=None)[源代码]#

在设备上分片和预取批次。

此实用程序接受一个迭代器,并返回一个新迭代器,该迭代器填充设备上的预取缓冲区。通过重叠计算和数据传输,尽早预取可以显着提高训练循环的性能。

此实用程序主要对 GPU 有用,对于 TPU 和 CPU,则不需要 – TPU 和 CPU 内存分配器(通常)不会选择尚未空闲的内存位置,因此它们不会阻塞。相反,这些分配器会 OOM。

参数
  • iterator – 一个迭代器,它产生一个 ndarray 的 pytree,其中第一维度在设备之间分片。

  • size

    预取缓冲区的大小。

    如果您在 GPU 上进行训练,通常 2 是最佳选择,因为这可以保证您可以将 GPU 上的训练步骤与 CPU 上的数据预取步骤重叠。

  • devices

    应该预取数组的设备列表。

    默认为 jax.pmap 预期的设备顺序。

产生

来自迭代器的原始项,其中每个 ndarray 现在都分片到指定的设备。

flax.jax_utils.pmean(xs, axis_name)[源代码]#
flax.jax_utils.pad_shard_unpad(wrapped, static_argnums=(0,), static_argnames=(), static_return=False)[源代码]#

使用在填充、分片、然后取消分片、取消填充的代码包装函数。

参数
  • wrapped – 要包装的函数。签名是 params, *args, *kwargs

  • static_argnumswrapped 的参数索引,这些参数_不应_填充和分片,而应按原样转发。默认值为 (0,),因为到目前为止,最常见的用例是首先传递 params

  • static_argnameswrapped 的 kwargs 名称,这些参数_不应_填充和分片,而应按原样转发。

  • static_return – 是否不对返回值进行取消分片和取消填充;静态返回值通常与计算指标的 eval 步骤一起使用

返回

在将其传递给包装的函数之前填充和分片其参数的新函数,以及取消分片和取消填充返回的 pytree。

这对于调用输入不能被设备数量整除的 pmap 函数非常有用。一个典型的用法是

@pad_shard_unpad @jax.pmap def forward(params, x): …

注意

填充在主机内存中完成,然后再传递给该函数,并且该函数返回的值将传输回主机内存。

返回的函数增加了一个新的仅关键字参数 min_device_batch,如果指定,则强制将输入填充到每个设备的至少此大小。这对于避免最后一个批次的重新编译和减少内存碎片非常有用。

有关更多信息,请参阅 https://flax.org.cn/en/latest/guides/data_preprocessing/full_eval.html