初始化/应用

初始化/应用#

flax.linen.apply(fn, module, mutable=False, capture_intermediates=False)[源代码]#

创建一个 apply 函数,使用绑定模块调用 fn

Module.apply 不同,此函数返回一个新函数,其签名为 (variables, *args, rngs=None, **kwargs) -> T,其中 Tfn 的返回类型。 如果 mutable 不是 False,则返回类型是一个元组,其中第二个项是具有已变异变量的 FrozenDict

返回的 apply 函数可以直接与 JAX 转换(如 jax.jit)组合。

>>> class Foo(nn.Module):
...   def encode(self, x):
...     ...
...   def decode(self, x):
...     ...

>>> def f(foo, x):
...   z = foo.encode(x)
...   y = foo.decode(z)
...   # ...
...   return y

>>> variables = {}
>>> foo = Foo()
>>> f_jitted = jax.jit(nn.apply(f, foo))
>>> f_jitted(variables, jnp.ones((1, 3)))
参数
  • fn – 应应用的函数。传递的第一个参数将是 module 的模块实例,其中变量和 RNG 已绑定到该实例。

  • module – 将用于将变量和 RNG 绑定到的 Module。 作为第一个参数传递给 fnModule 将是 module 的克隆。

  • mutable – 可以是 bool、str 或 list。指定哪些集合应被视为可变:bool:所有/无集合是可变的。str:单个可变集合的名称。list:可变集合的名称列表。

  • capture_intermediates – 如果为 True,则捕获“intermediates”集合中所有模块的中间返回值。默认情况下,仅存储所有 __call__ 方法的返回值。可以传递一个函数来更改筛选器行为。筛选器函数接收模块实例和方法名称,并返回一个布尔值,指示是否应存储该方法调用的输出。

返回

包装 fn 的 apply 函数。

flax.linen.init(fn, module, mutable=DenyList(deny='intermediates'), capture_intermediates=False)[源代码]#

创建一个 init 函数,使用绑定模块调用 fn

Module.init 不同,此函数返回一个新函数,其签名为 (rngs, *args, **kwargs) -> variables。 rngs 可以是 PRNGKeys 的字典或单个 `PRNGKey,它等效于传递一个名为“params”的 PRNGKey 字典。

返回的 init 函数可以直接与 JAX 转换(如 jax.jit)组合。

>>> class Foo(nn.Module):
...   def encode(self, x):
...     ...
...   def decode(self, x):
...     ...

>>> def f(foo, x):
...   z = foo.encode(x)
...   y = foo.decode(z)
...   # ...
...   return y

>>> foo = Foo()
>>> f_jitted = jax.jit(nn.init(f, foo))
>>> variables = f_jitted(jax.random.key(0), jnp.ones((1, 3)))
参数
  • fn – 应应用的函数。传递的第一个参数将是 module 的模块实例,其中变量和 RNG 已绑定到该实例。

  • module – 将用于将变量和 RNG 绑定到的 Module。 作为第一个参数传递给 fnModule 将是 module 的克隆。

  • mutable – 可以是 bool、str 或 list。指定哪些集合应被视为可变:bool:所有/无集合是可变的。str:单个可变集合的名称。list:可变集合的名称列表。默认情况下,除“intermediates”之外的所有集合都是可变的。

  • capture_intermediates – 如果为 True,则捕获“intermediates”集合中所有模块的中间返回值。默认情况下,仅存储所有 __call__ 方法的返回值。可以传递一个函数来更改筛选器行为。筛选器函数接收模块实例和方法名称,并返回一个布尔值,指示是否应存储该方法调用的输出。

返回

包装 fn 的 init 函数。

flax.linen.init_with_output(fn, module, mutable=DenyList(deny='intermediates'), capture_intermediates=False)[源代码]#

创建一个 init 函数,使用绑定模块调用 fn,该函数还返回函数输出。

Module.init_with_output 不同,此函数返回一个新函数,其签名为 (rngs, *args, **kwargs) -> (T, variables),其中 Tfn 的返回类型。rngs 可以是 PRNGKeys 的字典或单个 `PRNGKey,它等效于传递一个名为“params”的 PRNGKey 字典。

返回的 init 函数可以直接与 JAX 转换(如 jax.jit)组合。

>>> class Foo(nn.Module):
...   def encode(self, x):
...     ...
...   def decode(self, x):
...     ...

>>> def f(foo, x):
...   z = foo.encode(x)
...   y = foo.decode(z)
...   # ...
...   return y

>>> foo = Foo()
>>> f_jitted = jax.jit(nn.init_with_output(f, foo))
>>> y, variables = f_jitted(jax.random.key(0), jnp.ones((1, 3)))
参数
  • fn – 应应用的函数。传递的第一个参数将是 module 的模块实例,其中变量和 RNG 已绑定到该实例。

  • module – 将用于将变量和 RNG 绑定到的 Module。 作为第一个参数传递给 fnModule 将是 module 的克隆。

  • mutable – 可以是 bool、str 或 list。指定哪些集合应被视为可变:bool:所有/无集合是可变的。str:单个可变集合的名称。list:可变集合的名称列表。默认情况下,除“intermediates”之外的所有集合都是可变的。

  • capture_intermediates – 如果为 True,则捕获“intermediates”集合中所有模块的中间返回值。默认情况下,仅存储所有 __call__ 方法的返回值。可以传递一个函数来更改筛选器行为。筛选器函数接收模块实例和方法名称,并返回一个布尔值,指示是否应存储该方法调用的输出。

返回

包装 fn 的 init 函数。