flax.jax_utils 包#
我们可以考虑上游到 Jax 的实用程序。
- flax.jax_utils.partial_eval_by_shape(fn, input_spec, *args, **kwargs)[源代码]#
通过使用输入的形状来惰性评估函数。
此函数类似于
jax.eval_shape
,主要区别在于,可以在没有输入具体值的情况下计算的函数输出将按原样返回,而不仅仅是形状。例如,请参阅module.init_by_shape
,其中此功能用于初始化模型而不使用输入数据 lr 计算。- 参数
fn – 要惰性评估的函数。
input_spec – 形状或(形状、dtype)元组的可迭代对象,指定输入的形状和类型。如果未指定,则 dtype 为 float32。
*args – 传递给模块 apply 函数的其他参数
**kwargs – 传递给模块 apply 函数的关键字参数
- 返回
一个由模型输出和 Model 实例组成的对
多设备实用程序#
- flax.jax_utils.replicate(tree, devices=None)[源代码]#
将数组复制到多个设备。
- 参数
tree – 包含应复制的数组的 pytree。
devices – 数据复制到的设备(默认:与
jax.pmap()
预期的顺序相同)。
- 返回
包含复制数组的新 pytree。
- flax.jax_utils.prefetch_to_device(iterator, size, devices=None)[源代码]#
在设备上分片和预取批次。
此实用程序接受一个迭代器,并返回一个新迭代器,该迭代器填充设备上的预取缓冲区。通过重叠计算和数据传输,尽早预取可以显着提高训练循环的性能。
此实用程序主要对 GPU 有用,对于 TPU 和 CPU,则不需要 – TPU 和 CPU 内存分配器(通常)不会选择尚未空闲的内存位置,因此它们不会阻塞。相反,这些分配器会 OOM。
- 参数
iterator – 一个迭代器,它产生一个 ndarray 的 pytree,其中第一维度在设备之间分片。
size –
预取缓冲区的大小。
如果您在 GPU 上进行训练,通常 2 是最佳选择,因为这可以保证您可以将 GPU 上的训练步骤与 CPU 上的数据预取步骤重叠。
devices –
应该预取数组的设备列表。
默认为
jax.pmap
预期的设备顺序。
- 产生
来自迭代器的原始项,其中每个 ndarray 现在都分片到指定的设备。
- flax.jax_utils.pad_shard_unpad(wrapped, static_argnums=(0,), static_argnames=(), static_return=False)[源代码]#
使用在填充、分片、然后取消分片、取消填充的代码包装函数。
- 参数
wrapped – 要包装的函数。签名是
params, *args, *kwargs
。static_argnums –
wrapped
的参数索引,这些参数_不应_填充和分片,而应按原样转发。默认值为 (0,),因为到目前为止,最常见的用例是首先传递params
。static_argnames –
wrapped
的 kwargs 名称,这些参数_不应_填充和分片,而应按原样转发。static_return – 是否不对返回值进行取消分片和取消填充;静态返回值通常与计算指标的 eval 步骤一起使用
- 返回
在将其传递给包装的函数之前填充和分片其参数的新函数,以及取消分片和取消填充返回的 pytree。
这对于调用输入不能被设备数量整除的 pmap 函数非常有用。一个典型的用法是
@pad_shard_unpad @jax.pmap def forward(params, x): …
注意
填充在主机内存中完成,然后再传递给该函数,并且该函数返回的值将传输回主机内存。
返回的函数增加了一个新的仅关键字参数
min_device_batch
,如果指定,则强制将输入填充到每个设备的至少此大小。这对于避免最后一个批次的重新编译和减少内存碎片非常有用。有关更多信息,请参阅 https://flax.org.cn/en/latest/guides/data_preprocessing/full_eval.html