flax.training 包#

检查点#

检查点辅助函数。

处理基于步骤数或其他数值指标在文件名中保存和恢复优化器检查点。清理较旧/性能较差的检查点文件。

flax.training.checkpoints.save_checkpoint(ckpt_dir, target, step, prefix='checkpoint_', keep=1, overwrite=False, keep_every_n_steps=None, async_manager=None, orbax_checkpointer=None)[源代码]#

保存模型检查点。适用于单主机。

在此方法中,每个 JAX 进程都会自行保存检查点。如果您的多个进程打算将数据保存到公共目录(例如 GCloud 存储桶),请不要使用它。要将多进程检查点保存到共享存储或保存 GlobalDeviceArray``s, 请改用 ``save_checkpoint_multiprocess()

通过先写入临时文件,然后再最终重命名和清理过去的文件,可以安全地进行抢占。但是,如果使用了 async_manager,则最终提交将在异步回调中进行,可以通过调用 async_manager.wait_previous_save() 显式等待。

用法示例

>>> from flax.training import checkpoints
>>> import jax.numpy as jnp
>>> import tempfile

>>> with tempfile.TemporaryDirectory() as dir_path:
...   test_object = {
...     'a': jnp.array([1, 2, 3], jnp.int32),
...     'b': jnp.array([1, 1, 1], jnp.int32),
...   }
...   file_path = checkpoints.save_checkpoint(
...     dir_path, target=test_object, step=0, prefix='test_', keep=1
...   )
...   restored_object = checkpoints.restore_checkpoint(
...     file_path, target=None
...   )
>>> restored_object
{'a': Array([1, 2, 3], dtype=int32), 'b': Array([1, 1, 1], dtype=int32)}
参数
  • ckpt_dir – 用于存储检查点文件的 str 或类路径 pathlib。

  • target – 可序列化的 flax 对象,通常是 flax 优化器。

  • step – int 或 float:训练步骤数或其他指标数。

  • prefix – str:检查点文件名前缀。

  • keep – 要保留的过去检查点文件数。

  • overwrite – 如果当前或更高步骤的检查点已存在,则覆盖现有检查点文件(默认值:False)。

  • keep_every_n_steps – 如果已定义,则每隔 n 步保留每个检查点(除了保留最后的 'keep' 个检查点之外)。

  • async_manager – 如果已定义,则保存将在不阻塞主线程的情况下运行。仅适用于单主机。请注意,正在进行的保存仍将阻止后续保存,以确保覆盖/保留逻辑正常工作。

  • orbax_checkpointer – 如果已定义,则保存将由 ocp 完成。将来,所有 Flax 检查点功能都将迁移到 Orbax,建议开始使用 orbax_checkpointer。有关如何使用 Orbax 检查点,请查看检查点指南 (https://flax.org.cn/en/latest/guides/training_techniques/use_checkpointing.html#save-checkpoints)。

返回

已保存的检查点文件名。

flax.training.checkpoints.save_checkpoint_multiprocess(ckpt_dir, target, step, prefix='checkpoint_', keep=1, overwrite=False, keep_every_n_steps=None, async_manager=None, gda_manager=None, orbax_checkpointer=None)[源代码]#

在多进程环境中保存模型检查点。

使用此方法保存``GlobalDeviceArray``s,或将数据保存到公共目录。只有进程 0 会保存主检查点文件并删除旧的检查点文件。

通过先写入临时文件,然后再最终重命名和清理过去的文件,可以安全地进行抢占。但是,如果使用了 async_manager 或 gda_manager,则最终提交将在异步回调中进行,可以通过调用 async_manager.wait_previous_save()gda_manager.wait_until_finished() 显式等待。

参数
  • ckpt_dir – 用于存储检查点文件的 str 或类路径 pathlib。

  • target – 可序列化的 flax 对象,通常是 flax 优化器。

  • step – int 或 float:训练步骤数或其他指标数。

  • prefix – str:检查点文件名前缀。

  • keep – 要保留的过去检查点文件数。

  • overwrite – 如果当前或更高步骤的检查点已存在,则覆盖现有检查点文件(默认值:False)。

  • keep_every_n_steps – 如果已定义,则每隔 n 步保留每个检查点(除了保留最后的 'keep' 个检查点之外)。

  • async_manager – 如果已定义,则保存将在不阻塞主线程的情况下运行。仅适用于单主机。请注意,正在进行的保存仍将阻止后续保存,以确保覆盖/保留逻辑正常工作。

  • gda_manager – 如果 target 包含 JAX GlobalDeviceArray,则为必需项。将异步将 GDA 保存到带有后缀“_gda”的单独子目录中。与 async_manager 相同,这会阻止后续保存。

  • orbax_checkpointer – 如果已定义,则保存将由 Orbax 完成。将来,所有 Flax 检查点功能都将迁移到 Orbax,建议开始使用 orbax_checkpointer。有关如何使用 Orbax 检查点,请查看检查点指南 (https://flax.org.cn/en/latest/guides/training_techniques/use_checkpointing.html#save-checkpoints)。

返回

已保存的检查点文件名。

flax.training.checkpoints.latest_checkpoint(ckpt_dir, prefix='checkpoint_')[源代码]#

检索目录中最新检查点的路径。

参数
  • ckpt_dir – str:要从中还原检查点的目录。

  • prefix – str:检查点文件名的前缀。

返回

最新的检查点路径,如果未找到检查点,则为 None。

flax.training.checkpoints.restore_checkpoint(ckpt_dir, target, step=None, prefix='checkpoint_', parallel=True, gda_manager=None, allow_partial_mpa_restoration=False, orbax_checkpointer=None, orbax_transforms=None)[source]#

从路径中的检查点恢复最后一个/最佳检查点。

自然地对检查点文件进行排序,返回最高值的文件,例如:

  • ckpt_1, ckpt_2, ckpt_3 --> ckpt_3

  • ckpt_0.01, ckpt_0.1, ckpt_0.001 --> ckpt_0.1

  • ckpt_-1.0, ckpt_1.0, ckpt_1e5 --> ckpt_1e5

用法示例

>>> from flax.training import checkpoints
>>> import jax.numpy as jnp
>>> import tempfile
...
>>> with tempfile.TemporaryDirectory() as dir_path:
...   test_object = {
...     'a': jnp.array([1, 2, 3], jnp.int32),
...     'b': jnp.array([1, 1, 1], jnp.int32),
...   }
...   file_path = checkpoints.save_checkpoint(
...     dir_path, target=test_object, step=0, prefix='test_', keep=1
...   )
...   restored_object = checkpoints.restore_checkpoint(
...     file_path, target=None
...   )
>>> restored_object
{'a': Array([1, 2, 3], dtype=int32), 'b': Array([1, 1, 1], dtype=int32)}
参数
  • ckpt_dir – str: 要从中恢复的检查点文件或检查点目录。

  • target – 通过反序列化的状态字典重建的匹配对象。如果为None,则按原样返回反序列化的状态字典。

  • step – int 或 float: 要加载的步数,如果为 None 则加载最新的。如果指定,则 ckpt_dir 必须是目录。

  • prefix – str:检查点文件名的前缀。

  • parallel – bool: 为了速度,是否并行加载可搜索的检查点。

  • gda_manager – 如果检查点包含多进程数组(来自 pjit 的 GlobalDeviceArray 或 jax Array),则为必需项。将从带有后缀“_gda”的单独子目录中读取数组。

  • allow_partial_mpa_restoration – 如果为 true,则给定的 target 不必包含所有有效的多进程数组。因此,恢复的 Pytree 可能有一些 MPA 没有正确恢复。如果您无法提供完全有效的 target 并且不需要恢复检查点中的所有 MPA,请使用此选项。

  • orbax_checkpointerocp.Checkpointer,如果给定的检查点是用 ocp 保存的,则它处理底层恢复。

  • orbax_transforms – 将传递到 orbax_checkpointer.restore() 调用的 Orbax 转换。

返回

从检查点文件更新的已恢复 target,或者如果未指定步数且不存在检查点文件,则返回传入的 target,且不会更改。如果指定的文件路径未找到,将返回传入的 target。这与指定目录路径但尚未创建目录的情况的行为相匹配。

flax.training.checkpoints.convert_pre_linen(params)[source]#

转换预 Linen 参数 pytree。

在预 Linen API 中,子模块是递增编号的,与子模块类无关。对于 Linen,此行为已更改为为每个模块类保留单独的子模块计数。

考虑以下模块

class Model(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Conv(1, 1)(x)
    x = nn.Dense(1)(x)
    return x

在预 Linen 中,生成的参数将具有以下结构

{'Conv_0': { ... }, 'Dense_1': { ... } }

对于 Linen,生成的参数将具有以下结构

{'Conv_0': { ... }, 'Dense_0': { ... } }

要从预 Linen 格式转换为 Linen,只需调用

params = convert_pre_linen(pre_linen_params)

请注意,您也可以使用此实用程序来转换预 Linen 集合,因为它们遵循相同的模块命名。但是请注意,集合在预 Linen 中是“扁平的”,需要先展平,然后才能与此函数一起使用

batch_stats = convert_pre_linen(flax.traverse_util.unflatten_dict({
    tuple(k.split('/')[1:]): v
    for k, v in pre_linen_model_state.as_dict().items()
}))

然后可以从这些转换后的集合中定义 Linen 变量

variables = {'params': params, 'batch_stats': batch_stats}
参数

params – 预 Linen 格式的参数 pytree。如果 pytree 已经是 Linen 格式,则返回的 pytree 不会更改(即,可以安全地在任何加载的检查点上调用此函数以与 Linen 一起使用)。

返回

具有 Linen 子模块命名的参数 pytree。

学习率计划#

FLAX 图像分类示例中使用的学习率计划。

请注意,使用 FLIP #1009flax.training 中的学习率计划实际上已弃用,而支持 Optax 计划。有关更多信息,请参阅优化器计划

flax.training.lr_schedule.create_constant_learning_rate_schedule(base_learning_rate, steps_per_epoch, warmup_length=0.0)[source]#

创建带有可选预热的恒定学习率计划。

请注意,使用 FLIP #1009flax.training 中的学习率计划实际上已弃用,而支持 Optax 计划。有关更多信息,请参阅优化器计划

使学习率保持恒定。此函数还提供了根据 https://arxiv.org/abs/1706.02677 的学习率预热,以便使用大型小批量进行训练。

参数
  • base_learning_rate – 基本学习率

  • steps_per_epoch – 每个 epoch 的迭代次数

  • warmup_length – 如果 > 0,学习率将通过预热因子进行调整,该因子将在前 warmup_length 个 epoch 内从 0 线性增加到 1

返回

函数 f(step) -> lr 计算给定步数的学习率。

flax.training.lr_schedule.create_stepped_learning_rate_schedule(base_learning_rate, steps_per_epoch, lr_sched_steps, warmup_length=0.0)[source]#

创建带有可选预热的阶梯式学习率计划。

请注意,使用 FLIP #1009flax.training 中的学习率计划实际上已弃用,而支持 Optax 计划。有关更多信息,请参阅优化器计划

阶梯式学习率计划在指定的 epoch 处按指定的量减少学习率。这些步数作为 lr_sched_steps 参数给出。常见的 ImageNet 计划在 epoch 30、60 和 80 时将学习率降低 0.1。这将指定为

[
  [30, 0.1],
  [60, 0.01],
  [80, 0.001]
]

此函数还提供了根据 https://arxiv.org/abs/1706.02677 的学习率预热,以便使用大型小批量进行训练。

参数
  • base_learning_rate – 基本学习率

  • steps_per_epoch – 每个 epoch 的迭代次数

  • lr_sched_steps – 计划作为步数列表,每个步数都是一个 [epoch, lr_factor] 对;步数发生在 epoch 处并将学习率设置为 base_learning_rage * lr_factor

  • warmup_length – 如果 > 0,学习率将通过预热因子进行调整,该因子将在前 warmup_length 个 epoch 内从 0 线性增加到 1

返回

函数 f(step) -> lr 计算给定步数的学习率。

flax.training.lr_schedule.create_cosine_learning_rate_schedule(base_learning_rate, steps_per_epoch, halfcos_epochs, warmup_length=0.0)[source]#

创建带有可选预热的余弦学习率计划。

请注意,使用 FLIP #1009flax.training 中的学习率计划实际上已弃用,而支持 Optax 计划。有关更多信息,请参阅优化器计划

余弦学习率计划使用半个余弦波来调整学习率,在训练结束时逐渐将其缩放到 0。

此函数还提供了根据 https://arxiv.org/abs/1706.02677 的学习率预热,以便使用大型小批量进行训练。

参数
  • base_learning_rate – 基本学习率

  • steps_per_epoch – 每个 epoch 的迭代次数

  • halfcos_epochs – 完成半个余弦波所需的 epoch 数;通常是用于训练的 epoch 数

  • warmup_length – 如果 > 0,学习率将通过预热因子进行调整,该因子将在前 warmup_length 个 epoch 内从 0 线性增加到 1

返回

函数 f(step) -> lr 计算给定步数的学习率。

训练状态#

class flax.training.train_state.TrainState(step, apply_fn, params, tx, opt_state)[source]#

带有单个 Optax 优化器的常见情况的简单训练状态。

用法示例

>>> import flax.linen as nn
>>> from flax.training.train_state import TrainState
>>> import jax, jax.numpy as jnp
>>> import optax

>>> x = jnp.ones((1, 2))
>>> y = jnp.ones((1, 2))
>>> model = nn.Dense(2)
>>> variables = model.init(jax.random.key(0), x)
>>> tx = optax.adam(1e-3)

>>> state = TrainState.create(
...     apply_fn=model.apply,
...     params=variables['params'],
...     tx=tx)

>>> def loss_fn(params, x, y):
...   predictions = state.apply_fn({'params': params}, x)
...   loss = optax.l2_loss(predictions=predictions, targets=y).mean()
...   return loss
>>> loss_fn(state.params, x, y)
Array(3.3514676, dtype=float32)

>>> grads = jax.grad(loss_fn)(state.params, x, y)
>>> state = state.apply_gradients(grads=grads)
>>> loss_fn(state.params, x, y)
Array(3.343844, dtype=float32)

请注意,您可以通过对其进行子类化来轻松扩展此数据类,以存储其他数据(例如,其他变量集合)。

对于更特殊的使用案例(例如,多个优化器),最好 fork 该类并对其进行修改。

参数
  • step – 计数器从 0 开始,每次调用 .apply_gradients() 时都会递增。

  • apply_fn – 通常设置为 model.apply()。将其保留在此数据类中是为了方便地为训练循环中的 train_step() 函数提供较短的参数列表。

  • params – 要由 tx 更新并由 apply_fn 使用的参数。

  • tx – Optax 梯度变换。

  • opt_statetx 的状态。

apply_gradients(*, grads, **kwargs)[源代码]#

更新返回值中的 stepparamsopt_state**kwargs

请注意,此函数内部会先调用 .tx.update(),然后调用 optax.apply_updates() 来更新 paramsopt_state

参数
  • grads – 梯度,其 pytree 结构与 .params 相同。

  • **kwargs – 应该使用 .replace() 替换的其他 dataclass 属性。

返回

返回 self 的更新实例,其中 step 递增 1,paramsopt_state 通过应用 grads 更新,并且其他属性按照 kwargs 指定的方式替换。

classmethod create(*, apply_fn, params, tx, **kwargs)[源代码]#

创建一个新的实例,其中 step=0 并且初始化 opt_state

提前停止#

class flax.training.early_stopping.EarlyStopping(min_delta=0, patience=0, best_metric=inf, patience_count=0, should_stop=False, has_improved=False)[源代码]#

提前停止,以避免在训练期间过拟合。

以下示例将在当前 epoch 和前一个 epoch 中记录的损失之间的差异连续 2 次小于 1e-3 时提前停止训练

>>> from flax.training.early_stopping import EarlyStopping

>>> def train_epoch(optimizer, train_ds, batch_size, epoch, input_rng):
...   ...
...   loss = [4, 3, 3, 3, 2, 2, 2, 2, 1, 1][epoch]
...   return None, {'loss': loss}

>>> early_stop = EarlyStopping(min_delta=1e-3, patience=2)
>>> optimizer = None
>>> for epoch in range(10):
...   optimizer, train_metrics = train_epoch(
...       optimizer=optimizer, train_ds=None, batch_size=None, epoch=epoch, input_rng=None)
...   early_stop = early_stop.update(train_metrics['loss'])
...   if early_stop.should_stop:
...     print(f'Met early stopping criteria, breaking at epoch {epoch}')
...     break
Met early stopping criteria, breaking at epoch 7
min_delta#

被视为改进的更新之间的最小增量。

类型

浮点数

patience#

停止之前没有改进的步数。

类型

整数

best_metric#

当前最佳指标值。

类型

浮点数

patience_count#

自上次改进更新以来的步数。

类型

整数

should_stop#

是否应停止训练循环以避免过拟合。

类型

布尔值

has_improved#

在上次调用 .update 时,该指标的改进是否大于或等于 min_delta。

类型

布尔值

update(metric)[源代码]#

根据指标更新状态。

返回

更新后的 EarlyStopping 类。当从之前的 best_metric 有大于 min_delta 的改进时,.has_improved 属性为 True。

通用实用程序#

flax.training.common_utils.shard(xs)[源代码]#

用于 pmap 的助手,用于按 local_device_count 对数组的 pytree 进行分片。

参数

xs – 数组的 pytree。

返回

一个匹配的 pytree,其数组的开头维度按本地设备计数进行分片。

flax.training.common_utils.shard_prng_key(prng_key)[源代码]#

助手,用于对 PRNGKey 进行分片(又名拆分),以便与 pmap'd 函数一起使用。

PRNG 密钥可在训练时用于驱动随机模块,例如 Dropout。我们希望每个本地设备都有不同的 PRNG 密钥,以便我们在每个设备上获得不同的随机数,因此我们拆分了 PRNG 密钥。

参数

prng_key – JAX PRNGKey

返回

一个新 PRNGKey 数组,其开头维度等于本地设备计数。

flax.training.common_utils.stack_forest(forest)[源代码]#

助手函数,用于堆叠 pytree 序列的叶子。

参数

forest – 一个 pytree 序列(例如元组或列表),其结构匹配,叶子是具有单独匹配形状的数组。

返回

一个相同结构的单个 pytree,其叶子是单独的

堆叠数组。

flax.training.common_utils.get_metrics(device_metrics)[源代码]#

用于 pmap 的助手实用程序,用于收集复制的时间序列指标数据。

参数

device_metrics – 复制的、驻留在设备中的指标数据 pytree,其叶子被假定为随时间记录的数组序列。

返回

一个未复制的、驻留在主机中的、随时间堆叠的数组的 pytree,可用于计算主机本地统计信息和日志记录。

flax.training.common_utils.onehot(labels, num_classes, on_value=1.0, off_value=0.0)[源代码]#

创建索引数组的密集 one-hot 版本。

注意:考虑使用更标准的 jax.nn.one_hot 来代替。

参数
  • labels – 一个 n 维 JAX 数组,其最后一个维度包含整数索引。

  • num_classes – 最大可能的索引。

  • on_value – one-hot 数组的“on”值,默认为 1.0。

  • off_value – one-hot 数组的“off”值,默认为 0.0。

返回

一个(n+1)维数组,其最后一个维度包含长度为 num_classes 的 one-hot 向量。