保存和加载检查点#

本指南演示如何使用 Orbax 保存和加载 Flax 检查点。

Orbax 提供了各种用于保存和加载模型数据的功能,您将在本文档中了解这些功能

  • 支持各种数组类型和存储格式

  • 异步保存以减少训练等待时间

  • 版本控制和过去检查点的自动簿记

  • 灵活的 转换 来调整和加载旧检查点

  • 基于 jax.sharding 的 API,用于在多主机场景中保存和加载


正在进行向 Orbax 的迁移

在 2023 年 7 月 30 日之后,Flax 的旧版 flax.training.checkpoints API 将被弃用,转而使用 Orbax

  • 如果您是新的 Flax 用户:请使用新的 orbax.checkpoint API,如本指南所示。

  • 如果您的项目中有旧版 flax.training.checkpoints 代码:请考虑以下选项

    • 将您的代码迁移到 Orbax(推荐):按照此迁移指南,将您的 API 调用迁移到 orbax.checkpoint API。

    • 自动使用 Orbax 后端:将 flax.config.update('flax_use_orbax_checkpointing', True) 添加到您的项目中,这将使您的 flax.training.checkpoints 调用自动使用 Orbax 后端来保存您的检查点。

      • 计划翻转:这将在 2023 年 5 月(暂定日期)之后成为默认模式。

      • 如果您在自动迁移中遇到任何问题,请访问Orbax 作为后端故障排除部分


为了向后兼容,本指南显示了 Flax 旧版 flax.training.checkpoints API 中与 Orbax 等效的调用。

如果您需要了解更多关于 orbax.checkpoint 的信息,请参阅Orbax 文档

设置#

安装/升级 Flax 和 Orbax。对于支持 GPU/TPU 的 JAX 安装,请访问GitHub 上的此部分

注意:在运行 import jax 之前,创建八个虚拟设备来模拟此笔记本中的多主机环境。请注意,此处的导入顺序很重要。os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' 命令仅适用于 CPU 后端,这意味着如果您在 Google Colab 中运行此笔记本,则它将无法使用 GPU/TPU 加速。如果您已经在多个设备上运行代码(例如,在 4x2 TPU 环境中),则可以跳过运行下一个单元格。

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
from typing import Optional, Any
import shutil

import numpy as np
import jax
from jax import random, numpy as jnp

import flax
from flax import linen as nn
from flax.training import checkpoints, train_state
from flax import struct, serialization
import orbax.checkpoint

import optax
WARNING:absl:Tensorflow library not found, tensorflow.io.gfile operations will use native shim calls. GCS paths (i.e. 'gs://...') cannot be accessed.
ckpt_dir = '/tmp/flax_ckpt'

if os.path.exists(ckpt_dir):
    shutil.rmtree(ckpt_dir)  # Remove any existing checkpoints from the last notebook run.

保存检查点#

在 Orbax 和 Flax 中,您可以保存和加载任何给定的 JAX pytree。这不仅包括典型的 Python 和 NumPy 容器,还包括从 flax.struct.dataclass 扩展的自定义类。这意味着您可以存储几乎所有生成的数据 — 不仅是您的模型参数,还有任何数组/字典、元数据/配置等。

首先,创建一个具有多种数据结构和容器的 pytree,并进行尝试

# A simple model with one linear layer.
key1, key2 = random.split(random.key(0))
x1 = random.normal(key1, (5,))      # A simple JAX array.
model = nn.Dense(features=3)
variables = model.init(key2, x1)

# Flax's TrainState is a pytree dataclass and is supported in checkpointing.
# Define your class with `@flax.struct.dataclass` decorator to make it compatible.
tx = optax.sgd(learning_rate=0.001)      # An Optax SGD optimizer.
state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=variables['params'],
    tx=tx)
# Perform a simple gradient update similar to the one during a normal training workflow.
state = state.apply_gradients(grads=jax.tree_util.tree_map(jnp.ones_like, state.params))

# Some arbitrary nested pytree with a dictionary and a NumPy array.
config = {'dimensions': np.array([5, 3])}

