转换#

模块上的 JAX 转换。

Jax 函数转换对纯函数进行操作。Flax 将这些转换扩展到对具有状态变量和 PRNG 序列的模块进行操作。我们将这些扩展版本称为“提升的转换”。

提升的转换可以应用于 Module 类或将 Module 实例作为其第一个参数的函数。

flax.linen.vmap(target, variable_axes=FrozenDict({}), split_rngs=FrozenDict({}), in_axes=0, out_axes=0, axis_size=None, axis_name=None, spmd_axis_name=None, metadata_params={}, methods=None)[source]#

jax.vmap 的提升版本。

有关 Jax 中未提升的批处理转换,请参阅 jax.vmap

vmap 可用于将批处理轴添加到 Module。例如,我们可以创建一个具有不共享参数的批处理轴的 Dense 版本

>>> import flax.linen as nn
>>> BatchDense = nn.vmap(
...     nn.Dense,
...     in_axes=0, out_axes=0,
...     variable_axes={'params': 0},
...     split_rngs={'params': True})

通过使用 variable_axes={'params': 0},我们表明参数本身是映射的,因此不会沿映射轴共享。因此,我们还拆分了“params” RNG,否则参数将沿映射轴进行相同的初始化。

类似地,vmap 可以用于添加具有参数共享的批处理轴

>>> import flax.linen as nn
>>> BatchDense = nn.vmap(
...     nn.Dense,
...     in_axes=0, out_axes=0,
...     variable_axes={'params': None},
...     split_rngs={'params': False})

在这里,我们使用 variable_axes={'params': None} 来表示参数变量沿映射轴共享。因此,“params” RNG 也必须共享。

参数
  • targetModule 或将 Module 作为其第一个参数的函数。

  • variable_axes – 提升到批处理转换中的变量集合。使用 None 表示广播集合或整数以映射到轴上。例如,传入 variable_axes={'params': None} 将表示参数变量应沿映射轴共享。

  • split_rngs – 拆分的 PRNG 序列对于批处理维度的每个索引都将不同。未拆分的 PRNG 将被广播。

  • in_axes – 指定输入参数的映射(请参阅 jax.vmap)。

  • out_axes – 指定返回值的映射(请参阅 jax.vmap)。

  • axis_size – 指定批处理轴的大小。只有当无法从输入参数中派生出时,才需要指定此参数。

  • axis_name – 指定批处理轴的名称。可以与并行缩减原语(例如 jax.lax.pmeanjax.lax.ppermute 等)一起使用。请注意,这仅用于 pmap 和分片映射。对于 SPMD jit,您不需要手动同步。只需确保轴已正确注释,并且 XLA:SPMD 将插入必要的集合。

  • methods – 如果 targetModule,则为要使用 vmap 的 Module 的方法。

  • spmd_axis_name – 添加到 fn 中出现的任何 pjit 分片约束的轴名称。另请参阅 google/flax

  • metadata_params – 传递给变量树中 AxisMetadata 实例的参数字典。

返回

一个 target 的批处理/向量化版本,具有相同的参数,但在 in_axes 指示的位置有额外的轴,并且具有相同的返回值,但在 out_axes 指示的位置有额外的轴。

flax.linen.scan(target, variable_axes=FrozenDict({}), variable_broadcast=False, variable_carry=False, split_rngs=FrozenDict({}), in_axes=0, out_axes=0, length=None, reverse=False, unroll=1, data_transform=None, metadata_params={}, methods=None, _split_transpose=False, check_constancy_invariants=True)[source]#

jax.lax.scan 的提升版本。

有关 Jax 中未提升的扫描,请参阅 jax.lax.scan

为了提高与 vmap 的一致性,此版本的扫描使用 in_axesout_axes 来确定扫描的参数以及沿哪个轴扫描。

