flax.serialization 包
用于 Jax 的序列化实用程序。
所有带有状态的 Flax 类(例如,优化器)都可以转换为 numpy 数组的状态字典,以便于序列化。
状态字典
-
flax.serialization.from_state_dict(target, state, name='.')[源代码]
使用状态字典恢复给定目标的状态。
此函数将当前目标作为参数。这让我们知道目标的精确结构,以及让我们添加断言来确保形状和 dtype 不会改变。
实际上,target
中的叶值实际上都不会使用。只使用树结构、形状和 dtype。
- 参数
-
- 返回
具有恢复状态的对象的副本。
-
flax.serialization.to_state_dict(target)[源代码]
返回包含给定目标状态的字典。
-
flax.serialization.register_serialization_state(ty, ty_to_state_dict, ty_from_state_dict, override=False)[源代码]
注册一个用于序列化的类型。
- 参数
ty – 要注册的类型
ty_to_state_dict – 一个函数,它接受 ty 的实例并将其状态作为字典返回。
ty_from_state_dict – 一个函数,它接受 ty 的实例和一个状态字典,并返回带有恢复状态的实例副本。
override – 覆盖先前注册的序列化处理程序(默认值:False)。
使用 MessagePack 进行序列化
-
flax.serialization.msgpack_serialize(pytree, in_place=False)[源代码]
以 msgpack 格式将数据结构保存为字节。
仅支持带有数组叶子的 python 树的低级函数,对于自定义对象,请使用 to_bytes
。 它将高于 MAX_CHUNK_SIZE 的数组拆分为多个块。
- 参数
-
- 返回
pytree 的 msgpack 编码字节。
-
flax.serialization.msgpack_restore(encoded_pytree)[源代码]
从 msgpack 格式的字节恢复数据结构。
仅支持带有数组叶子的 python 树的低级函数,对于自定义对象,请使用 from_bytes
。
- 参数
encoded_pytree – python 树的 msgpack 编码字节。
- 返回
带有 python 原始类型和数组叶子的字典、列表、元组的 Python 树。
-
flax.serialization.to_bytes(target)[源代码]
将优化器或其他对象保存为 msgpack 序列化的状态字典。
- 参数
target – 带有状态字典注册的模板对象,该对象将被序列化为 msgpack 格式。通常是 flax 模型或优化器。
- 返回
target
对象的 msgpack 编码状态字典的字节。
-
flax.serialization.from_bytes(target, encoded_bytes)[源代码]
从 msgpack 序列化的状态字典中恢复优化器或其他对象。
- 参数
-
- 返回
一个与 target
在结构上同构的新对象,其中包含来自已保存数据的更新的叶数据。