# Bundle everything together.
ckpt = {'model': state, 'config': config, 'data': [x1]}
ckpt
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1695322343.254588       1 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
{'model': TrainState(step=1, apply_fn=<bound method Module.apply of Dense(
     # attributes
     features = 3
     use_bias = True
     dtype = None
     param_dtype = float32
     precision = None
     kernel_init = init
     bias_init = zeros
     dot_general = dot_general
     dot_general_cls = None
 )>, params={'bias': Array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': Array([[ 0.26048955, -0.61399287, -0.23458514],
        [ 0.11050402, -0.8765793 ,  0.9800635 ],
        [ 0.36260957,  0.18276349, -0.6856061 ],
        [-0.8519373 , -0.6416717 , -0.4818122 ],
        [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x13d5d83a0>, update=<function chain.<locals>.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState())),
 'config': {'dimensions': array([5, 3])},
 'data': [Array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],      dtype=float32)]}

使用 Orbax#

使用 orbax.checkpoint.PyTreeCheckpointer 将检查点直接保存到 tmp/orbax/single_save 目录。

注意:提供了一个可选的 save_args。建议使用此选项来提高性能,因为它将 pytree 中的较小数组捆绑到一个较大的文件中,而不是多个较小的文件。

from flax.training import orbax_utils

orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(ckpt)
orbax_checkpointer.save('/tmp/flax_ckpt/orbax/single_save', ckpt, save_args=save_args)

接下来,要使用版本控制和自动簿记功能,您需要将 orbax.checkpoint.CheckpointManager 包装在 orbax.checkpoint.PyTreeCheckpointer 之上。

此外,提供 orbax.checkpoint.CheckpointManagerOptions 来自定义您的需求,例如您希望以什么频率和在什么标准下删除旧的检查点。有关提供的所有选项的完整列表,请参阅文档

orbax.checkpoint.CheckpointManager 应放置在训练步骤之外的顶层,以管理您的保存。

options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=2, create=True)
checkpoint_manager = orbax.checkpoint.CheckpointManager(
    '/tmp/flax_ckpt/orbax/managed', orbax_checkpointer, options)

# Inside a training loop
for step in range(5):
    # ... do your training
    checkpoint_manager.save(step, ckpt, save_kwargs={'save_args': save_args})

os.listdir('/tmp/flax_ckpt/orbax/managed')  # Because max_to_keep=2, only step 3 and 4 are retained
['4', '3']

使用旧版 API#

这是使用旧版 Flax 检查点实用程序进行保存的方法(请注意,与 orbax.checkpoint.CheckpointManagerOptions 相比,此方法提供的管理功能较少)

# Import Flax Checkpoints.
from flax.training import checkpoints

checkpoints.save_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing',
                            target=ckpt,
                            step=0,
                            overwrite=True,
                            keep=2)
'/tmp/flax_ckpt/flax-checkpointing/checkpoint_0'

恢复检查点#

使用 Orbax#

在 Orbax 中,对 orbax.checkpoint.PyTreeCheckpointerorbax.checkpoint.CheckpointManager 调用 .restore(),以原始 pytree 格式恢复您的检查点。

raw_restored = orbax_checkpointer.restore('/tmp/flax_ckpt/orbax/single_save')
raw_restored
{'config': {'dimensions': array([5, 3])},
 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],
        dtype=float32)],
 'model': {'opt_state': [None, None],
  'params': {'bias': array([-0.001, -0.001, -0.001], dtype=float32),
   'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],
          [ 0.11050402, -0.8765793 ,  0.9800635 ],
          [ 0.36260957,  0.18276349, -0.6856061 ],
          [-0.8519373 , -0.6416717 , -0.4818122 ],
          [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)},
  'step': 1}}

请注意,step 数字对于 CheckpointManger 是必需的。您还可以使用 .latest_step() 来查找最新的可用步骤。

step = checkpoint_manager.latest_step()  # step = 4
checkpoint_manager.restore(step)
{'config': {'dimensions': array([5, 3])},
 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],
        dtype=float32)],
 'model': {'opt_state': [None, None],
  'params': {'bias': array([-0.001, -0.001, -0.001], dtype=float32),
   'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],
          [ 0.11050402, -0.8765793 ,  0.9800635 ],
          [ 0.36260957,  0.18276349, -0.6856061 ],
          [-0.8519373 , -0.6416717 , -0.4818122 ],
          [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)},
  'step': 1}}

使用旧版 API#