scan 区分循环内的 3 种不同类型的值

  1. scan:在循环中迭代的值。所有扫描值在其扫描的轴中必须具有相同的大小。扫描输出将沿扫描轴堆叠。

  2. carry:携带的值在每次循环迭代时更新。它在整个循环中必须具有相同的形状和 dtype。

  3. broadcast:循环关闭的值。当广播变量时,它们通常在循环体内部初始化,但独立于循环变量。

target 的签名应为 (module, carry, *xs) -> (carry, ys),其中 xsys 是进出循环的扫描值。

示例

>>> import flax.linen as nn
>>> import jax
>>> import jax.numpy as jnp
...
>>> class LSTM(nn.Module):
...   features: int
...
...   @nn.compact
...   def __call__(self, x):
...     ScanLSTM = nn.scan(
...       nn.LSTMCell, variable_broadcast="params",
...       split_rngs={"params": False}, in_axes=1, out_axes=1)
...
...     lstm = ScanLSTM(self.features)
...     input_shape =  x[:, 0].shape
...     carry = lstm.initialize_carry(jax.random.key(0), input_shape)
...     carry, x = lstm(carry, x)
...     return x
...
>>> x = jnp.ones((4, 12, 7))
>>> module = LSTM(features=32)
>>> y, variables = module.init_with_output(jax.random.key(0), x)

请注意,当向 nn.scan 提供函数时,扫描发生在从第三个参数开始的所有参数上,由 in_axes 指定。前面的示例也可以使用函数形式编写为

>>> class LSTM(nn.Module):
...   features: int
...
...   @nn.compact
...   def __call__(self, x):
...
...     cell = nn.LSTMCell(self.features)
...     def body_fn(cell, carry, x):
...       carry, y = cell(carry, x)
...       return carry, y
...     scan = nn.scan(
...       body_fn, variable_broadcast="params",
...       split_rngs={"params": False}, in_axes=1, out_axes=1)
...
...     input_shape =  x[:, 0].shape
...     carry = cell.initialize_carry(
...       jax.random.key(0), input_shape)
...     carry, x = scan(cell, carry, x)
...     return x
...
>>> module = LSTM(features=32)
>>> variables = module.init(jax.random.key(0), jnp.ones((4, 12, 7)))

您还可以使用 scan 通过将多个层合并到单个扫描循环中来减少 JAX 程序的编译时间。当您有一系列相同的层想要迭代地应用于输入时,可以这样做。例如

>>> class ResidualMLPBlock(nn.Module):
...   @nn.compact
...   def __call__(self, x, _):
...     h = nn.Dense(features=2)(x)
...     h = nn.relu(h)
...     return x + h, None
...
>>> class ResidualMLP(nn.Module):
...   n_layers: int = 4
...
...   @nn.compact
...   def __call__(self, x):
...     ScanMLP = nn.scan(
...       ResidualMLPBlock, variable_axes={'params': 0},
...       variable_broadcast=False, split_rngs={'params': True},
...       length=self.n_layers)
...     x, _ = ScanMLP()(x, None)
...     return x
...
>>> model = ResidualMLP(n_layers=4)
>>> variables = model.init(jax.random.key(42), jnp.ones((1, 2)))

为了减少编译时间和内存使用,您可以使用 remat_scan(),它还将在扫描循环中检查每个层。

