初始化/应用#
- flax.linen.apply(fn, module, mutable=False, capture_intermediates=False)[源代码]#
创建一个 apply 函数,使用绑定模块调用
fn
。与
Module.apply
不同,此函数返回一个新函数,其签名为(variables, *args, rngs=None, **kwargs) -> T
,其中T
是fn
的返回类型。 如果mutable
不是False
,则返回类型是一个元组,其中第二个项是具有已变异变量的FrozenDict
。返回的 apply 函数可以直接与 JAX 转换(如
jax.jit
)组合。>>> class Foo(nn.Module): ... def encode(self, x): ... ... ... def decode(self, x): ... ... >>> def f(foo, x): ... z = foo.encode(x) ... y = foo.decode(z) ... # ... ... return y >>> variables = {} >>> foo = Foo() >>> f_jitted = jax.jit(nn.apply(f, foo)) >>> f_jitted(variables, jnp.ones((1, 3)))
- 参数
fn – 应应用的函数。传递的第一个参数将是
module
的模块实例,其中变量和 RNG 已绑定到该实例。module – 将用于将变量和 RNG 绑定到的
Module
。 作为第一个参数传递给fn
的Module
将是 module 的克隆。mutable – 可以是 bool、str 或 list。指定哪些集合应被视为可变:
bool
:所有/无集合是可变的。str
:单个可变集合的名称。list
:可变集合的名称列表。capture_intermediates – 如果为
True
,则捕获“intermediates”集合中所有模块的中间返回值。默认情况下,仅存储所有 __call__ 方法的返回值。可以传递一个函数来更改筛选器行为。筛选器函数接收模块实例和方法名称,并返回一个布尔值,指示是否应存储该方法调用的输出。
- 返回
包装
fn
的 apply 函数。
- flax.linen.init(fn, module, mutable=DenyList(deny='intermediates'), capture_intermediates=False)[源代码]#
创建一个 init 函数,使用绑定模块调用
fn
。与
Module.init
不同,此函数返回一个新函数,其签名为(rngs, *args, **kwargs) -> variables
。 rngs 可以是 PRNGKeys 的字典或单个`PRNGKey
,它等效于传递一个名为“params”的 PRNGKey 字典。返回的 init 函数可以直接与 JAX 转换(如
jax.jit
)组合。>>> class Foo(nn.Module): ... def encode(self, x): ... ... ... def decode(self, x): ... ... >>> def f(foo, x): ... z = foo.encode(x) ... y = foo.decode(z) ... # ... ... return y >>> foo = Foo() >>> f_jitted = jax.jit(nn.init(f, foo)) >>> variables = f_jitted(jax.random.key(0), jnp.ones((1, 3)))
- 参数
fn – 应应用的函数。传递的第一个参数将是
module
的模块实例,其中变量和 RNG 已绑定到该实例。module – 将用于将变量和 RNG 绑定到的
Module
。 作为第一个参数传递给fn
的Module
将是 module 的克隆。mutable – 可以是 bool、str 或 list。指定哪些集合应被视为可变:
bool
:所有/无集合是可变的。str
:单个可变集合的名称。list
:可变集合的名称列表。默认情况下,除“intermediates”之外的所有集合都是可变的。capture_intermediates – 如果为 True,则捕获“intermediates”集合中所有模块的中间返回值。默认情况下,仅存储所有 __call__ 方法的返回值。可以传递一个函数来更改筛选器行为。筛选器函数接收模块实例和方法名称,并返回一个布尔值,指示是否应存储该方法调用的输出。
- 返回
包装
fn
的 init 函数。
- flax.linen.init_with_output(fn, module, mutable=DenyList(deny='intermediates'), capture_intermediates=False)[源代码]#
创建一个 init 函数,使用绑定模块调用
fn
,该函数还返回函数输出。与
Module.init_with_output
不同,此函数返回一个新函数,其签名为(rngs, *args, **kwargs) -> (T, variables)
,其中T
是fn
的返回类型。rngs 可以是 PRNGKeys 的字典或单个`PRNGKey
,它等效于传递一个名为“params”的 PRNGKey 字典。返回的 init 函数可以直接与 JAX 转换(如
jax.jit
)组合。>>> class Foo(nn.Module): ... def encode(self, x): ... ... ... def decode(self, x): ... ... >>> def f(foo, x): ... z = foo.encode(x) ... y = foo.decode(z) ... # ... ... return y >>> foo = Foo() >>> f_jitted = jax.jit(nn.init_with_output(f, foo)) >>> y, variables = f_jitted(jax.random.key(0), jnp.ones((1, 3)))
- 参数
fn – 应应用的函数。传递的第一个参数将是
module
的模块实例,其中变量和 RNG 已绑定到该实例。module – 将用于将变量和 RNG 绑定到的
Module
。 作为第一个参数传递给fn
的Module
将是 module 的克隆。mutable – 可以是 bool、str 或 list。指定哪些集合应被视为可变:
bool
:所有/无集合是可变的。str
:单个可变集合的名称。list
:可变集合的名称列表。默认情况下,除“intermediates”之外的所有集合都是可变的。capture_intermediates – 如果为
True
,则捕获“intermediates”集合中所有模块的中间返回值。默认情况下,仅存储所有 __call__ 方法的返回值。可以传递一个函数来更改筛选器行为。筛选器函数接收模块实例和方法名称,并返回一个布尔值,指示是否应存储该方法调用的输出。
- 返回
包装
fn
的 init 函数。