请注意,随着向 Orbax 的迁移正在进行中,flax.training.checkpointing.restore_checkpoint 可以自动识别检查点是使用旧版 Flax 格式保存还是使用 Orbax 后端保存的,并正确恢复 pytree。因此,添加 flax.config.update('flax_use_orbax_checkpointing', True) 不会损害您恢复旧检查点的能力。

以下是使用旧版 API 恢复检查点的方法

raw_restored = checkpoints.restore_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing', target=None)
raw_restored
{'config': {'dimensions': array([5, 3])},
 'data': {'0': array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],
        dtype=float32)},
 'model': {'opt_state': {'0': None, '1': None},
  'params': {'bias': array([-0.001, -0.001, -0.001], dtype=float32),
   'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],
          [ 0.11050402, -0.8765793 ,  0.9800635 ],
          [ 0.36260957,  0.18276349, -0.6856061 ],
          [-0.8519373 , -0.6416717 , -0.4818122 ],
          [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)},
  'step': 1}}

使用自定义数据类恢复#

使用 Orbax#

  • 先前示例中恢复的 pytree 采用原始字典的形式。原始 pytree 包含诸如 TrainStateoptax 状态之类的自定义数据类。

  • 这是因为在恢复 pytree 时,程序还不知道它曾经属于哪个结构。

  • 为了解决这个问题,您应该首先提供一个示例 pytree,以便让 Orbax 或 Flax 确切地知道要恢复到哪个结构。

本节演示如何显式设置任何自定义 Flax 数据类,使其与已保存的检查点具有相同的结构。

注意:以 JAX NumPy 数组 (jnp.array) 格式保存的数据将被恢复为 NumPy 数组 (numpy.array)。这不会影响您的工作,因为一旦计算开始,JAX 会自动将 NumPy 数组转换为 JAX 数组。

empty_state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=jax.tree_util.tree_map(np.zeros_like, variables['params']),  # values of the tree leaf doesn't matter
    tx=tx,
)
empty_config = {'dimensions': np.array([0, 0])}
target = {'model': empty_state, 'config': empty_config, 'data': [jnp.zeros_like(x1)]}
state_restored = orbax_checkpointer.restore('/tmp/flax_ckpt/orbax/single_save', item=target)
state_restored
{'config': {'dimensions': array([5, 3])},
 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],
        dtype=float32)],
 'model': TrainState(step=1, apply_fn=<bound method Module.apply of Dense(
     # attributes
     features = 3
     use_bias = True
     dtype = None
     param_dtype = float32
     precision = None
     kernel_init = init
     bias_init = zeros
     dot_general = dot_general
     dot_general_cls = None
 )>, params={'bias': array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],
        [ 0.11050402, -0.8765793 ,  0.9800635 ],
        [ 0.36260957,  0.18276349, -0.6856061 ],
        [-0.8519373 , -0.6416717 , -0.4818122 ],
        [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x13d5d83a0>, update=<function chain.<locals>.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState()))}

使用旧版 API#

或者,您可以从 Orbax CheckpointManager 和旧版 Flax 代码恢复,如下所示

checkpoint_manager.restore(4, items=target)
{'config': {'dimensions': array([5, 3])},
 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],
        dtype=float32)],
 'model': TrainState(step=1, apply_fn=<bound method Module.apply of Dense(
     # attributes
     features = 3
     use_bias = True
     dtype = None
     param_dtype = float32
     precision = None
     kernel_init = init
     bias_init = zeros
     dot_general = dot_general
     dot_general_cls = None
 )>, params={'bias': array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],
        [ 0.11050402, -0.8765793 ,  0.9800635 ],
        [ 0.36260957,  0.18276349, -0.6856061 ],
        [-0.8519373 , -0.6416717 , -0.4818122 ],
        [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x13d5d83a0>, update=<function chain.<locals>.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState()))}
checkpoints.restore_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing', target=target)
WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
{'model': TrainState(step=1, apply_fn=<bound method Module.apply of Dense(
     # attributes
     features = 3
     use_bias = True
     dtype = None
     param_dtype = float32
     precision = None
     kernel_init = init
     bias_init = zeros
     dot_general = dot_general
     dot_general_cls = None
 )>, params={'bias': array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],
        [ 0.11050402, -0.8765793 ,  0.9800635 ],
        [ 0.36260957,  0.18276349, -0.6856061 ],
        [-0.8519373 , -0.6416717 , -0.4818122 ],
        [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x13d5d83a0>, update=<function chain.<locals>.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState())),
 'config': {'dimensions': array([5, 3])},
 'data': [Array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],      dtype=float32)]}

