flax.struct 包

flax.struct 包#

用于定义可与 jax 转换一起使用的自定义类的实用程序。

flax.struct.dataclass(clz=None, **kwargs)[源代码]#

创建一个可以传递给函数式转换的类。

注意

继承 PyTreeNode 以避免在使用 PyType 时出现类型检查问题。

诸如 jax.jitjax.grad 之类的 Jax 转换需要不可变的对象,并且可以使用 jax.tree_util 方法进行映射。dataclass 装饰器可以轻松定义可以安全地传递给 Jax 的自定义类。例如

>>> from flax import struct
>>> import jax
>>> from typing import Any, Callable

>>> @struct.dataclass
... class Model:
...   params: Any
...   # use pytree_node=False to indicate an attribute should not be touched
...   # by Jax transformations.
...   apply_fn: Callable = struct.field(pytree_node=False)

...   def __apply__(self, *args):
...     return self.apply_fn(*args)

>>> params = {}
>>> params_b = {}
>>> apply_fn = lambda v, x: x
>>> model = Model(params, apply_fn)

>>> # model.params = params_b  # Model is immutable. This will raise an error.
>>> model_b = model.replace(params=params_b)  # Use the replace method instead.

>>> # This class can now be used safely in Jax to compute gradients w.r.t. the
>>> # parameters.
>>> model = Model(params, apply_fn)
>>> loss_fn = lambda model: 3.
>>> model_grad = jax.grad(loss_fn)(model)

请注意,数据类具有自动生成的 __init__,其中构造函数的参数和创建的实例的属性 1:1 匹配。这种对应关系使得这些对象成为可与 JAX 转换以及更一般的 jax.tree_util 库一起使用的有效容器。

有时需要一个“智能构造函数”,例如因为某些属性可以(可选地)从其他属性派生。使用 Flax 数据类执行此操作的方法是制作一个提供智能构造函数的静态方法或类方法。这样就保留了 jax.tree_util 使用的简单构造函数。请考虑以下示例

>>> @struct.dataclass
... class DirectionAndScaleKernel:
...   direction: jax.Array
...   scale: jax.Array

...   @classmethod
...   def create(cls, kernel):
...     scale = jax.numpy.linalg.norm(kernel, axis=0, keepdims=True)
...     direction = direction / scale
...     return cls(direction, scale)
参数
  • clz – 将被装饰器转换的类。

  • **kwargs – 传递给数据类构造函数的参数。

返回

新类。

class flax.struct.PyTreeNode(*args, **kwargs)[源代码]#

应该像 JAX pytree 节点一样的数据类的基类。

有关 jax.tree_util 行为,请参见 flax.struct.dataclass。此基类还可以避免在使用 PyType 时出现类型检查错误。

示例

>>> from flax import struct
>>> import jax
>>> from typing import Any, Callable

>>> class Model(struct.PyTreeNode):
...   params: Any
...   # use pytree_node=False to indicate an attribute should not be touched
...   # by Jax transformations.
...   apply_fn: Callable = struct.field(pytree_node=False)

...   def __apply__(self, *args):
...     return self.apply_fn(*args)

>>> params = {}
>>> params_b = {}
>>> apply_fn = lambda v, x: x
>>> model = Model(params, apply_fn)

>>> # model.params = params_b  # Model is immutable. This will raise an error.
>>> model_b = model.replace(params=params_b)  # Use the replace method instead.

>>> # This class can now be used safely in Jax to compute gradients w.r.t. the
>>> # parameters.
>>> model = Model(params, apply_fn)
>>> loss_fn = lambda model: 3.
>>> model_grad = jax.grad(loss_fn)(model)