参数
  • targetModule 或将 Module 作为其第一个参数的函数。

  • variable_axes – 要扫描的变量集合。

  • variable_broadcast – 指定广播的变量集合。广播的变量不应依赖于任何无法从循环中提升的计算。这通常用于在 fn 中定义共享参数。

  • variable_carry – 指定在循环中传递的变量集合。对这些变量的更改将传递到下一次迭代,并在扫描完成时保留。

  • split_rngs – 每个循环迭代的拆分 PRNG 序列将不同。如果 split 为 False,则 PRNG 在迭代之间将相同。

  • in_axes – 指定要扫描的参数轴。应为参数的前缀树。使用 flax.core.broadcast 将整个输入馈送到扫描主体的每次迭代。

  • out_axes – 指定要扫描的返回值轴。应为返回值的前缀树。

  • length – 指定循环迭代的次数。仅当无法从扫描参数中推导出时才需要指定。

  • reverse – 如果为 true,则从结尾到开头按相反的顺序扫描。

  • unroll – 在循环的单次迭代中展开的扫描迭代次数(默认值:1)。

  • data_transform – 可选函数,用于转换提升的扫描 body_fn 内的原始函数式核心变量和 rng 组,旨在用于内联 SPMD 注释。

  • metadata_params – 传递给变量树中 AxisMetadata 实例的参数字典。

  • methods – 如果 target 是一个 Module,则是要扫描的 Module 的方法。

  • _split_transpose – 一个实验性功能,用于将扫描的转置拆分为扫描和映射,由实验性 Jax lax.scan() 功能支持。

  • check_constancy_invariants – 如果为 true,则扫描将验证广播常量是否为真循环不变量,并且还支持广播函数(非 carry)输出。但是,这需要额外的 jax 跟踪步骤,因此设置为 false 可以减少大型模型的跟踪时间。

返回

具有签名 (module, carry, *xs) -> (carry, ys) 的扫描函数,其中 xsys 是进出循环的扫描值。

flax.linen.jit(target, variables=True, rngs=True, static_argnums=(), static_argnames=(), donate_argnums=(), device=None, backend=None, methods=None)[源代码]#

jax.jit 的提升版本。

参数
  • targetModule 或将 Module 作为其第一个参数的函数。

  • variables – 要提升的变量集合。默认情况下,会提升所有集合。

  • rngs – 要提升的 PRNG 序列。默认情况下,会提升所有 PRNG 序列。

  • static_argnums – 一个整数或整数集合,指定将哪些位置参数视为静态(编译时常量)。仅依赖于静态参数的操作将在 Python 中进行常量折叠(在跟踪期间),因此相应的参数值可以是任何 Python 对象。静态参数应该是可哈希的,这意味着实现了 __hash____eq__,并且是不可变的。使用这些常数的不同值调用 jitted 函数将触发重新编译。如果使用比 static_argnums 指示的位置参数更少的位置参数调用 jitted 函数,则会引发错误。不是数组或其容器的参数必须标记为静态。默认为 ()。

  • static_argnames – 一个可选的字符串或字符串集合,指定将哪些命名参数视为静态(编译时常量)。有关详细信息,请参阅 static_argnums 上的注释。如果未提供但设置了 static_argnums,则默认值基于调用 inspect.signature(fun) 来查找相应的命名参数。

  • donate_argnums – 指定将哪些参数“捐赠”给计算。如果您在计算完成后不再需要这些参数,则可以安全地捐赠这些参数。在某些情况下,XLA 可以利用捐赠的缓冲区来减少执行计算所需的内存量,例如,回收您的一个输入缓冲区来存储结果。您不应重用捐赠给计算的缓冲区,如果您尝试这样做,JAX 将引发错误。

  • device – 这是一个实验性功能,API 可能会更改。可选,jitted 函数将在其上运行的设备。(可以通过 jax.devices() 检索可用设备。)默认值继承自 XLA 的 DeviceAssignment 逻辑,通常是使用 jax.devices()[0]

  • backend – 一个表示 XLA 后端的字符串:'cpu''gpu''tpu'

  • methods – 如果 target 是一个 Module,则是要 jit 的 Module 的方法。

返回

目标函数的包装版本,设置为即时编译。

flax.linen.remat(target, variables=True, rngs=True, concrete=False, prevent_cse=True, static_argnums=(), policy=None, methods=None)#

jax.checkpoint 的提升版本。

检查点是一种通过在反向传播期间重新计算激活来减少内存使用量的技术。训练大型模型时,检查模型的部分内容以权衡内存使用量以进行额外的计算可能会很有用。

示例