通常建议重构初始化检查点结构的过程(例如,TrainState),以便保存/加载更容易且不易出错。这是因为诸如 apply_fntx (优化器)之类的函数和复杂对象无法序列化到检查点文件中,必须由代码初始化。

在检查点结构不同时恢复#

在您的开发过程中,当更改模型、在调整期间添加/删除字段等时,您的检查点结构将会发生变化。

本节解释如何将旧数据加载到您的新代码中。

下面是一个简单的示例,一个从 flax.training.train_state.TrainState 扩展的 CustomTrainState,其中包含一个名为 batch_stats 的额外字段。在处理真实世界的模型时,当应用批次归一化时,您可能需要它。

在这里,您将新的 CustomTrainState 存储为第 5 步,而第 4 步包含旧的/先前的 TrainState

class CustomTrainState(train_state.TrainState):
    batch_stats: Any = None

custom_state = CustomTrainState.create(
    apply_fn=state.apply_fn,
    params=state.params,
    tx=state.tx,
    batch_stats=np.arange(10),
)

custom_ckpt = {'model': custom_state, 'config': config, 'data': [x1]}
# Use a custom state to read the old `TrainState` checkpoint.
custom_target = {'model': custom_state, 'config': None, 'data': [jnp.zeros_like(x1)]}

# Save it in Orbax.
custom_save_args = orbax_utils.save_args_from_target(custom_ckpt)
checkpoint_manager.save(5, custom_ckpt, save_kwargs={'save_args': custom_save_args})
True

建议使您的检查点与 pytree 数据类定义保持同步。但是,您可能被迫在运行时使用不兼容的参考对象来恢复检查点。发生这种情况时,检查点恢复将在给定时尝试遵守参考的结构。

以下是一些常见场景的示例。

场景 1:当参考对象是部分的时#

如果您的参考对象是检查点的子树,则恢复将忽略其他字段,并恢复与参考具有相同结构的检查点。

如下例所示,CustomTrainState 中的 batch_stats 字段被忽略,并且检查点被恢复为 TrainState

这对于仅读取检查点的一部分也很有用。

restored = checkpoint_manager.restore(5, items=target)
assert not hasattr(restored, 'batch_stats')
assert type(restored['model']) == train_state.TrainState
restored
{'config': {'dimensions': array([5, 3])},
 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],
        dtype=float32)],
 'model': TrainState(step=0, apply_fn=<bound method Module.apply of Dense(
     # attributes
     features = 3
     use_bias = True
     dtype = None
     param_dtype = float32
     precision = None
     kernel_init = init
     bias_init = zeros
     dot_general = dot_general
     dot_general_cls = None
 )>, params={'bias': array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],
        [ 0.11050402, -0.8765793 ,  0.9800635 ],
        [ 0.36260957,  0.18276349, -0.6856061 ],
        [-0.8519373 , -0.6416717 , -0.4818122 ],
        [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x13d5d83a0>, update=<function chain.<locals>.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState()))}

场景 2:当检查点是部分的时#

另一方面,如果参考对象包含检查点中不可用的值,则检查点代码默认情况下会警告某些数据不兼容。

要绕过错误,您需要传递一个 Orbax transform,该转换会教导 Orbax 如何将此检查点符合 custom_target 的结构。

在这种情况下,传递一个默认的 {},允许 Orbax 使用 custom_target 中的值来填充空白。这允许您将旧的检查点恢复到新的数据结构 CustomTrainState 中。

try:
    checkpoint_manager.restore(4, items=custom_target)
except KeyError as e:
    print(f'KeyError when target state has an unmentioned field: {e}')
    print('')

# Step 4 is an original `TrainState`, without the `batch_stats`
custom_restore_args = orbax_utils.restore_args_from_target(custom_target)
restored = checkpoint_manager.restore(4, items=custom_target,
                                      restore_kwargs={'transforms': {}, 'restore_args': custom_restore_args})
assert type(restored['model']) == CustomTrainState
np.testing.assert_equal(restored['model'].batch_stats,
                        custom_target['model'].batch_stats)
