模块#
Flax 模块系统。
- class flax.linen.Module[源代码]#
所有神经网络模块的基类。
层和模型应子类化此类。
所有 Flax 模块都是 Python 3.7 数据类。由于数据类接管了
__init__
,您应该改为重写setup()
,它会自动调用以初始化模块。模块可以包含子模块,并通过这种方式嵌套在树结构中。子模型可以在
setup()
方法内部以常规属性的形式分配。您可以在 Module 子类上定义任意“前向传递”方法。虽然没有特殊处理的方法,但
__call__
是一个受欢迎的选择,因为它允许您像使用函数一样使用模块实例>>> from flax import linen as nn >>> from typing import Tuple >>> class Module(nn.Module): ... features: Tuple[int, ...] = (16, 4) ... def setup(self): ... self.dense1 = nn.Dense(self.features[0]) ... self.dense2 = nn.Dense(self.features[1]) ... def __call__(self, x): ... return self.dense2(nn.relu(self.dense1(x)))
或者,为了更简洁的模块实现,其中子模块定义与其用法并置,您可以使用
compact()
包装器。- __setattr__(name, val)[源代码]#
在此模块上设置属性。
我们重载 setattr 只是为了支持在特殊的
setup()
函数中通过子模块赋值进行 Python 式命名self.submodule_name = MyModule(...)
我们也支持列表和其他通用 pytree,例如
self.submodules = [MyModule0(..), MyModule1(..), ...]
- 参数
name – 要设置的属性。
val – 属性的值。
- apply(variables, *args, rngs=None, method=None, mutable=False, capture_intermediates=False, **kwargs)[源代码]#
将模块方法应用于变量并返回输出和修改后的变量。
请注意,如果要对
__call__
以外的其他类方法调用apply
,则应设置method
。例如,假设 Transformer 模块有一个名为encode
的方法,则以下代码在那个方法上调用apply
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> import numpy as np >>> class Transformer(nn.Module): ... def encode(self, x): ... ... >>> x = jnp.ones((16, 9)) >>> model = Transformer() >>> variables = model.init(jax.random.key(0), x, method=Transformer.encode) >>> encoded = model.apply(variables, x, method=Transformer.encode)
如果提供了函数实例,则使用未绑定的函数。例如,下面的示例等效于上面的示例
>>> encoded = model.apply(variables, x, method=model.encode)
您还可以将字符串传递给模块的可调用属性。例如,前面的内容可以写成
>>> encoded = model.apply(variables, x, method='encode')
请注意,
method
也可以是不在Transformer
中定义的函数。在这种情况下,该函数应至少有一个参数表示 Module 类的实例>>> def other_fn(instance, x): ... # instance.some_module_attr(...) ... instance.encode ... ... >>> model.apply(variables, x, method=other_fn)
如果您传递单个
PRNGKey
,Flax 将使用它来馈送'params'
RNG 流。如果您想使用不同的 RNG 流或需要使用多个流,您可以将每个 RNG 流名称映射到其相应的PRNGKey
的字典传递给apply
。如果在用户未传递的 RNG 流名称上调用self.make_rng(name)
,则它将默认使用'params'
RNG 流。示例
>>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x, add_noise=False): ... x = nn.Dense(16)(x) ... x = nn.relu(x) ... ... if add_noise: ... # Add gaussian noise ... noise_key = self.make_rng('noise') ... x = x + jax.random.normal(noise_key, x.shape) ... ... return nn.Dense(1)(x) >>> x = jnp.empty((1, 7)) >>> module = Foo() >>> rngs = {'params': jax.random.key(0), 'noise': jax.random.key(1)} >>> variables = module.init(rngs, x) >>> out0 = module.apply(variables, x, add_noise=True, rngs=rngs) >>> rngs['noise'] = jax.random.key(0) >>> out1 = module.apply(variables, x, add_noise=True, rngs=rngs) >>> # different output (key(1) vs key(0)) >>> np.testing.assert_raises(AssertionError, np.testing.assert_allclose, out0, out1) >>> del rngs['noise'] >>> # self.make_rng('noise') will default to using the 'params' RNG stream >>> out2 = module.apply(variables, x, add_noise=True, rngs=rngs) >>> # same output (key(0)) >>> np.testing.assert_allclose(out1, out2) >>> # passing in a single key is equivalent to passing in {'params': key} >>> out3 = module.apply(variables, x, add_noise=True, rngs=jax.random.key(0)) >>> # same output (key(0)) >>> np.testing.assert_allclose(out2, out3)
- 参数
variables – 一个字典,其中包含按变量集合键控的变量。有关变量的更多详细信息,请参阅
flax.core.variables
。*args – 传递给指定 apply 方法的命名参数。
rngs – 用于初始化 PRNG 序列的 PRNGKey 字典。“params” PRNG 序列用于初始化参数。
method – 要在其上调用 apply 的函数。这通常是模块中的一个函数。如果提供,则应用此方法。如果未提供,则应用模块的
__call__
方法。也可以提供字符串来按名称指定方法。mutable – 可以是布尔值、字符串或列表。指定哪些集合应视为可变的:
bool
:所有/没有集合是可变的。str
:单个可变集合的名称。list
:可变集合名称的列表。capture_intermediates – 如果
True
,则捕获“intermediates”集合内部所有模块的中间返回值。默认情况下,仅存储所有__call__
方法的返回值。可以传递一个函数来更改过滤器行为。过滤器函数接受 Module 实例和方法名称,并返回一个布尔值,指示是否应存储该方法调用的输出。**kwargs – 传递给指定 apply 方法的关键字参数。
- 返回
如果
mutable
为 False,则返回输出。如果任何集合是可变的,则返回(output, vars)
,其中vars
是修改后的集合的字典。
- bind(variables, *args, rngs=None, mutable=False)[source]#
通过绑定变量和 RNG 创建一个交互式的 Module 实例。
bind
直接提供一个 Module 的“交互式”实例,而无需使用apply
转换函数。这对于调试和交互式用例(如笔记本)特别有用,在这些用例中,函数会限制将代码拆分到不同单元格的能力。一旦变量(以及可选的 RNG)绑定到
Module
,它就会变成一个有状态的对象。请注意,符合习惯的 JAX 是函数式的,因此交互式实例与普通的 JAX API 不太兼容。bind()
应该仅用于交互式实验,在所有其他情况下,我们强烈建议用户使用apply()
。示例
>>> import jax >>> import jax.numpy as jnp >>> import flax.linen as nn >>> class AutoEncoder(nn.Module): ... def setup(self): ... self.encoder = nn.Dense(3) ... self.decoder = nn.Dense(5) ... ... def __call__(self, x): ... return self.decoder(self.encoder(x)) >>> x = jnp.ones((16, 9)) >>> ae = AutoEncoder() >>> variables = ae.init(jax.random.key(0), x) >>> model = ae.bind(variables) >>> z = model.encoder(x) >>> x_reconstructed = model.decoder(z)
- 参数
variables – 一个字典,其中包含按变量集合键控的变量。有关变量的更多详细信息,请参阅
flax.core.variables
。*args – 命名参数(未使用)。
rngs – 用于初始化 PRNG 序列的 PRNGKeys 字典。
mutable – 可以是布尔值、字符串或列表。指定哪些集合应视为可变的:
bool
:所有/没有集合是可变的。str
:单个可变集合的名称。list
:可变集合名称的列表。
- 返回
此实例的副本,带有绑定的变量和 RNG。
- copy(*, parent=<flax.linen.module._Sentinel object>, name=None, **updates)[source]#
创建此 Module 的副本,可以选择更新参数。
- 参数
parent – 副本的父级。默认情况下,如果未显式指定,则将当前模块作为父级。
name – 复制的 Module 的新名称,默认情况下会给出新的自动名称。
**updates – 属性更新。
- 返回
此 Module 的副本,具有更新的名称、父级和属性。
- get_variable(col, name, default=None)[source]#
检索变量的值。
- 参数
col – 变量集合。
name – 变量的名称。
default – 如果此作用域中不存在变量,则返回的默认值。
- 返回
输入变量的值,如果此作用域中不存在该变量,则返回默认值。
- has_variable(col, name)[source]#
检查此 Module 中是否存在给定集合和名称的变量。
有关变量和集合的更多说明,请参阅
flax.core.variables
。- 参数
col – 变量集合名称。
name – 变量的名称。
- 返回
如果变量存在,则为 True。
- init(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), capture_intermediates=False, **kwargs)[source]#
使用变量初始化模块方法并返回修改后的变量。
init
将单个PRNGKey
或将变量集合名称映射到其PRNGKeys
的字典作为第一个参数,并且将调用method
(默认情况下是模块的__call__
函数),传递*args
和**kwargs
,并返回初始化的变量字典。示例
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> import numpy as np >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x, train): ... x = nn.Dense(16)(x) ... x = nn.BatchNorm(use_running_average=not train)(x) ... x = nn.relu(x) ... return nn.Dense(1)(x) >>> x = jnp.empty((1, 7)) >>> module = Foo() >>> key = jax.random.key(0) >>> variables = module.init(key, x, train=False)
如果传递单个
PRNGKey
,Flax 将使用它来提供'params'
RNG 流。如果要使用不同的 RNG 流或需要使用多个流,则可以传递一个字典,将每个 RNG 流名称映射到其对应的PRNGKey
到init
。如果在用户未传递的 RNG 流名称上调用self.make_rng(name)
,则它将默认使用'params'
RNG 流。示例
>>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(16)(x) ... x = nn.relu(x) ... ... other_variable = self.variable( ... 'other_collection', ... 'other_variable', ... lambda x: jax.random.normal(self.make_rng('other_rng'), x.shape), ... x, ... ) ... x = x + other_variable.value ... ... return nn.Dense(1)(x) >>> module = Foo() >>> rngs = {'params': jax.random.key(0), 'other_rng': jax.random.key(1)} >>> variables0 = module.init(rngs, x) >>> rngs['other_rng'] = jax.random.key(0) >>> variables1 = module.init(rngs, x) >>> # equivalent params (key(0)) >>> _ = jax.tree_util.tree_map( ... np.testing.assert_allclose, variables0['params'], variables1['params'] ... ) >>> # different other_variable (key(1) vs key(0)) >>> np.testing.assert_raises( ... AssertionError, ... np.testing.assert_allclose, ... variables0['other_collection']['other_variable'], ... variables1['other_collection']['other_variable'], ... ) >>> del rngs['other_rng'] >>> # self.make_rng('other_rng') will default to using the 'params' RNG stream >>> variables2 = module.init(rngs, x) >>> # equivalent params (key(0)) >>> _ = jax.tree_util.tree_map( ... np.testing.assert_allclose, variables1['params'], variables2['params'] ... ) >>> # equivalent other_variable (key(0)) >>> np.testing.assert_allclose( ... variables1['other_collection']['other_variable'], ... variables2['other_collection']['other_variable'], ... ) >>> # passing in a single key is equivalent to passing in {'params': key} >>> variables3 = module.init(jax.random.key(0), x) >>> # equivalent params (key(0)) >>> _ = jax.tree_util.tree_map( ... np.testing.assert_allclose, variables2['params'], variables3['params'] ... ) >>> # equivalent other_variable (key(0)) >>> np.testing.assert_allclose( ... variables2['other_collection']['other_variable'], ... variables3['other_collection']['other_variable'], ... )
Jitting
init
仅使用提供的参数的形状延迟初始化模型,并避免使用实际值计算正向传递。示例>>> module = nn.Dense(1) >>> init_jit = jax.jit(module.init) >>> variables = init_jit(jax.random.key(0), x)
init
是apply
的一个轻量级包装器,因此其他apply
参数(如method
、mutable
和capture_intermediates
)也可用。- 参数
rngs – 变量集合的 rngs。
*args – 传递给 init 函数的命名参数。
method – 一个可选方法。如果提供,则应用此方法。如果未提供,则应用
__call__
方法。也可以提供一个字符串来按名称指定方法。mutable – 可以是 bool、str 或 list。指定应将哪些集合视为可变的:
bool
:所有/没有集合是可变的。str
:单个可变集合的名称。list
:可变集合名称的列表。默认情况下,除了“intermediates”之外的所有集合都是可变的。capture_intermediates – 如果为
True
,则捕获“intermediates”集合中所有 Module 的中间返回值。默认情况下,仅存储所有__call__
方法的返回值。可以传递一个函数来更改筛选器行为。筛选器函数接受 Module 实例和方法名称,并返回一个布尔值,指示是否应存储该方法调用的输出。**kwargs – 传递给 init 函数的关键字参数。
- 返回
初始化的变量字典。
- init_with_output(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), capture_intermediates=False, **kwargs)[source]#
使用变量初始化模块方法并返回输出和修改后的变量。
- 参数
rngs – 变量集合的 rngs。
*args – 传递给 init 函数的命名参数。
method – 一个可选方法。如果提供,则应用此方法。如果未提供,则应用
__call__
方法。也可以提供一个字符串来按名称指定方法。mutable – 可以是 bool、str 或 list。指定应将哪些集合视为可变的:
bool
:所有/没有集合是可变的。str
:单个可变集合的名称。list
:可变集合名称的列表。默认情况下,除了“intermediates”之外的所有集合都是可变的。capture_intermediates – 如果为
True
,则捕获“intermediates”集合中所有 Module 的中间返回值。默认情况下,仅存储所有__call__
方法的返回值。可以传递一个函数来更改筛选器行为。筛选器函数接受 Module 实例和方法名称,并返回一个布尔值,指示是否应存储该方法调用的输出。**kwargs – 传递给 init 函数的关键字参数。
- 返回
(output, vars)
,其中vars
是修改后的集合的字典。
- is_initializing()[source]#
如果在 self.init(…) 或 nn.init(…)() 下运行,则返回 True。
这是一个辅助方法,用于处理简单初始化的常见情况,在这种情况下,我们希望仅在
module.init
或nn.init
下调用时发生设置逻辑。对于更复杂的多阶段初始化方案,最好测试特定变量集合的可变性或是否存在可能需要初始化的特定变量。
- lazy_init(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), **kwargs)[源代码]#
初始化模块,无需在实际输入上进行计算。
lazy_init
将在不进行不必要计算的情况下初始化变量。输入数据应作为jax.ShapeDtypeStruct
传递,该结构指定输入的形状和数据类型,但不包含具体数据。示例
>>> model = nn.Dense(features=256) >>> variables = model.lazy_init( ... jax.random.key(0), jax.ShapeDtypeStruct((1, 128), jnp.float32))
传递给
lazy_init
的 args 和 kwargs 参数可以是具体值(jax 数组、标量、布尔值)和抽象值(ShapeDtypeStruct)的混合。只有在初始化变量时才会用到具体值。例如,模型可能需要一个关键字参数来启用/禁用模型的子部分。在这种情况下,应传递显式值 (True/False),否则lazy_init
无法推断应初始化哪些变量。- 参数
rngs – 变量集合的 rngs。
*args – 传递给 init 函数的参数。
method – 一个可选的方法。如果提供,则应用此方法。如果未提供,则应用
__call__
方法。mutable – 可以是 bool、str 或 list。指定应将哪些集合视为可变的:
bool
:所有/没有集合是可变的。str
:单个可变集合的名称。list
:可变集合名称的列表。默认情况下,除了“intermediates”之外的所有集合都是可变的。**kwargs – 传递给 init 函数的关键字参数。
- 返回
初始化的变量字典。
- make_rng(name='params')[源代码]#
从此模块的给定 RNG 序列返回一个新的 RNG 密钥。
新的 RNG 密钥是从之前的密钥分割出来的。因此,每次调用
make_rng
都会返回一个新的 RNG 密钥,同时仍然保证完全可重现性。注意
如果传递无效名称(即用户在
.init
或.apply
中没有为此名称传递 RNG 密钥),则name
将默认为'params'
。示例
>>> import jax >>> import flax.linen as nn >>> class ParamsModule(nn.Module): ... def __call__(self): ... return self.make_rng('params') >>> class OtherModule(nn.Module): ... def __call__(self): ... return self.make_rng('other') >>> key = jax.random.key(0) >>> params_out, _ = ParamsModule().init_with_output({'params': key}) >>> # self.make_rng('other') will default to using the 'params' RNG stream >>> other_out, _ = OtherModule().init_with_output({'params': key}) >>> assert params_out == other_out
通过阅读 Flax RNG 指南了解有关 RNG 的更多信息:https://flax.org.cn/en/latest/guides/flax_fundamentals/rng_guide.html
- 参数
name – RNG 序列名称。
- 返回
新生成的 RNG 密钥。
- module_paths(rngs, *args, show_repeated=False, mutable=DenyList(deny='intermediates'), **kwargs)[源代码]#
返回一个字典,该字典将模块路径映射到模块实例。
此方法具有相同的签名,并且在内部调用
Module.init
,但是它不是返回变量,而是返回一个字典,该字典将模块路径映射到运行时使用的模块实例的无界副本。module_paths
使用jax.eval_shape
来运行前向计算,而不消耗任何 FLOP 或分配内存。示例
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... h = nn.Dense(4)(x) ... return nn.Dense(2)(h) >>> x = jnp.ones((16, 9)) >>> modules = Foo().module_paths(jax.random.key(0), x) >>> print({ ... p: type(m).__name__ for p, m in modules.items() ... }) {'': 'Foo', 'Dense_0': 'Dense', 'Dense_1': 'Dense'}
- 参数
rngs – 传递给
Module.init
的变量集合的 rngs。*args – 前向计算的参数。
show_repeated – 如果为
True
,则表中将显示对同一模块的重复调用,否则仅显示首次调用。默认为False
。mutable – 可以是 bool、str 或列表。指定哪些集合应被视为可变的:
bool
:所有/无集合是可变的。str
:单个可变集合的名称。list
:可变集合名称的列表。默认情况下,除“intermediates”以外的所有集合都是可变的。**kwargs – 传递给前向计算的关键字参数。
- 返回
一个将模块路径映射到模块实例的字典。
- param(name, init_fn, *init_args, unbox=True, **init_kwargs)[源代码]#
在此模块中声明并返回一个参数。
参数是名为“params”的集合中的只读变量。有关变量的更多详细信息,请参见
flax.core.variables
。假定
init_fn
的第一个参数是 PRNG 密钥,该密钥会自动提供,无需使用init_args
或init_kwargs
传递。>>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(4)(x) ... mean = self.param('mean', nn.initializers.lecun_normal(), x.shape) ... ... ... return x * mean >>> variables = Foo().init({'params': jax.random.key(0), 'stats': jax.random.key(1)}, jnp.ones((2, 3))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'Dense_0': {'bias': (4,), 'kernel': (3, 4)}, 'mean': (2, 4)}}
在上面的示例中,函数
lecun_normal
需要两个参数:key
和shape
,但只需要显式提供shape
;key
是使用 PRNG 自动设置的,该 PRNG 用于使用init()
初始化模块时传递的params
。- 参数
name – 参数名称。
init_fn – 将被调用以计算此变量初始值的函数。此函数仅在此参数首次在此模块中使用时调用。
*init_args – 传递给 init_fn 的位置参数。
unbox – 如果为 True,则
AxisMetadata
实例将替换为其未装箱的值,请参阅flax.nn.meta.unbox
(默认值:True)。**init_kwargs – 传递给 init_fn 的关键字参数。
- 返回
已初始化参数的值。如果参数已存在,则抛出错误。
- property path#
获取此模块的路径。顶级根模块具有空路径
()
。请注意,此方法只能在具有有效范围的绑定模块上使用。用法示例
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> class SubModel(nn.Module): ... @nn.compact ... def __call__(self, x): ... print(f'SubModel path: {self.path}') ... return x >>> class Model(nn.Module): ... @nn.compact ... def __call__(self, x): ... print(f'Model path: {self.path}') ... return SubModel()(x) >>> model = Model() >>> variables = model.init(jax.random.key(0), jnp.ones((1, 2))) Model path: () SubModel path: ('SubModel_0',)
- perturb(name, value, collection='perturbations')[源代码]#
向中间值添加一个零值变量(“扰动”)。
value
的梯度与此扰动变量的梯度相同。因此,如果您将损失函数定义为具有参数和扰动作为独立参数,则可以通过对扰动参数运行jax.grad
来获得value
的中间梯度。注意
这是一个实验性 API,可能会在以后进行调整以获得更好的性能和可用性。在当前阶段,它会创建占用额外内存空间的额外虚拟变量。仅使用它来调试训练中的梯度。
示例
>>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(3)(x) ... x = self.perturb('dense3', x) ... return nn.Dense(2)(x) >>> def loss(variables, inputs, targets): ... preds = model.apply(variables, inputs) ... return jnp.square(preds - targets).mean() >>> x = jnp.ones((2, 9)) >>> y = jnp.ones((2, 2)) >>> model = Foo() >>> variables = model.init(jax.random.key(0), x) >>> intm_grads = jax.grad(loss, argnums=0)(variables, x, y) >>> print(intm_grads['perturbations']['dense3']) [[-1.456924 -0.44332537 0.02422847] [-1.456924 -0.44332537 0.02422847]]
如果未将扰动传递给
apply
,则perturb
的行为类似于空操作,因此您可以在不需要时轻松禁用该行为。>>> model.apply(variables, x) # works as expected Array([[-1.0980128 , -0.67961735], [-1.0980128 , -0.67961735]], dtype=float32) >>> model.apply({'params': variables['params']}, x) # behaves like a no-op Array([[-1.0980128 , -0.67961735], [-1.0980128 , -0.67961735]], dtype=float32) >>> intm_grads = jax.grad(loss, argnums=0)({'params': variables['params']}, x, y) >>> 'perturbations' not in intm_grads True
- put_variable(col, name, value)[源代码]#
如果给定的变量是可变的,则更新其值,否则会报错。
- 参数
col – 变量集合。
name – 变量的名称。
value – 变量的新值。
- setup()[源代码]#
延迟初始化模块(类似于延迟
__init__
)。当模块绑定时,在调用任何其他方法(如
__call__
)之前,或者在访问self
上定义的setup
属性之前,会延迟调用一次模块实例上的setup
。这可能发生在三种情况下
一旦通过分配给另一个模块在其他模块的
setup
方法中的属性,该模块被赋予名称(请参阅__setattr__()
)>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
一旦模块在用
compact()
包装的方法内部构建,则会在调用另一个方法或访问setup
定义的属性之前立即构建。
- sow(col, name, value, reduce_fn=<function <lambda>>, init_fn=<function <lambda>>)[源代码]#
将值存储在集合中。
集合可用于收集中间值,而无需显式地通过每个模块调用传递容器。
如果目标集合是不可变的,则
sow
的行为类似于空操作,并返回False
。示例
>>> import jax >>> import jax.numpy as jnp >>> import flax.linen as nn >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... h = nn.Dense(4)(x) ... self.sow('intermediates', 'h', h) ... return nn.Dense(2)(h) >>> x = jnp.ones((16, 9)) >>> model = Foo() >>> variables = model.init(jax.random.key(0), x) >>> y, state = model.apply(variables, x, mutable=['intermediates']) >>> jax.tree.map(jnp.shape, state['intermediates']) {'h': ((16, 4),)}
默认情况下,这些值存储在元组中,并且每个存储的值都追加在末尾。这样,当多次调用同一模块时,可以跟踪所有中间值。或者,可以传递自定义的 init/reduce 函数。
>>> class Foo2(nn.Module): ... @nn.compact ... def __call__(self, x): ... init_fn = lambda: 0 ... reduce_fn = lambda a, b: a + b ... self.sow('intermediates', 'h', x, ... init_fn=init_fn, reduce_fn=reduce_fn) ... self.sow('intermediates', 'h', x * 2, ... init_fn=init_fn, reduce_fn=reduce_fn) ... return x >>> x = jnp.ones((1, 1)) >>> model = Foo2() >>> variables = model.init(jax.random.key(0), x) >>> y, state = model.apply( ... variables, x, mutable=['intermediates']) >>> print(state['intermediates']) {'h': Array([[3.]], dtype=float32)}
- 参数
col – 变量集合的名称。
name – 变量的名称。
value – 变量的值。
reduce_fn – 用于将现有值与新值组合的函数。默认是将值追加到元组。
init_fn – 对于存储的第一个值,
reduce_fn
将传递init_fn
的结果以及要存储的值。默认值是空元组。
- 返回
如果值已成功存储,则返回
True
,否则返回False
。
- tabulate(rngs, *args, depth=None, show_repeated=False, mutable=DenyList(deny='intermediates'), console_kwargs=None, table_kwargs=mappingproxy({}), column_kwargs=mappingproxy({}), compute_flops=False, compute_vjp_flops=False, **kwargs)[源代码]#
创建一个以表格形式表示的模块摘要。
此方法具有相同的签名,并在内部调用
Module.init
,但它不返回变量,而是返回以表格形式总结模块的字符串。tabulate
使用jax.eval_shape
来运行前向计算,而不会消耗任何 FLOP 或分配内存。可以将其他参数传递到
console_kwargs
参数中,例如,{'width': 120}
。有关console_kwargs
参数的完整列表,请参阅:https://rich.pythonlang.cn/en/stable/reference/console.html#rich.console.Console示例
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... h = nn.Dense(4)(x) ... return nn.Dense(2)(h) >>> x = jnp.ones((16, 9)) >>> # print(Foo().tabulate( >>> # jax.random.key(0), x, compute_flops=True, compute_vjp_flops=True))
这会给出以下输出
Foo Summary ┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ ┃ path ┃ module ┃ inputs ┃ outputs ┃ flops ┃ vjp_flops ┃ params ┃ ┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ │ │ Foo │ float32[16,9] │ float32[16,2] │ 1504 │ 4460 │ │ ├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤ │ Dense_0 │ Dense │ float32[16,9] │ float32[16,4] │ 1216 │ 3620 │ bias: │ │ │ │ │ │ │ │ float32[4] │ │ │ │ │ │ │ │ kernel: │ │ │ │ │ │ │ │ float32[9,4] │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ 40 (160 B) │ ├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤ │ Dense_1 │ Dense │ float32[16,4] │ float32[16,2] │ 288 │ 840 │ bias: │ │ │ │ │ │ │ │ float32[2] │ │ │ │ │ │ │ │ kernel: │ │ │ │ │ │ │ │ float32[4,2] │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ 10 (40 B) │ ├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤ │ │ │ │ │ │ Total │ 50 (200 B) │ └─────────┴────────┴───────────────┴───────────────┴───────┴───────────┴─────────────────┘ Total Parameters: 50 (200 B)
注意:表格中的行顺序不代表执行顺序,而是与
variables
中键的顺序一致,这些键按字母顺序排序。注意:如果模块不可微分,则
vjp_flops
返回0
。- 参数
rngs – 传递给
Module.init
的变量集合的 rngs。*args – 前向计算的参数。
depth – 控制摘要可以深入多少个子模块。默认情况下,它为
None
,这意味着没有限制。如果由于深度限制而未显示子模块,则其参数计数和字节数将添加到其第一个显示的祖先的行中,以便所有行的总和始终等于模块的参数总数。show_repeated – 如果为
True
,则表中将显示对同一模块的重复调用,否则仅显示首次调用。默认为False
。mutable – 可以是 bool、str 或列表。指定哪些集合应被视为可变的:
bool
:所有/无集合是可变的。str
:单个可变集合的名称。list
:可变集合名称的列表。默认情况下,除“intermediates”以外的所有集合都是可变的。console_kwargs – 一个可选字典,其中包含在呈现表格时传递给
rich.console.Console
的其他关键字参数。默认参数为{'force_terminal': True, 'force_jupyter': False}
。table_kwargs – 一个可选字典,其中包含传递给
rich.table.Table
构造函数的其他关键字参数。column_kwargs – 一个可选字典,其中包含在向表格添加列时传递给
rich.table.Table.add_column
的其他关键字参数。compute_flops – 是否在表格中包含
flops
列,列出每个模块前向传递的估计 FLOP 成本。确实会产生实际的设备上计算/编译/内存分配,但对于大型模块仍然会引入开销(例如,对于 Stable Diffusion 的 UNet 额外需要 20 秒,而其他情况下制表会在 5 秒内完成)。compute_vjp_flops – 是否在表格中包含
vjp_flops
列,列出每个模块反向传递的估计 FLOP 成本。引入的计算开销约为compute_flops
的 2-3 倍。**kwargs – 传递给前向计算的关键字参数。
- 返回
一个总结模块的字符串。
- unbind()[源代码]#
返回模块及其变量的未绑定副本。
unbind
有助于创建绑定模块的无状态版本。一个常见用例的示例:提取在
setup()
内部定义的子模块及其相应的变量:1) 临时bind
父模块;然后 2)unbind
所需的子模块。(回想一下,setup()
仅在模块绑定时调用。)>>> class Encoder(nn.Module): ... @nn.compact ... def __call__(self, x): ... ... ... return nn.Dense(256)(x) >>> class Decoder(nn.Module): ... @nn.compact ... def __call__(self, x): ... ... ... return nn.Dense(784)(x) >>> class AutoEncoder(nn.Module): ... def setup(self): ... self.encoder = Encoder() ... self.decoder = Decoder() ... ... def __call__(self, x): ... return self.decoder(self.encoder(x)) >>> module = AutoEncoder() >>> variables = module.init(jax.random.key(0), jnp.ones((1, 784))) >>> # Extract the Encoder sub-Module and its variables >>> encoder, encoder_vars = module.bind(variables).encoder.unbind()
- 返回
一个包含此模块的未绑定副本及其变量的元组。
- variable(col, name, init_fn=None, *init_args, unbox=True, **init_kwargs)[源代码]#
在此模块中声明并返回一个变量。
有关更多信息,请参阅
flax.core.variables
。另请参阅param()
,以了解在“params”集合中定义只读变量的简写方式。与
param()
相反,使用init_fn
传递的所有参数都应显式传递。>>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(4)(x) ... key = self.make_rng('stats') ... mean = self.variable('stats', 'mean', nn.initializers.lecun_normal(), key, x.shape) ... ... ... return x * mean.value >>> variables = Foo().init({'params': jax.random.key(0), 'stats': jax.random.key(1)}, jnp.ones((2, 3))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'Dense_0': {'bias': (4,), 'kernel': (3, 4)}}, 'stats': {'mean': (2, 4)}}
在上面的示例中,函数
lecun_normal
期望两个参数:key
和shape
,并且都必须传递。当调用init()
和apply()
时,必须显式提供stats
的 PRNG。- 参数
col – 变量集合名称。
name – 变量名称。
init_fn – 将调用该函数来计算此变量的初始值。此函数仅在此变量第一次在此模块中使用时调用。如果为 None,则该变量必须已初始化,否则会引发错误。
*init_args – 传递给 init_fn 的位置参数。
unbox – 如果为 True,则
AxisMetadata
实例将替换为其未装箱的值,请参阅flax.nn.meta.unbox
(默认值:True)。**init_kwargs – 要传递给 init_fn 的关键字参数
- 返回
一个
flax.core.variables.Variable
,可以通过“.value”属性读取或设置。如果变量已存在,则会引发错误。
- property variables#
返回此模块中的变量。
- flax.linen.apply(fn, module, mutable=False, capture_intermediates=False)[源代码]#
创建一个 apply 函数以使用绑定模块调用
fn
。与
Module.apply
不同,此函数返回一个具有以下签名的新函数:(variables, *args, rngs=None, **kwargs) -> T
,其中T
是fn
的返回类型。如果mutable
不是False
,则返回类型是一个元组,其中第二个项是带有已变异变量的FrozenDict
。返回的 apply 函数可以直接与 JAX 转换(如
jax.jit
)组合使用。>>> class Foo(nn.Module): ... def encode(self, x): ... ... ... def decode(self, x): ... ... >>> def f(foo, x): ... z = foo.encode(x) ... y = foo.decode(z) ... # ... ... return y >>> variables = {} >>> foo = Foo() >>> f_jitted = jax.jit(nn.apply(f, foo)) >>> f_jitted(variables, jnp.ones((1, 3)))
- 参数
fn – 要应用的函数。传递的第一个参数将是具有绑定变量和 RNG 的
module
的模块实例。module – 将用于绑定变量和 RNG 的
Module
。作为第一个参数传递给fn
的Module
将是 module 的克隆。mutable – 可以是布尔值、字符串或列表。指定哪些集合应视为可变的:
bool
:所有/没有集合是可变的。str
:单个可变集合的名称。list
:可变集合名称的列表。capture_intermediates – 如果为
True
,则捕获“intermediates”集合中所有模块的中间返回值。默认情况下,仅存储所有 __call__ 方法的返回值。可以传递一个函数来更改筛选行为。筛选函数接收模块实例和方法名称,并返回一个布尔值,指示是否应存储该方法调用的输出。
- 返回
包装
fn
的 apply 函数。
- flax.linen.init(fn, module, mutable=DenyList(deny='intermediates'), capture_intermediates=False)[来源]#
创建一个 init 函数,使用绑定的模块调用
fn
。与
Module.init
不同,此函数返回一个具有以下签名的新函数:(rngs, *args, **kwargs) -> variables
。rngs 可以是 PRNGKey 的字典,也可以是单个`PRNGKey
,它等效于传递一个具有名为“params”的 PRNGKey 的字典。返回的 init 函数可以直接与 JAX 转换(如
jax.jit
)组合使用。>>> class Foo(nn.Module): ... def encode(self, x): ... ... ... def decode(self, x): ... ... >>> def f(foo, x): ... z = foo.encode(x) ... y = foo.decode(z) ... # ... ... return y >>> foo = Foo() >>> f_jitted = jax.jit(nn.init(f, foo)) >>> variables = f_jitted(jax.random.key(0), jnp.ones((1, 3)))
- 参数
fn – 要应用的函数。传递的第一个参数将是具有绑定变量和 RNG 的
module
的模块实例。module – 将用于绑定变量和 RNG 的
Module
。作为第一个参数传递给fn
的Module
将是 module 的克隆。mutable – 可以是 bool、str 或 list。指定应将哪些集合视为可变的:
bool
:所有/没有集合是可变的。str
:单个可变集合的名称。list
:可变集合名称的列表。默认情况下,除了“intermediates”之外的所有集合都是可变的。capture_intermediates – 如果为 True,则捕获“intermediates”集合中所有模块的中间返回值。默认情况下,仅存储所有 __call__ 方法的返回值。可以传递一个函数来更改筛选行为。筛选函数接收模块实例和方法名称,并返回一个布尔值,指示是否应存储该方法调用的输出。
- 返回
包装
fn
的 init 函数。
- flax.linen.init_with_output(fn, module, mutable=DenyList(deny='intermediates'), capture_intermediates=False)[来源]#
创建一个 init 函数,以使用绑定的模块调用
fn
,该函数还会返回函数输出。与
Module.init_with_output
不同,此函数返回一个具有以下签名的新函数:(rngs, *args, **kwargs) -> (T, variables)
,其中T
是fn
的返回类型。rngs 可以是 PRNGKey 的字典,也可以是单个`PRNGKey
,它等效于传递一个具有名为“params”的 PRNGKey 的字典。返回的 init 函数可以直接与 JAX 转换(如
jax.jit
)组合使用。>>> class Foo(nn.Module): ... def encode(self, x): ... ... ... def decode(self, x): ... ... >>> def f(foo, x): ... z = foo.encode(x) ... y = foo.decode(z) ... # ... ... return y >>> foo = Foo() >>> f_jitted = jax.jit(nn.init_with_output(f, foo)) >>> y, variables = f_jitted(jax.random.key(0), jnp.ones((1, 3)))
- 参数
fn – 要应用的函数。传递的第一个参数将是具有绑定变量和 RNG 的
module
的模块实例。module – 将用于绑定变量和 RNG 的
Module
。作为第一个参数传递给fn
的Module
将是 module 的克隆。mutable – 可以是 bool、str 或 list。指定应将哪些集合视为可变的:
bool
:所有/没有集合是可变的。str
:单个可变集合的名称。list
:可变集合名称的列表。默认情况下,除了“intermediates”之外的所有集合都是可变的。capture_intermediates – 如果为
True
,则捕获“intermediates”集合中所有模块的中间返回值。默认情况下,仅存储所有 __call__ 方法的返回值。可以传递一个函数来更改筛选行为。筛选函数接收模块实例和方法名称,并返回一个布尔值,指示是否应存储该方法调用的输出。
- 返回
包装
fn
的 init 函数。
- flax.linen.intercept_methods(interceptor)[来源]#
注册一个新的方法拦截器。
方法拦截器允许您(在远处)拦截对模块的方法调用。它的工作方式类似于装饰器。您可以在调用底层方法之前修改 args/kwargs,和/或修改调用底层方法返回的结果。或者您可以完全跳过调用底层方法,并决定执行其他操作。例如
>>> import flax.linen as nn >>> import jax.numpy as jnp ... >>> class Foo(nn.Module): ... def __call__(self, x): ... return x ... >>> def my_interceptor1(next_fun, args, kwargs, context): ... print('calling my_interceptor1') ... return next_fun(*args, **kwargs) ... >>> foo = Foo() >>> with nn.intercept_methods(my_interceptor1): ... _ = foo(jnp.ones([1])) calling my_interceptor1
您也可以在同一方法上注册多个拦截器。拦截器将按顺序运行。例如
>>> def my_interceptor2(next_fun, args, kwargs, context): ... print('calling my_interceptor2') ... return next_fun(*args, **kwargs) ... >>> with nn.intercept_methods(my_interceptor1), \ ... nn.intercept_methods(my_interceptor2): ... _ = foo(jnp.ones([1])) calling my_interceptor1 calling my_interceptor2
您可以通过直接调用
context.orig_method
来跳过其他拦截器。例如>>> def my_interceptor3(next_fun, args, kwargs, context): ... print('calling my_interceptor3') ... return context.orig_method(*args, **kwargs) >>> with nn.intercept_methods(my_interceptor3), \ ... nn.intercept_methods(my_interceptor1), \ ... nn.intercept_methods(my_interceptor2): ... _ = foo(jnp.ones([1])) calling my_interceptor3
以下方法无法拦截
使用
nn.nowrap
装饰的方法。包括
__eq__
、__repr__
、__init__
、__hash__
和__post_init__
在内的 Dunder 方法。模块数据类字段。
模块描述符。
- 参数
interceptor – 方法拦截器。
修改其中一个模块,使其共享同一作用域。当您想要包装一个模块并在不更改参数结构的情况下扩展其功能时,这非常有用。
share_scope
接受两个模块module
和other
。如果other
具有作用域,且它不是module
作用域的后代,则module
将使用other
的作用域。>>> import flax.linen as nn >>> import jax >>> from jax import numpy as jnp, random ... >>> class DenseLoRA(nn.Module): ... base: nn.Dense ... rank: int ... ... def setup(self): ... nn.share_scope(self, self.base) ... ... @nn.compact ... def __call__(self, x: jax.Array): ... din, dout = x.shape[-1], self.base.features ... A = self.param('A', nn.zeros_init(), (din, self.rank)) ... B = self.param('B', nn.zeros_init(), (self.rank, dout)) ... return self.base(x) + x @ A @ B ... >>> class Model(nn.Module): ... @nn.compact ... def __call__(self, x: jax.Array): ... dense = nn.Dense(10) # base scope ... return DenseLoRA(dense, rank=2)(x) # reuse the base scope ... >>> model = Model() ... >>> params = model.init(random.key(0), jnp.ones((1, 5)))['params'] >>> list(params['Dense_0'].keys()) ['A', 'B', 'kernel', 'bias']
当
other
的作用域是module
作用域的后代时,other
将改为使用module
的作用域。>>> class DenseLoRA(nn.Module): ... features: int ... rank: int ... ... def setup(self): ... self.child = nn.Dense(self.features) ... nn.share_scope(self, self.child) ... ... @nn.compact ... def __call__(self, x: jax.Array): ... din, dout = x.shape[-1], self.features ... A = self.param('A', nn.zeros_init(), (din, self.rank)) ... B = self.param('B', nn.zeros_init(), (self.rank, dout)) ... return self.child(x) + x @ A @ B ... >>> class Model(nn.Module): ... @nn.compact ... def __call__(self, x: jax.Array): ... return DenseLoRA(10, rank=2)(x) ... >>> model = Model() ... >>> params = model.init(random.key(0), jnp.ones((1, 5)))['params'] >>> list(params['DenseLoRA_0'].keys()) ['A', 'B', 'kernel', 'bias']