>>> import jax
>>> import jax.numpy as jnp
>>> import flax.linen as nn
...
>>> class CheckpointedMLP(nn.Module):
...   @nn.checkpoint
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(128)(x)
...     x = nn.relu(x)
...     x = nn.Dense(1)(x)
...     return x
...
>>> model = CheckpointedMLP()
>>> variables = model.init(jax.random.key(0), jnp.ones((1, 16)))

此函数被别名为 remat,就像 jax.remat 一样。

参数
  • target – 一个 Module 或一个以 Module 作为其第一个参数的函数。在计算目标的梯度时,将重新计算中间计算。

  • variables – 要提升的变量集合。默认情况下,会提升所有集合。

  • rngs – 要提升的 PRNG 序列。默认情况下,会提升所有 PRNG 序列。

  • concrete – 可选,布尔值,指示 fun 是否可能涉及依赖于值的 Python 控制流(默认值为 False)。对此类控制流的支持是可选的,默认情况下禁用,因为在与 jax.jit() 的某些边缘情况组合中,它可能会导致一些额外的计算。

  • prevent_cse – 可选,布尔值,指示是否阻止从微分生成的 HLO 中的公共子表达式消除 (CSE) 优化。这种 CSE 预防是有代价的,因为它会阻碍其他优化,并且在某些后端(尤其是 GPU)上可能会产生很高的开销。默认值为 True,因为否则在 jitpmap 下,CSE 可能会破坏此装饰器的目的。但在某些设置中,例如在 scan 内使用时,此 CSE 预防机制是不必要的,在这种情况下,应将 prevent_cse 设置为 False。

  • static_argnums – 可选,整数或整数序列,指示用于跟踪和缓存目的的参数值。将参数指定为静态可以避免在跟踪时出现 ConcretizationTypeErrors,但会增加重新跟踪开销。

  • policy – 实验性检查点策略,请参阅 jax.checkpoint

  • methods – 一个可选的方法名称列表,如果 methods 为 None(默认值),则仅提升 __call__ 方法。如果 ``target`` 是一个函数,则忽略 methods

返回

target 的包装版本。在计算梯度时,将在反向传递中重新计算中间计算。

flax.linen.remat_scan(target, lengths=(), policy=None, variable_broadcast=False, variable_carry=False, variable_axes=FrozenDict({True: 0}), split_rngs=FrozenDict({True: True}))[源代码]#

结合 remat 和 scan 以提高内存效率和实现恒定时间编译。

remat_scan 允许实现恒定的编译时间和与模型深度相关的亚线性内存使用。 只需付出较小的恒定代价。 这通常对于非常深的模型有利。

示例

>>> import flax.linen as nn