restored
WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
KeyError when target state has an unmentioned field: 'batch_stats'
{'config': None,
 'data': [Array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],      dtype=float32)],
 'model': CustomTrainState(step=1, apply_fn=<bound method Module.apply of Dense(
     # attributes
     features = 3
     use_bias = True
     dtype = None
     param_dtype = float32
     precision = None
     kernel_init = init
     bias_init = zeros
     dot_general = dot_general
     dot_general_cls = None
 )>, params={'bias': Array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': Array([[ 0.26048955, -0.61399287, -0.23458514],
        [ 0.11050402, -0.8765793 ,  0.9800635 ],
        [ 0.36260957,  0.18276349, -0.6856061 ],
        [-0.8519373 , -0.6416717 , -0.4818122 ],
        [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x13d5d83a0>, update=<function chain.<locals>.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState()), batch_stats=array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))}

使用 Orbax#

如果您已经使用 Orbax 后端保存了检查点,则可以使用 orbax_transforms 来访问 Flax API 中的此 transforms 参数。

# Save in the "Flax-with-Orbax" backend.
flax.config.update('flax_use_orbax_checkpointing', True)
checkpoints.save_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing',
                            target=ckpt,
                            step=4,
                            overwrite=True,
                            keep=2)

checkpoints.restore_checkpoint('/tmp/flax_ckpt/flax-checkpointing', target=custom_target, step=4,
                               orbax_transforms={})
WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
{'model': CustomTrainState(step=1, apply_fn=<bound method Module.apply of Dense(
     # attributes
     features = 3
     use_bias = True
     dtype = None
     param_dtype = float32
     precision = None
     kernel_init = init
     bias_init = zeros
     dot_general = dot_general
     dot_general_cls = None
 )>, params={'bias': Array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': Array([[ 0.26048955, -0.61399287, -0.23458514],
        [ 0.11050402, -0.8765793 ,  0.9800635 ],
        [ 0.36260957,  0.18276349, -0.6856061 ],
        [-0.8519373 , -0.6416717 , -0.4818122 ],
        [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x13d5d83a0>, update=<function chain.<locals>.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState()), batch_stats=array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])),
 'config': None,
 'data': [Array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],      dtype=float32)]}

使用旧版 API#

使用旧版 flax.training.checkpoints API,也可以执行类似的操作,但它们不如 Orbax Transformations 灵活。

您需要使用 target=None 将检查点恢复为原始字典,相应地修改结构,然后将其反序列化回原始目标。

# Save using the legacy Flax `checkpoints` API.
flax.config.update('flax_use_orbax_checkpointing', False)
checkpoints.save_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing',
                            target=ckpt,
                            step=5,
                            overwrite=True,
                            keep=2)

# Pass no target to get a raw state dictionary first.
raw_state_dict = checkpoints.restore_checkpoint('/tmp/flax_ckpt/flax-checkpointing', target=None, step=5)
# Add/remove fields as needed.
raw_state_dict['model']['batch_stats'] = np.flip(np.arange(10))
# Restore the classes with correct target now
flax.serialization.from_state_dict(custom_target, raw_state_dict)
{'model': CustomTrainState(step=1, apply_fn=<bound method Module.apply of Dense(
     # attributes
     features = 3
     use_bias = True
     dtype = None
     param_dtype = float32
     precision = None
     kernel_init = init
     bias_init = zeros
     dot_general = dot_general
     dot_general_cls = None
 )>, params={'bias': array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],
        [ 0.11050402, -0.8765793 ,  0.9800635 ],
        [ 0.36260957,  0.18276349, -0.6856061 ],
        [-0.8519373 , -0.6416717 , -0.4818122 ],
        [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x13d5d83a0>, update=<function chain.<locals>.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState()), batch_stats=array([9, 8, 7, 6, 5, 4, 3, 2, 1, 0])),
 'config': {'dimensions': array([5, 3])},
 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],
        dtype=float32)]}

异步检查点#

检查点是 I/O 密集型的,如果您有大量数据要保存,则可能值得将其放入后台线程中,同时继续进行训练。

您可以通过创建 orbax.checkpoint.AsyncCheckpointer 来代替 orbax.checkpoint.PyTreeCheckpointer 来执行此操作。

