flax.serialization 包#

用于 Jax 的序列化实用程序。

所有带有状态的 Flax 类(例如,优化器)都可以转换为 numpy 数组的状态字典,以便于序列化。

状态字典#

flax.serialization.from_state_dict(target, state, name='.')[源代码]#

使用状态字典恢复给定目标的状态。

此函数将当前目标作为参数。这让我们知道目标的精确结构,以及让我们添加断言来确保形状和 dtype 不会改变。

实际上,target 中的叶值实际上都不会使用。只使用树结构、形状和 dtype。

参数
  • target – 应恢复其状态的对象。

  • state – 由 to_state_dict 生成的字典,其中包含 target 的所需新状态。

  • name – 采用的分支名称,用于改进反序列化错误消息。

返回

具有恢复状态的对象的副本。

flax.serialization.to_state_dict(target)[源代码]#

返回包含给定目标状态的字典。

flax.serialization.register_serialization_state(ty, ty_to_state_dict, ty_from_state_dict, override=False)[源代码]#

注册一个用于序列化的类型。

参数
  • ty – 要注册的类型

  • ty_to_state_dict – 一个函数,它接受 ty 的实例并将其状态作为字典返回。

  • ty_from_state_dict – 一个函数,它接受 ty 的实例和一个状态字典,并返回带有恢复状态的实例副本。

  • override – 覆盖先前注册的序列化处理程序(默认值:False)。

使用 MessagePack 进行序列化#

flax.serialization.msgpack_serialize(pytree, in_place=False)[源代码]#

以 msgpack 格式将数据结构保存为字节。

仅支持带有数组叶子的 python 树的低级函数,对于自定义对象,请使用 to_bytes。 它将高于 MAX_CHUNK_SIZE 的数组拆分为多个块。

参数
  • pytree – 带有 python 原始类型和数组叶子的字典、列表、元组的 python 树。

  • in_place – 布尔值,指定是否应就地修改 pytree。

返回

pytree 的 msgpack 编码字节。

flax.serialization.msgpack_restore(encoded_pytree)[源代码]#

从 msgpack 格式的字节恢复数据结构。

仅支持带有数组叶子的 python 树的低级函数,对于自定义对象,请使用 from_bytes

参数

encoded_pytree – python 树的 msgpack 编码字节。

返回

带有 python 原始类型和数组叶子的字典、列表、元组的 Python 树。

flax.serialization.to_bytes(target)[源代码]#

将优化器或其他对象保存为 msgpack 序列化的状态字典。

参数

target – 带有状态字典注册的模板对象,该对象将被序列化为 msgpack 格式。通常是 flax 模型或优化器。

返回

target 对象的 msgpack 编码状态字典的字节。

flax.serialization.from_bytes(target, encoded_bytes)[源代码]#

从 msgpack 序列化的状态字典中恢复优化器或其他对象。

参数
  • target – 带有状态字典注册的模板对象,该对象与从 encoded_bytes 反序列化的结构匹配。

  • encoded_bytes – 与 target 在结构上同构的 msgpack 序列化对象。通常是 flax 模型或优化器。

返回

一个与 target 在结构上同构的新对象,其中包含来自已保存数据的更新的叶数据。