>>> class BigModel(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     DenseStack = nn.remat_scan(nn.Dense, lengths=(10, 10))
...     # 100x dense with O(sqrt(N)) memory for gradient computation
...     return DenseStack(8, name="dense_stack")(x)
参数
  • targetModule 或将 Module 作为其第一个参数的函数。

  • lengths – 给定级别的循环迭代次数。总迭代次数 n = prod(lengths)。每次循环都会进行重物质化。这样,内存消耗与 n^(1 / d) 成正比,其中 d = len(lengths)。最小的内存消耗需要调整长度,以便在嵌套循环的每个级别消耗相同的内存量。

  • policy – 实验性检查点策略,请参阅 jax.checkpoint

  • variable_broadcast – 指定广播的变量集合。广播的变量不应依赖于任何无法从循环中提升的计算。这通常用于在 fn 中定义共享参数。

  • variable_carry – 指定在循环中传递的变量集合。对这些变量的更改将传递到下一次迭代,并在扫描完成时保留。

  • variable_axes – 要扫描的变量集合。默认为 {True: 0}

  • split_rngs – 每个循环迭代的拆分 PRNG 序列将不同。如果 split 为 False,则 PRNG 在迭代之间将相同。默认为 {True: True}

返回

target 的包装版本,该版本会重复自身 prod(lengths) 次。

flax.linen.map_variables(target, mapped_collections=True, trans_in_fn=<function <lambda>>, trans_out_fn=<function <lambda>>, init=False, mutable=False, rngs=True, variables=True, methods=None)[源代码]#

映射模块内的变量。

map_variables 可用于在应用模块之前和之后转换模块内的变量。这在其他方面对于掩盖模块的权重而无需修改模块本身非常有用。

示例

>>> import jax
>>> import jax.numpy as jnp
>>> import flax.linen as nn
...
>>> class CausalDense(nn.Module):
...   '''A dense layer that masks the weights such that the output is
...   causal, i.e. output i only depends on input <= i.
...   '''
...   features: int
...
...   def apply_mask(self, variables):
...     return (jax.tree_util.tree_map(jnp.triu, variables)
...             if not self.is_initializing() else variables)
...
...   def setup(self):
...     # temporary class
...     _CausalDense = nn.map_variables(
...       nn.Dense, 'params', self.apply_mask, init=self.is_initializing())
...     self.dense = _CausalDense(features=self.features, use_bias=False)
...
...   def __call__(self, x):
...     return self.dense(x)
...
>>> module = CausalDense(features=5)
>>> variables = module.init(jax.random.key(0), jnp.ones((1, 5)))
参数
  • target – 要转换的模块或函数。

  • mapped_collections – 要转换的集合。

  • trans_in_fn – 在应用模块或函数之前修改变量。

  • trans_out_fn – 在应用模块或函数之后修改变量,只有当 initmutable 不为 False 时才应用。

  • init – 如果为 True,则在转换之前初始化变量。

  • mutable – 如果为 True,则映射的变量集合将是可变的。

  • rngs – 添加到转换作用域的 PRNGSequences(默认值:全部)。

  • variables – 添加到转换作用域的其他变量集合。除了 target 指定的那些之外(默认值:全部)。

  • methods – 如果 targetModule,则 Module 的方法将映射变量。

返回

target 的包装版本,它将映射指定的集合。

flax.linen.jvp(fn, mdl, primals, tangents, variable_tangents, variables=True, rngs=True)[源代码]#

jax.jvp 的提升版本。

有关未提升的雅可比-向量积(前向梯度),请参见 jax.jvp

请注意,不会为变量返回切线。当需要变量切线时,应使用 Module.variables 通过 fn 显式返回其值

>>> import flax.linen as nn
>>> import jax.numpy as jnp

>>> class LearnScale(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     p = self.param('test', nn.initializers._init(), ())
...     return p * x

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     scale = LearnScale()
...     vars_t = jax.tree_util.tree_map(jnp.ones_like,
...                                     scale.variables.get('params', {}))
...     _, out_t = nn.jvp(
...         lambda mdl, x: mdl(x), scale, (x,), (jnp.zeros_like(x),),
...         variable_tangents={'params': vars_t})
...     return out_t

示例

>>> def learn_scale(scope, x):
...   p = scope.param('scale', nn.initializers.zeros_init(), ())
...   return p * x

>>> def f(scope, x):
...   vars_t = jax.tree_util.tree_map(jnp.ones_like, scope.variables().get('params', {}))
...   x, out_t = lift.jvp(
...       learn_scale, scope, (x,), (jnp.zeros_like(x),),
...       variable_tangents={'params': vars_t})
...   return out_t
参数
  • fn – 要微分的函数。其参数应为数组、标量或数组或标量的标准 Python 容器。它应返回数组、标量或数组或标量的标准 Python 容器。它将接收作用域和原始值作为参数。

  • mdl – 将要对其变量进行微分的模块。

  • primals – 应在其中评估 fun 的雅可比的原始值。应为参数的元组或列表,其长度应等于 fun 的位置参数的数量。

  • tangents – 应在其中评估雅可比-向量积的切向量。应为切向量的元组或列表,其树结构和数组形状与 primals 相同。

  • variable_tangents – 具有与作用域相同结构的字典或字典的 PyTree。字典中的每个条目都指定变量集合的切线。在 variable_tangents 中不指定集合等效于传递零向量作为切线。

  • variablesfn 中可用的其他变量集合,但不接收切线。

  • rngsfn 内部可用的 prng。

返回

(primals_out, tangents_out) 对,其中 primals_outfun(*primals),而 tangents_out 是在 primals 处评估的 function 的雅可比-向量积,带有 tangentstangents_out 值具有与 primals_out 相同的 Python 树结构和形状。

flax.linen.vjp(fn, mdl, *primals, has_aux=False, reduce_axes=(), vjp_variables='params', variables=True, rngs=True, multi_scope=False)[源代码]#

jax.vjp 的提升版本。

有关未提升的向量-雅可比积(反向梯度),请参见 jax.vjp

请注意,对于 vjp_variables 指定的集合中的所有变量,都会返回梯度。但是,反向函数仅期望 fn 的返回值的余切。如果变量也需要余切,可以使用 Module.variablesfn 返回。

示例

>>> import flax.linen as nn
>>> import jax.numpy as jnp

>>> class LearnScale(nn.Module):
...   @nn.compact
...   def __call__(self, x, y):
...     p = self.param('scale', nn.initializers.zeros_init(), ())
...     return p * x * y

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x, y):
...     z, bwd = nn.vjp(lambda mdl, x, y: mdl(x, y), LearnScale(), x, y)
...     params_grad, x_grad, y_grad = bwd(jnp.ones(z.shape))
...     return z, params_grad, x_grad, y_grad
参数
  • fn – 要微分的函数。其参数应为数组、标量或数组或标量的标准 Python 容器。它应返回数组、标量或数组或标量的标准 Python 容器。它将接收作用域和原始值作为参数。

  • mdl – 将要对其变量进行微分的模块。

  • *primals – 应该在其中计算 fn 的雅可比矩阵的原始值序列。primals 的长度应等于 fn 的位置参数的数量。每个原始值都应该是一个数组、标量或标准 Python 容器的元组。

  • has_aux – 可选的布尔值。指示 fn 是否返回一对,其中第一个元素被认为是要求微分的数学函数的输出,第二个元素是辅助数据。默认为 False

  • reduce_axes – 可选的轴名称元组。如果此处列出了一个轴,并且 fn 隐式地在该轴上广播一个值,则反向传播将执行相应梯度的 psum。否则,VJP 将按命名轴进行逐示例计算。例如,如果 'batch' 是一个命名的批次轴,vjp(f, *args, reduce_axes=('batch',)) 将创建一个在批次上求和的 VJP 函数,而 vjp(f, *args) 将创建一个逐示例的 VJP。

  • vjp_variables – vjpfun 将为该过滤器指定的所有变量集合返回一个余切向量。

  • variablesfn 内部可用的其他变量集合,但不接收余切。

  • rngsfn 内部可用的 prng。

  • multi_scope – 对于包含从外部模块传入的多个作用域的模块,允许为多个作用域返回变量梯度,而不是报错。

返回

如果 has_auxFalse,则返回一个 (primals_out, vjpfun) 对,其中 primals_outfn(*primals)vjpfun 是一个函数,它将与 primals_out 形状相同的余切向量映射到与 primals 形状相同的余切向量的元组,表示在 primals 处计算的 fn 的向量-雅可比矩阵乘积。如果 has_auxTrue,则返回一个 (primals_out, vjpfun, aux) 元组,其中 auxfn 返回的辅助数据。

flax.linen.custom_vjp(fn, forward_fn, backward_fn, grad_vars='params', nondiff_argnums=())[源代码]#

jax.custom_vjp 的提升版本。

forward_fnbackward_fn 共同为 fn 定义一个自定义 vjp。如果未计算 vjp(反向梯度),则原始的 fn 将运行。

forward_fn 接收与 fn 相同的参数,但预计返回一个元组,其中包含 fn(mdl, *args) 的输出和传递给 backward_fn 的残差。

backward_fn 接收 nondiff 参数、残差和输出切线。它应该返回一个包含变量和输入切线的元组。

请注意,nn.vjp 返回的 vjp 函数可以作为残差传递并在 backward_fn 中使用。在反向传播期间,作用域不可用。如果 backward_fn 中需要模块,则可以获取变量的快照并在 forward_fn 中将其作为残差返回。

示例

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     def f(mdl, x):
...       return mdl(x)
...
...     def fwd(mdl, x):
...       return nn.vjp(f, mdl, x)
...
...     def bwd(vjp_fn, y_t):
...       params_t, *inputs_t = vjp_fn(y_t)
...       params_t = jax.tree_util.tree_map(jnp.sign, params_t)
...       return (params_t, *inputs_t)
...
...     sign_grad = nn.custom_vjp(
...         f, forward_fn=fwd, backward_fn=bwd)
...     return sign_grad(nn.Dense(1), x).reshape(())

>>> x = jnp.ones((2,))
>>> variables = Foo().init(jax.random.key(0), x)
>>> grad = jax.grad(Foo().apply)(variables, x)
参数
  • fn – 要为其定义 custom_vjp 的函数。

  • forward_fn – 与 fn 具有相同参数的函数,返回一个元组,其中包含原始输出和将传递给 backward_fn 的残差。

  • backward_fn – 参数以 (*nondiff_args, residuals, tangents) 的形式传递。该函数应该返回一个元组,其中包含 grad_vars 指定的集合中的变量和输入参数(模块和 nondiff 参数除外)的切线。

  • grad_vars – 将为其计算 vjp 的集合(默认值:“params”)。

  • nondiff_argnums – 不计算 vjp 的参数。

返回

具有与 fn 相同签名的函数,带有自定义 vjp。

flax.linen.while_loop(cond_fn, body_fn, mdl, init, carry_variables=False, broadcast_variables=True, split_rngs=FrozenDict({}))[源代码]#

jax.lax.while_loop 的提升版本。

提升的作用域被传递给 cond_fnbody_fn。广播的变量是不可变的。携带的变量是可变的,但不能更改形状和数据类型。这也意味着你不能在主体内部初始化变量。如果需要变量初始化,请考虑在调用 while_loop 之前手动调用 body_fn 一次。

示例

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class WhileLoopExample(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     def cond_fn(mdl, c):
...       return mdl.variables['state']['acc'] < 10
...     def body_fn(mdl, c):
...       acc = mdl.variable('state', 'acc', lambda: jnp.array(0))
...       acc.value += 1
...       y = nn.Dense(c.shape[-1])(c)
...       return y
...     c = x
...     if self.is_mutable_collection('params'):
...       return body_fn(self, c)
...     else:
...       return nn.while_loop(cond_fn, body_fn, self, c,
...                             carry_variables='state')

>>> k = jax.random.key(0)
>>> x = jnp.ones((2, 2))
>>> initial_vars = WhileLoopExample().init(k, x)
>>> result, state = WhileLoopExample().apply(initial_vars, x, mutable=['state'])
参数
  • cond_fn – 只要循环应该继续,就应该返回 True。

  • body_fn – while 循环的主体。

  • mdl – 应该提升到循环中的模块。

  • init – 传递给循环的初始状态。

  • carry_variables – 在循环中携带的集合,因此是可变的(默认值:无)。

  • broadcast_variables – 关闭的集合,因此是只读的(默认值:所有集合)

  • split_rngs – 每个循环迭代的拆分 PRNG 序列将不同。如果 split 为 False,则 PRNG 在迭代之间将相同。

返回

执行 while 循环后的最终状态。

flax.linen.cond(pred, true_fun, false_fun, mdl, *operands, variables=True, rngs=True)[源代码]#

jax.lax.cond 的提升版本。

true_funfalse_fun 返回的值必须具有相同的 Pytree 结构、形状和数据类型。在分支内部创建或更新的变量也必须具有相同的结构。请注意,当仅在一个分支中创建变量或子模块时,此约束会被违反。因为仅在一个分支中初始化变量会导致参数结构不同。

示例

>>> import flax.linen as nn

>>> class CondExample(nn.Module):
...   @nn.compact
...   def __call__(self, x, pred):
...     self.variable('state', 'true_count', lambda: 0)
...     self.variable('state', 'false_count', lambda: 0)
...     def true_fn(mdl, x):
...       mdl.variable('state', 'true_count').value += 1
...       return nn.Dense(2, name='dense')(x)
...     def false_fn(mdl, x):
...       mdl.variable('state', 'false_count').value += 1
...       return -nn.Dense(2, name='dense')(x)
...     return nn.cond(pred, true_fn, false_fn, self, x)
参数
  • pred – 确定是评估 true_fun 还是 false_fun。

  • true_fun – 当 predTrue 时执行的函数。函数签名为 (模块, *操作数) -> T。

  • false_fun – 当 predFalse 时执行的函数。函数签名为 (模块, *操作数) -> T。

  • mdl – 要传递的模块目标。

  • *operands – 传递给 true_funfalse_fun 的参数。

  • variables – 传递给条件分支的变量集合(默认值:全部)。

  • rngs – 传递给条件分支的 PRNG 序列(默认值:全部)。

返回

已执行分支(true_funfalse_fun)的结果。

flax.linen.switch(index, branches, mdl, *operands, variables=True, rngs=True)[source]#

jax.lax.switch 的提升版本。

来自 branches 的返回值必须具有相同的 Pytree 结构、形状和数据类型。在分支内部创建或更新的变量也必须具有相同的结构。请注意,当仅在一个分支中创建变量或子模块时,会违反此约束。因为仅在一个分支中初始化变量会导致参数结构不同。

示例

>>> import flax.linen as nn

>>> class SwitchExample(nn.Module):
...   @nn.compact
...   def __call__(self, x, index):
...     self.variable('state', 'a_count', lambda: 0)
...     self.variable('state', 'b_count', lambda: 0)
...     self.variable('state', 'c_count', lambda: 0)
...     def a_fn(mdl, x):
...       mdl.variable('state', 'a_count').value += 1
...       return nn.Dense(2, name='dense')(x)
...     def b_fn(mdl, x):
...       mdl.variable('state', 'b_count').value += 1
...       return -nn.Dense(2, name='dense')(x)
...     def c_fn(mdl, x):
...       mdl.variable('state', 'c_count').value += 1
...       return nn.Dense(2, name='dense')(x)
...     return nn.switch(index, [a_fn, b_fn, c_fn], self, x)

如果要为每个分支设置不同的参数结构,则应在调用 switch 之前在初始化时运行所有分支。

>>> class MultiHeadSwitchExample(nn.Module):
...   def setup(self) -> None:
...     self.heads = [
...       nn.Sequential([nn.Dense(10), nn.Dense(7), nn.Dense(5)]),
...       nn.Sequential([nn.Dense(11), nn.Dense(5)]),
...       nn.Dense(5),
...     ]
...
...   @nn.compact
...   def __call__(self, x, index):
...     def head_fn(i):
...       return lambda mdl, x: mdl.heads[i](x)
...     branches = [head_fn(i) for i in range(len(self.heads))]
...
...     # run all branches on init
...     if self.is_mutable_collection('params'):
...       for branch in branches:
...         _ = branch(self, x)
...
...     return nn.switch(index, branches, self, x)
参数
  • index – 整数标量类型,表示要应用的哪个分支函数。

  • branches – 要根据索引应用的函数序列。每个函数的签名为 (模块, *操作数) -> T。

  • mdl – 要传递的模块目标。

  • *operands – 传递给分支的参数。

  • variables – 传递给条件分支的变量集合(默认值:全部)。

  • rngs – 传递给条件分支的 PRNG 序列(默认值:全部)。

返回

已执行分支的结果。