注意:您应该使用相同的 async_checkpointer 来处理训练步骤中的所有异步保存,以便它可以确保在下一个异步保存开始之前完成先前的异步保存。这使得诸如 keep (检查点的数量)和 overwrite 之类的记账在各个步骤中保持一致。

每当您想显式等待异步保存完成时,可以调用 async_checkpointer.wait_until_finished()

# `orbax.checkpoint.AsyncCheckpointer` needs some multi-process initialization, because it was
# originally designed for multi-process large model checkpointing.
# For Python notebooks or other single-process settings, just set up with `num_processes=1`.
# Refer to https://jax.net.cn/en/latest/multi_process.html#initializing-the-cluster
# for how to set it up in multi-process scenarios.
jax.distributed.initialize("localhost:8889", num_processes=1, process_id=0)

async_checkpointer = orbax.checkpoint.AsyncCheckpointer(
    orbax.checkpoint.PyTreeCheckpointHandler(), timeout_secs=50)

# Save your job:
async_checkpointer.save('/tmp/flax_ckpt/orbax/single_save_async', ckpt, save_args=save_args)
# ... Continue with your work...

# ... Until a time when you want to wait until the save completes:
async_checkpointer.wait_until_finished()  # Blocks until the checkpoint saving is completed.
async_checkpointer.restore('/tmp/flax_ckpt/orbax/single_save_async', item=target)
{'config': {'dimensions': array([5, 3])},
 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],
        dtype=float32)],
 'model': TrainState(step=1, apply_fn=<bound method Module.apply of Dense(
     # attributes
     features = 3
     use_bias = True
     dtype = None
     param_dtype = float32
     precision = None
     kernel_init = init
     bias_init = zeros
     dot_general = dot_general
     dot_general_cls = None
 )>, params={'bias': array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],
        [ 0.11050402, -0.8765793 ,  0.9800635 ],
        [ 0.36260957,  0.18276349, -0.6856061 ],
        [-0.8519373 , -0.6416717 , -0.4818122 ],
        [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x13d5d83a0>, update=<function chain.<locals>.update_fn at 0x13d5d8dc0>), opt_state=(EmptyState(), EmptyState()))}

如果您正在使用 Orbax CheckpointManager,只需在初始化时传入 async_checkpointer 即可。然后在实践中,调用 async_checkpoint_manager.wait_until_finished()

async_checkpoint_manager = orbax.checkpoint.CheckpointManager(
    '/tmp/flax_ckpt/orbax/managed_async', async_checkpointer, options)
async_checkpoint_manager.wait_until_finished()

多主机/多进程检查点#

JAX 提供了几种方法来同时在多个主机上扩展代码。当设备(CPU/GPU/TPU)数量非常大以至于不同的设备由不同的主机(CPU)管理时,通常会发生这种情况。要在多进程设置中开始使用 JAX,请查看在多主机和多进程环境中使用 JAX以及分布式数组指南

在 JAX jit单程序多数据 (SPMD) 范例中,大型多进程数组的数据可以在不同的设备之间分片。(请注意,JAX pjitjit 已合并为一个统一的接口。要了解如何在多主机或多核环境中编译和执行 JAX 函数,请参阅本指南jax.Array 迁移指南。)当多进程数组被序列化时,每个主机将其数据分片转储到单个共享存储中,例如 Google Cloud 存储桶。

Orbax 支持以与单进程 pytree 相同的方式保存和加载具有多进程数组的 pytree。但是,建议使用异步orbax.AsyncCheckpointer在另一个线程上保存大型多进程数组,以便您可以与保存同时执行计算。对于纯 Orbax,在多进程环境中保存检查点使用与单进程环境中相同的 API。

from jax.sharding import PartitionSpec, NamedSharding

# Create an array sharded across multiple devices.
mesh_shape = (4, 2)
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
mesh = jax.sharding.Mesh(devices, ('x', 'y'))

mp_array = jax.device_put(np.arange(8 * 2).reshape(8, 2),
                          NamedSharding(mesh, PartitionSpec('x', 'y')))

# Make it a pytree.
mp_ckpt = {'model': mp_array}
async_checkpoint_manager.save(0, mp_ckpt)
async_checkpoint_manager.wait_until_finished()

当恢复具有多进程数组的检查点时,您需要指定每个数组应恢复到的分片。否则,它们将在进程 0 上恢复为大型 np.array,从而耗费时间和内存。

(在本笔记本中,由于我们处于单进程中,即使我们提供分片,它也会恢复为 np.array。)

使用 Orbax#

Orbax 允许您通过在 restore_args 中传递 分片 的 pytree 来指定这一点。如果您已经有一个参考 pytree,其中所有数组都具有正确的分片,则可以使用 orbax_utils.restore_args_from_target 将其转换为 Orbax 需要的 restore_args

# The reference doesn't need to be as large as your checkpoint!
# Just make sure it has the `.sharding` you want.
mp_smaller = jax.device_put(np.arange(8).reshape(4, 2),
                            NamedSharding(mesh, PartitionSpec('x', 'y')))
ref_ckpt = {'model': mp_smaller}

restore_args = orbax_utils.restore_args_from_target(ref_ckpt)
async_checkpoint_manager.restore(
    0, items=ref_ckpt, restore_kwargs={'restore_args': restore_args})
{'model': Array([[ 0,  1],
        [ 2,  3],
        [ 4,  5],
        [ 6,  7],
        [ 8,  9],
        [10, 11],
        [12, 13],
        [14, 15]], dtype=int32)}

使用旧版 Flax:使用 save_checkpoint_multiprocess#

在旧版 Flax 中,要保存多进程数组,请使用 flax.training.checkpoints.save_checkpoint_multiprocess() 代替 save_checkpoint() 并使用相同的参数。

如果您的检查点太大,则可以在管理器中指定 timeout_secs,并给它更多时间来完成写入。

async_checkpointer = orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.PyTreeCheckpointHandler(), timeout_secs=50)
checkpoints.save_checkpoint_multiprocess(ckpt_dir,
                                         mp_ckpt,
                                         step=3,
                                         overwrite=True,
                                         keep=4,
                                         orbax_checkpointer=async_checkpointer)
'/tmp/flax_ckpt/checkpoint_3'
mp_restored = checkpoints.restore_checkpoint(ckpt_dir,
                                             target=ref_ckpt,
                                             step=3,
                                             orbax_checkpointer=async_checkpointer)
mp_restored
WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
{'model': Array([[ 0,  1],
        [ 2,  3],
        [ 4,  5],
        [ 6,  7],
        [ 8,  9],
        [10, 11],
        [12, 13],
        [14, 15]], dtype=int32)}

Orbax 作为后端故障排除#

作为迁移(从旧版 Flax checkpoints API 迁移到 Orbax)的中间阶段,从 2023 年 5 月 15 日开始,flax.training.checkpoints API 将开始使用 Orbax 作为其后端来保存检查点。

使用 Orbax 后端保存的检查点可以被 flax.training.checkpoints.restore_checkpointorbax.checkpoint.PyTreeCheckpointer 读取。

在代码层面,这等同于将配置标志 flax.config.flax_use_orbax_checkpointing 的默认值设置为 True。 您可以随时在您的项目中使用 flax.config.update('flax_use_orbax_checkpointing', <BoolValue>) 来覆盖此值。

一般来说,这种自动迁移不会影响大多数用户。但是,如果您的 API 使用遵循某些特定模式,您可能会遇到问题。请查看以下章节进行故障排除。

如果您的设备在写入检查点时挂起#

如果您在多主机环境(通常是大于 8 个 TPU 设备)中运行,并且您的设备在写入检查点时挂起,请检查您的代码是否符合以下模式(即,save_checkpoint 仅在主机 0 上运行)

if jax.process_index() == 0:
  flax.training.checkpoints.save_checkpoint(...)

不幸的是,这是一种将被弃用且不再支持的遗留模式,因为在多进程环境中,检查点代码应该在主机之间进行协调,而不是仅在主机 0 上触发。将上面的代码替换为以下代码应该可以解决挂起问题

flax.training.checkpoints.save_checkpoint_multiprocess(...)

如果您不保存 pytrees#

Orbax 使用 orbax.checkpoint.PyTreeCheckpointHandler 来保存检查点,这意味着它们只保存 pytrees。

如果您想保存单个数组或数字,您有两个选择

  1. 使用 orbax.ArrayCheckpointHandler 保存它们,请参考 此迁移部分

  2. 将其包装在 pytree 中并像往常一样保存。