目录

#

线性模块#

class flax.linen.Dense(features, use_bias=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, dot_general=None, dot_general_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

对输入的最后一个维度应用线性变换。

用法示例

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> layer = nn.Dense(features=4)
>>> params = layer.init(jax.random.key(0), jnp.ones((1, 3)))
>>> jax.tree_util.tree_map(jnp.shape, params)
{'params': {'bias': (4,), 'kernel': (3, 4)}}
features#

输出特征的数量。

类型

int

use_bias#

是否在输出中添加偏置(默认值:True)。

类型

bool

dtype#

计算的数据类型(默认值:从输入和参数推断)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的数据类型(默认值:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

precision#

计算的数值精度,详情请参阅 jax.lax.Precision

类型

Union[None, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]

kernel_init#

权重矩阵的初始化函数。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

bias_init#

偏置的初始化函数。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

__call__(inputs)[source]#

对输入的最后一个维度应用线性变换。

参数

inputs – 要变换的 nd 数组。

返回

变换后的输入。

方法

class flax.linen.DenseGeneral(features, axis=-1, batch_dims=(), use_bias=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, precision=None, dot_general=None, dot_general_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

具有灵活轴的线性变换。

用法示例

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> # equivalent to `nn.Dense(features=4)`
>>> layer = nn.DenseGeneral(features=4)
>>> # output features (4, 5)
>>> layer = nn.DenseGeneral(features=(4, 5))
>>> params = layer.init(jax.random.key(0), jnp.ones((1, 3)))
>>> jax.tree_util.tree_map(jnp.shape, params)
{'params': {'bias': (4, 5), 'kernel': (3, 4, 5)}}
>>> # apply transformation on the the second and last axes
>>> layer = nn.DenseGeneral(features=(4, 5), axis=(1, -1))
>>> params = layer.init(jax.random.key(0), jnp.ones((1, 3, 6, 7)))
>>> jax.tree_util.tree_map(jnp.shape, params)
{'params': {'bias': (4, 5), 'kernel': (3, 7, 4, 5)}}
features#

int 或 包含输出特征数量的元组。

类型

int | collections.abc.Sequence[int]

axis#

int 或 包含要应用变换的轴的元组。例如,(-2, -1) 将变换应用于最后两个轴。

类型

int | collections.abc.Sequence[int]

batch_dims#

包含批量轴的元组。

类型

collections.abc.Sequence[int]

use_bias#

是否在输出中添加偏置(默认值:True)。

类型

bool

dtype#

计算的数据类型(默认值:从输入和参数推断)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的数据类型(默认值:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

kernel_init#

权重矩阵的初始化函数。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

bias_init#

偏置的初始化函数。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

precision#

计算的数值精度,详情请参阅 jax.lax.Precision

类型

Union[None, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]

__call__(inputs)[source]#

对输入沿多个维度应用线性变换。

参数

inputs – 要变换的 nd 数组。

返回

变换后的输入。

方法

class flax.linen.Conv(features, kernel_size, strides=1, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=True, mask=None, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, conv_general_dilated=None, conv_general_dilated_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

包装 lax.conv_general_dilated 的卷积模块。

用法示例

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> # valid padding
>>> layer = nn.Conv(features=4, kernel_size=(3,), padding='VALID')
>>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3)))
>>> jax.tree_util.tree_map(jnp.shape, variables)
{'params': {'bias': (4,), 'kernel': (3, 3, 4)}}
>>> out.shape
(1, 6, 4)
>>> # circular padding with stride 2
>>> layer = nn.Conv(features=4, kernel_size=(3, 3), strides=2, padding='CIRCULAR')
>>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3)))
>>> jax.tree_util.tree_map(jnp.shape, variables)
{'params': {'bias': (4,), 'kernel': (3, 3, 3, 4)}}
>>> out.shape
(1, 4, 4)
>>> # apply lower triangle mask
>>> mask = jnp.tril(jnp.ones((3, 3, 4)))
>>> layer = nn.Conv(features=4, kernel_size=(3,), mask=mask, padding='VALID')
>>> variables = layer.init(jax.random.key(0), jnp.ones((1, 8, 3)))
features#

卷积滤波器的数量。

类型

int

kernel_size#

卷积核的形状。整数将被解释为单个整数的元组。

类型

int | collections.abc.Sequence[int]

strides#

一个整数或一个由 n 个整数组成的序列,表示窗口间的步长(默认值:1)。

类型

None | int | collections.abc.Sequence[int]

padding#

字符串 'SAME'、字符串 'VALID'、字符串 'CIRCULAR'(周期性边界条件),或者一个由 n(low, high) 整数对组成的序列,表示在每个空间维度之前和之后应用的填充。单个整数被解释为在所有维度应用相同的填充,而序列中的单个整数导致在两侧使用相同的填充。'CAUSAL' 用于 1D 卷积的填充将左填充卷积轴,从而产生相同大小的输出。

类型

Union[str, int, collections.abc.Sequence[Union[int, tuple[int, int]]]]

input_dilation#

一个整数或一个由 n 个整数组成的序列,给出在 inputs 的每个空间维度中应用的扩张因子(默认值:1)。具有输入扩张 d 的卷积等效于具有步长 d 的转置卷积。

类型

None | int | collections.abc.Sequence[int]

kernel_dilation#

一个整数或一个由 n 个整数组成的序列,给出在卷积核的每个空间维度中应用的扩张因子(默认值:1)。具有核扩张的卷积也称为“空洞卷积”。

类型

None | int | collections.abc.Sequence[int]

feature_group_count#

整数,默认为 1。如果指定,则将输入特征分为若干组。

类型

int

use_bias#

是否在输出中添加偏置(默认值:True)。

类型

bool

mask#

用于屏蔽卷积期间权重的可选掩码。掩码必须与卷积权重矩阵的形状相同。

类型

Optional[Union[jax.Array, Any]]

dtype#

计算的数据类型(默认值:从输入和参数推断)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的数据类型(默认值:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

precision#

计算的数值精度,详情请参阅 ``jax.lax.Precision`。

类型

Union[None, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]

kernel_init#

卷积核的初始化器。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

bias_init#

偏置的初始化器。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

__call__(inputs)#

对输入应用(可能未共享的)卷积。

参数

inputs – 输入数据,维度为 (*batch_dims, spatial_dims..., features)。这是 channels-last 约定,即 2D 卷积为 NHWC,3D 卷积为 NDHWC。注意:这与 lax.conv_general_dilated 使用的输入约定不同,后者将空间维度放在最后。注意:如果输入具有多个批次维度,则所有批次维度将被展平为单个维度进行卷积,并在返回之前恢复。在某些情况下,直接 vmap 层的性能可能比此默认展平方法更好。如果输入缺少批次维度,则会为卷积添加批次维度,并在返回时删除,这允许编写单例代码。

返回

卷积后的数据。

方法

class flax.linen.ConvTranspose(features, kernel_size, strides=None, padding='SAME', kernel_dilation=None, use_bias=True, mask=None, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, transpose_kernel=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

封装 lax.conv_transpose 的卷积模块。

用法示例

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> # valid padding
>>> layer = nn.ConvTranspose(features=4, kernel_size=(3,), padding='VALID')
>>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3)))
>>> jax.tree_util.tree_map(jnp.shape, variables)
{'params': {'bias': (4,), 'kernel': (3, 3, 4)}}
>>> out.shape
(1, 10, 4)
>>> # circular padding with stride 2
>>> layer = nn.ConvTranspose(features=4, kernel_size=(6, 6), strides=(2, 2), padding='CIRCULAR', transpose_kernel=True)
>>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 15, 15, 3)))
>>> jax.tree_util.tree_map(jnp.shape, variables)
{'params': {'bias': (4,), 'kernel': (6, 6, 4, 3)}}
>>> out.shape
(1, 30, 30, 4)
>>> # apply lower triangle mask
>>> mask = jnp.tril(jnp.ones((3, 3, 4)))
>>> layer = nn.ConvTranspose(features=4, kernel_size=(3,), mask=mask, padding='VALID')
>>> variables = layer.init(jax.random.key(0), jnp.ones((1, 8, 3)))
features#

卷积滤波器的数量。

类型

int

kernel_size#

卷积核的形状。对于 1D 卷积,可以将内核大小作为整数传递,它将被解释为单个整数的元组。对于所有其他情况,它必须是整数序列。

类型

int | collections.abc.Sequence[int]

strides#

一个整数或一个包含 n 个整数的序列,表示窗口之间的步长。

类型

collections.abc.Sequence[int] | None

padding#

可以是字符串 ‘SAME’,字符串 ‘VALID’,字符串 ‘CIRCULAR’(周期性边界条件),或一个由 n(low, high) 整数对组成的序列,表示在每个空间维度之前和之后应用的填充。单个整数被解释为在所有维度上应用相同的填充,在序列中分配单个整数会导致在两侧使用相同的填充。

类型

Union[str, int, collections.abc.Sequence[Union[int, tuple[int, int]]]]

kernel_dilation#

None,或一个整数或一个包含 n 个整数的序列,表示在卷积核的每个空间维度中应用的扩张因子。具有内核扩张的卷积也称为“空洞卷积”。

类型

collections.abc.Sequence[int] | None

use_bias#

是否在输出中添加偏置(默认值:True)。

类型

bool

mask#

用于屏蔽卷积期间权重的可选掩码。掩码必须与卷积权重矩阵的形状相同。

类型

Optional[Union[jax.Array, Any]]

dtype#

计算的数据类型(默认值:从输入和参数推断)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的数据类型(默认值:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

precision#

计算的数值精度,详情请参阅 jax.lax.Precision

类型

Union[None, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]

kernel_init#

卷积核的初始化器。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

bias_init#

偏置的初始化器。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

transpose_kernel#

如果 True,则翻转空间轴并交换内核的输入/输出通道轴。

类型

bool

__call__(inputs)[source]#

对输入应用转置卷积。

行为与 jax.lax.conv_transpose 镜像。

参数

inputs – 输入数据,维度为 (*batch_dims, spatial_dims..., features)。这是 channels-last 约定,即 2D 卷积为 NHWC,3D 卷积为 NDHWC。注意:这与 lax.conv_general_dilated 使用的输入约定不同,后者将空间维度放在最后。注意:如果输入具有多个批次维度,则所有批次维度将被展平为单个维度进行卷积,并在返回之前恢复。在某些情况下,直接 vmap 层的性能可能比此默认展平方法更好。如果输入缺少批次维度,则会为卷积添加批次维度,并在返回时删除,这允许编写单例代码。

返回

卷积后的数据。

方法

class flax.linen.ConvLocal(features, kernel_size, strides=1, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=True, mask=None, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, conv_general_dilated=None, conv_general_dilated_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

封装 lax.conv_general_dilated_local 的局部卷积模块。

用法示例

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> # valid padding
>>> layer = nn.ConvLocal(features=4, kernel_size=(3,), padding='VALID')
>>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3)))
>>> jax.tree_util.tree_map(jnp.shape, variables)
{'params': {'bias': (6, 4), 'kernel': (6, 9, 4)}}
>>> out.shape
(1, 6, 4)
>>> # circular padding with stride 2
>>> layer = nn.ConvLocal(features=4, kernel_size=(3, 3), strides=2, padding='CIRCULAR')
>>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3)))
>>> jax.tree_util.tree_map(jnp.shape, variables)
{'params': {'bias': (1, 4, 4), 'kernel': (1, 4, 27, 4)}}
>>> out.shape
(1, 4, 4)
>>> # apply lower triangle mask
>>> mask = jnp.tril(jnp.ones((6, 9, 4)))
>>> layer = nn.ConvLocal(features=4, kernel_size=(3,), mask=mask, padding='VALID')
>>> variables = layer.init(jax.random.key(0), jnp.ones((1, 8, 3)))
features#

卷积滤波器的数量。

类型

int

kernel_size#

卷积核的形状。整数将被解释为单个整数的元组。

类型

int | collections.abc.Sequence[int]

strides#

一个整数或一个由 n 个整数组成的序列,表示窗口间的步长(默认值:1)。

类型

None | int | collections.abc.Sequence[int]

padding#

字符串 'SAME'、字符串 'VALID'、字符串 'CIRCULAR'(周期性边界条件),或者一个由 n(low, high) 整数对组成的序列,表示在每个空间维度之前和之后应用的填充。单个整数被解释为在所有维度应用相同的填充,而序列中的单个整数导致在两侧使用相同的填充。'CAUSAL' 用于 1D 卷积的填充将左填充卷积轴,从而产生相同大小的输出。

类型

Union[str, int, collections.abc.Sequence[Union[int, tuple[int, int]]]]

input_dilation#

一个整数或一个由 n 个整数组成的序列,给出在 inputs 的每个空间维度中应用的扩张因子(默认值:1)。具有输入扩张 d 的卷积等效于具有步长 d 的转置卷积。

类型

None | int | collections.abc.Sequence[int]

kernel_dilation#

一个整数或一个由 n 个整数组成的序列,给出在卷积核的每个空间维度中应用的扩张因子(默认值:1)。具有核扩张的卷积也称为“空洞卷积”。

类型

None | int | collections.abc.Sequence[int]

feature_group_count#

整数,默认为 1。如果指定,则将输入特征分为若干组。

类型

int

use_bias#

是否在输出中添加偏置(默认值:True)。

类型

bool

mask#

用于屏蔽卷积期间权重的可选掩码。掩码必须与卷积权重矩阵的形状相同。

类型

Optional[Union[jax.Array, Any]]

dtype#

计算的数据类型(默认值:从输入和参数推断)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的数据类型(默认值:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

precision#

计算的数值精度,详情请参阅 jax.lax.Precision

类型

Union[None, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]

kernel_init#

卷积核的初始化器。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

bias_init#

偏置的初始化器。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

__call__(inputs)#

对输入应用(可能未共享的)卷积。

参数

inputs – 输入数据,维度为 (*batch_dims, spatial_dims..., features)。这是 channels-last 约定,即 2D 卷积为 NHWC,3D 卷积为 NDHWC。注意:这与 lax.conv_general_dilated 使用的输入约定不同,后者将空间维度放在最后。注意:如果输入具有多个批次维度,则所有批次维度将被展平为单个维度进行卷积,并在返回之前恢复。在某些情况下,直接 vmap 层的性能可能比此默认展平方法更好。如果输入缺少批次维度,则会为卷积添加批次维度,并在返回时删除,这允许编写单例代码。

返回

卷积后的数据。

方法

class flax.linen.Einsum(shape, einsum_str=None, use_bias=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

一个带有可学习的核和偏置的 Einsum 变换。

用法示例

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> layer = nn.Einsum((5, 6, 7), 'abc,cde->abde')
>>> variables = layer.init(jax.random.key(0), jnp.ones((3, 4, 5)))
>>> jax.tree_util.tree_map(jnp.shape, variables)
{'params': {'bias': (6, 7), 'kernel': (5, 6, 7)}}
shape#

核的形状。

类型

collections.abc.Sequence[int]

einsum_str#

一个表示 Einsum 方程的字符串。该方程必须恰好有两个操作数,左侧操作数是传入的输入,右侧操作数是可学习的核。构造函数参数和调用参数中的 einsum_str 必须恰好有一个不为 None,而另一个必须为 None。

类型

str | None

use_bias#

是否在输出中添加偏置(默认值:True)。

类型

bool

dtype#

计算的数据类型(默认值:从输入和参数推断)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的数据类型(默认值:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

precision#

计算的数值精度,详情请参阅 jax.lax.Precision

类型

Union[None, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]

kernel_init#

权重矩阵的初始化函数。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

bias_init#

偏置的初始化函数。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

__call__(inputs, einsum_str=None)[source]#

对输入的最后一个维度应用线性变换。

参数
  • inputs – 要变换的 nd 数组。

  • einsum_str – 一个表示 Einsum 方程的字符串。该方程必须恰好有两个操作数,左侧操作数是传入的输入,右侧操作数是可学习的核。传入调用方法的 einsum_str 将优先于传入构造函数的 einsum_str

返回

变换后的输入。

方法

class flax.linen.Embed(num_embeddings, features, dtype=None, param_dtype=<class 'jax.numpy.float32'>, embedding_init=<function variance_scaling.<locals>.init>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

嵌入模块。

一个参数化的函数,将整数 [0, num_embeddings) 映射到 features 维向量。此 Module 将创建一个形状为 (num_embeddings, features)embedding 矩阵。调用此层时,输入值将用于 0 索引到 embedding 矩阵中。索引大于或等于 num_embeddings 的值将导致 nan 值。当 num_embeddings 等于 1 时,它会将 embedding 矩阵广播到具有附加的 features 维度的输入形状。

用法示例

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> layer = nn.Embed(num_embeddings=5, features=3)
>>> indices_input = jnp.array([[0, 1, 2], [-1, -2, -3]])
>>> variables = layer.init(jax.random.key(0), indices_input)
>>> variables
{'params': {'embedding': Array([[-0.28884724,  0.19018005, -0.414205  ],
       [-0.11768015, -0.54618824, -0.3789283 ],
       [ 0.30428642,  0.49511626,  0.01706631],
       [-0.0982546 , -0.43055868,  0.20654906],
       [-0.688412  , -0.46882293,  0.26723292]], dtype=float32)}}
>>> # get the first three and last three embeddings
>>> layer.apply(variables, indices_input)
Array([[[-0.28884724,  0.19018005, -0.414205  ],
        [-0.11768015, -0.54618824, -0.3789283 ],
        [ 0.30428642,  0.49511626,  0.01706631]],

       [[-0.688412  , -0.46882293,  0.26723292],
        [-0.0982546 , -0.43055868,  0.20654906],
        [ 0.30428642,  0.49511626,  0.01706631]]], dtype=float32)
num_embeddings#

嵌入数量/词汇量大小。

类型

int

features#

每个嵌入的特征维度数。

类型

int

dtype#

嵌入向量的数据类型(默认:与嵌入相同)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的数据类型(默认值:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

embedding_init#

嵌入初始化器。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

__call__(inputs)[source]#

沿最后一个维度嵌入输入。

参数

inputs – 输入数据,所有维度都被视为批次维度。输入数组中的值必须是整数。

返回

输出是嵌入的输入数据。输出形状遵循输入,并附加一个额外的 features 维度。

attend(query)[source]#

使用查询数组对嵌入进行注意力计算。

参数

query – 数组,其最后一个维度等于嵌入的特征深度 features

返回

一个数组,其最终维度 num_embeddings 对应于查询向量数组与每个嵌入的批处理内积。通常用于 NLP 模型中嵌入和 logit 变换之间的权重共享。

方法

attend(query)

使用查询数组对嵌入进行注意力计算。

池化#

flax.linen.max_pool(inputs, window_shape, strides=None, padding='VALID')[source]#

通过取窗口切片的最大值来池化输入。

参数
  • inputs – 输入数据,维度为 (batch, window dims…, features)。

  • window_shape – 一个形状元组,定义要缩减的窗口。

  • strides – 一个 n 个整数的序列,表示窗口之间的步幅(默认值:(1, ..., 1))。

  • padding – 字符串 'SAME'、字符串 'VALID'n(low, high) 整数对的序列,用于给出在每个空间维度之前和之后应用的填充(默认值:'VALID')。

返回

每个窗口切片的最大值。

flax.linen.avg_pool(inputs, window_shape, strides=None, padding='VALID', count_include_pad=True)[source]#

通过取窗口的平均值来池化输入。

参数
  • inputs – 输入数据,维度为 (batch, window dims…, features)。

  • window_shape – 一个形状元组,定义要缩减的窗口。

  • strides – 一个 n 个整数的序列,表示窗口之间的步幅(默认值:(1, ..., 1))。

  • padding – 字符串 'SAME'、字符串 'VALID'n(low, high) 整数对的序列,用于给出在每个空间维度之前和之后应用的填充(默认值:'VALID')。

  • count_include_pad – 一个布尔值,指示是否将填充的标记包含在平均值计算中(默认值:True)。

返回

每个窗口切片的平均值。

flax.linen.pool(inputs, init, reduce_fn, window_shape, strides, padding)[源代码]#

用于定义池化函数的辅助函数。

池化函数使用 ReduceWindow XLA 操作实现。

注意

请注意,池化通常是不可微分的。这意味着提供一个可微分的 reduce_fn 并不意味着 pool 是可微分的。

参数
  • inputs – 输入数据,维度为 (batch, window dims…, features)。

  • init – 归约的初始值

  • reduce_fn – 形式为 (T, T) -> T 的归约函数。

  • window_shape – 一个形状元组,定义要缩减的窗口。

  • strides – 一个 n 个整数的序列,表示窗口之间的步幅(默认值:(1, ..., 1))。

  • padding – 字符串 'SAME'、字符串 'VALID'n(low, high) 整数对的序列,用于指定在每个空间维度之前和之后应用的填充。

返回

每个窗口切片的归约输出。

归一化#

class flax.linen.BatchNorm(use_running_average=None, axis=-1, momentum=0.99, epsilon=1e-05, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, axis_name=None, axis_index_groups=None, use_fast_variance=True, force_float32_reductions=True, parent=<flax.linen.module._Sentinel object>, name=None)[源代码]#

BatchNorm 模块。

用法说明:如果我们使用 BatchNorm 定义一个模型,例如

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp
>>> BN = nn.BatchNorm(momentum=0.9, epsilon=1e-5, dtype=jnp.float32)

初始化的变量字典除了 ‘params’ 集合外,还将包含一个单独的 ‘batch_stats’ 集合,其中包含模型中所有 BatchNorm 层的运行统计信息

>>> x = jax.random.normal(jax.random.key(0), (5, 6))
>>> variables = BN.init(jax.random.key(1), x, use_running_average=False)
>>> jax.tree_util.tree_map(jnp.shape, variables)
{'batch_stats': {'mean': (6,), 'var': (6,)}, 'params': {'bias': (6,), 'scale': (6,)}}

然后,我们在训练期间通过指定 batch_stats 集合在模块的 apply 方法中是可变的来更新 batch_stats。

>>> y, new_batch_stats = BN.apply(variables, x, mutable=['batch_stats'], use_running_average=False)

在评估期间,我们将使用 use_running_average=True 定义 BN,并使用训练中的 batch_stats 集合来设置统计信息。在这种情况下,我们不会改变批处理统计信息集合,因此无需将其标记为可变

>>> y = BN.apply(variables, x, mutable=['batch_stats'], use_running_average=True)
use_running_average#

如果为 True,则将使用 batch_stats 中存储的统计信息,而不是计算输入的批处理统计信息。

类型

bool | None

axis#

输入的特征轴或非批处理轴。

类型

int

momentum#

批处理统计信息的指数移动平均值的衰减率。

类型

float

epsilon#

添加到方差中的一个很小的浮点数,以避免除以零。

类型

float

dtype#

结果的 dtype(默认值:从输入和参数推断)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的数据类型(默认值:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

use_bias#

如果为 True,则添加偏差 (beta)。

类型

bool

use_scale#

如果为 True,则乘以比例 (gamma)。如果下一层是线性的(例如 nn.relu),则可以禁用此选项,因为缩放将由下一层完成。

类型

bool

bias_init#

偏差的初始化器,默认为零。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

scale_init#

比例的初始化器,默认为一。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

axis_name#

用于合并来自多个设备的批处理统计信息的轴名称。 有关轴名称的说明,请参阅 jax.pmap(默认值:无)。 请注意,这仅用于 pmap 和 shard map。 对于 SPMD jit,您不需要手动同步。 只需确保轴已正确注释,XLA:SPMD 将插入必要的集合。

类型

str | None

axis_index_groups#

该命名轴内的一组轴索引,表示要缩减的设备子集(默认值:无)。 例如,[[0, 1], [2, 3]] 将独立地对前两个和后两个设备上的示例进行批处理规范化。 有关更多详细信息,请参阅 jax.lax.psum

类型

Any

use_fast_variance#

如果为 true,则使用更快但数值稳定性较差的方差计算。

类型

bool

__call__(x, use_running_average=None, *, mask=None)[源代码]#

使用批处理统计信息对输入进行归一化。

注意

在初始化期间(当 self.is_initializing()True 时),批处理统计信息的运行平均值不会更新。因此,初始化期间输入的输入不需要与实际的输入分布匹配,并且缩减轴(使用 axis_name 设置)不必存在。

参数
  • x – 要归一化的输入。

  • use_running_average – 如果为 true,则将使用 batch_stats 中存储的统计信息,而不是计算输入的批处理统计信息。

  • mask – 可广播到 inputs 张量的形状的二进制数组,指示应计算均值和方差的位置。

返回

归一化的输入(与输入相同的形状)。

方法

class flax.linen.LayerNorm(epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, reduction_axes=-1, feature_axes=-1, axis_name=None, axis_index_groups=None, use_fast_variance=True, force_float32_reductions=True, parent=<flax.linen.module._Sentinel object>, name=None)[源代码]#

层归一化 (https://arxiv.org/abs/1607.06450)。

LayerNorm 独立地对批处理中每个给定示例的层激活进行归一化,而不是像批处理归一化那样跨批处理进行归一化。即,应用一种变换,使每个示例中的平均激活保持接近 0,并且激活标准偏差接近 1。

注意

此归一化操作与 InstanceNorm 和 GroupNorm 相同;区别仅在于缩减的轴和特征轴的形状(即,可学习的比例和偏差参数的形状)。

用法示例

>>> import flax.linen as nn
>>> import jax
>>> import numpy as np

>>> x = jax.random.normal(jax.random.key(0), (3, 4, 5, 6))
>>> layer = nn.LayerNorm()
>>> variables = layer.init(jax.random.key(1), x)
>>> variables
{'params': {'scale': Array([1., 1., 1., 1., 1., 1.], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0.], dtype=float32)}}
>>> y = layer.apply(variables, x)

>>> y = nn.LayerNorm(reduction_axes=(1, 2, 3)).apply(variables, x)
>>> y2 = nn.GroupNorm(num_groups=1).apply(variables, x)
>>> np.testing.assert_allclose(y, y2)

>>> y = nn.LayerNorm(reduction_axes=(1, 2), feature_axes=-1).apply(variables, x)
>>> y2 = nn.InstanceNorm(feature_axes=-1).apply(variables, x)
>>> np.testing.assert_allclose(y, y2)
epsilon#

添加到方差中的一个很小的浮点数,以避免除以零。

类型

float

dtype#

结果的 dtype(默认值:从输入和参数推断)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的数据类型(默认值:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

use_bias#

如果为 True,则添加偏差 (beta)。

类型

bool

use_scale#

如果为 True,则乘以比例 (gamma)。如果下一层是线性的(例如 nn.relu),则可以禁用此选项,因为缩放将由下一层完成。

类型

bool

bias_init#

偏差的初始化器,默认为零。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

scale_init#

比例的初始化器,默认为一。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

reduction_axes#

用于计算归一化统计信息的轴。

类型

Union[int, collections.abc.Sequence[int]]

feature_axes#

用于学习偏差和缩放的特征轴。

类型

Union[int, collections.abc.Sequence[int]]

axis_name#

用于组合来自多个设备的批次统计信息的轴名称。有关轴名称的描述,请参阅 jax.pmap (默认值:None)。只有当模型在设备之间进行细分时才需要此操作,即被归一化的数组在 pmap 或 shard map 内的设备之间进行分片。对于 SPMD jit,您无需手动同步。只需确保轴被正确注释,XLA:SPMD 将插入必要的集合操作。

类型

str | None

axis_index_groups#

该命名轴内的一组轴索引,表示要缩减的设备子集(默认值:无)。 例如,[[0, 1], [2, 3]] 将独立地对前两个和后两个设备上的示例进行批处理规范化。 有关更多详细信息,请参阅 jax.lax.psum

类型

Any

use_fast_variance#

如果为 true,则使用更快但数值稳定性较差的方差计算。

类型

bool

__call__(x, *, mask=None)[source]#

对输入应用层归一化。

参数
  • x – 输入

  • mask – 可广播到 inputs 张量的形状的二进制数组,指示应计算均值和方差的位置。

返回

归一化的输入(与输入相同的形状)。

方法

class flax.linen.GroupNorm(num_groups=32, group_size=None, epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, reduction_axes=None, axis_name=None, axis_index_groups=None, use_fast_variance=True, force_float32_reductions=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

组归一化 (arxiv.org/abs/1803.08494)。

此操作类似于批次归一化,但统计信息在大小相等的通道组之间共享,而不是在批次维度之间共享。因此,组归一化不依赖于批次组成,并且不需要维护内部状态来存储统计信息。用户应指定通道组的总数或每个组的通道数。

注意

LayerNorm 是 GroupNorm 的一个特例,其中 num_groups=1,InstanceNorm 是 GroupNorm 的一个特例,其中 group_size=1

用法示例

>>> import flax.linen as nn
>>> import jax
>>> import numpy as np

>>> x = jax.random.normal(jax.random.key(0), (3, 4, 5, 6))
>>> layer = nn.GroupNorm(num_groups=3)
>>> variables = layer.init(jax.random.key(1), x)
>>> variables
{'params': {'scale': Array([1., 1., 1., 1., 1., 1.], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0.], dtype=float32)}}
>>> y = layer.apply(variables, x)

>>> y = nn.GroupNorm(num_groups=1).apply(variables, x)
>>> y2 = nn.LayerNorm(reduction_axes=(1, 2, 3)).apply(variables, x)
>>> np.testing.assert_allclose(y, y2)

>>> y = nn.GroupNorm(num_groups=None, group_size=1).apply(variables, x)
>>> y2 = nn.InstanceNorm(feature_axes=-1).apply(variables, x)
>>> np.testing.assert_allclose(y, y2)
num_groups#

通道组的总数。原始组归一化论文提出了默认值 32。

类型

int | None

group_size#

一个组中的通道数。

类型

int | None

epsilon#

添加到方差中的一个很小的浮点数,以避免除以零。

类型

float

dtype#

结果的 dtype(默认值:从输入和参数推断)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的数据类型(默认值:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

use_bias#

如果为 True,则添加偏差 (beta)。

类型

bool

use_scale#

如果为 True,则乘以比例 (gamma)。如果下一层是线性的(例如 nn.relu),则可以禁用此选项,因为缩放将由下一层完成。

类型

bool

bias_init#

偏差的初始化器,默认为零。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

scale_init#

比例的初始化器,默认为一。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

reduction_axes#

用于计算归一化统计信息的轴列表。此列表必须包含最后一个维度,该维度被假定为特征轴。此外,如果在调用时使用的输入与用于初始化的数据相比具有额外的引导轴(例如由于批处理),则需要明确定义缩减轴。

类型

Optional[Union[int, collections.abc.Sequence[int]]]

axis_name#

用于组合来自多个设备的批次统计信息的轴名称。有关轴名称的描述,请参阅 jax.pmap (默认值:None)。只有当模型在设备之间进行细分时才需要此操作,即被归一化的数组在 pmap 或 shard map 内的设备之间进行分片。对于 SPMD jit,您无需手动同步。只需确保轴被正确注释,XLA:SPMD 将插入必要的集合操作。

类型

str | None

axis_index_groups#

该命名轴内的一组轴索引,表示要缩减的设备子集(默认值:无)。 例如,[[0, 1], [2, 3]] 将独立地对前两个和后两个设备上的示例进行批处理规范化。 有关更多详细信息,请参阅 jax.lax.psum

类型

Any

use_fast_variance#

如果为 true,则使用更快但数值稳定性较差的方差计算。

类型

bool

__call__(x, *, mask=None)[source]#

将组归一化应用于输入 (arxiv.org/abs/1803.08494)。

参数
  • x – 形状为 ...C 的输入,其中 C 是通道维度,... 表示可用于累积统计信息的任意数量的额外维度。如果未指定任何缩减轴,则所有额外维度 ... 将用于累积统计信息,除了表示批次的引导维度。

  • mask – 可广播到 inputs 张量的形状的二进制数组,指示应计算均值和方差的位置。

返回

归一化的输入(与输入相同的形状)。

方法

class flax.linen.RMSNorm(epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_scale=True, scale_init=<function ones>, reduction_axes=-1, feature_axes=-1, axis_name=None, axis_index_groups=None, use_fast_variance=True, force_float32_reductions=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

RMS 层归一化 (https://arxiv.org/abs/1910.07467)。

RMSNorm 独立地归一化批次中每个给定示例的层的激活,而不是像批量归一化那样跨批次归一化。与 LayerNorm 将均值重新居中为 0 并按激活的标准差归一化不同,RMSNorm 完全不重新居中,而是按激活的均方根归一化。

用法示例

>>> import flax.linen as nn
>>> import jax

>>> x = jax.random.normal(jax.random.key(0), (5, 6))
>>> layer = nn.RMSNorm()
>>> variables = layer.init(jax.random.key(1), x)
>>> variables
{'params': {'scale': Array([1., 1., 1., 1., 1., 1.], dtype=float32)}}
>>> y = layer.apply(variables, x)
epsilon#

添加到方差中的一个很小的浮点数,以避免除以零。

类型

float

dtype#

结果的 dtype(默认值:从输入和参数推断)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的数据类型(默认值:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

use_scale#

如果为 True,则乘以比例 (gamma)。如果下一层是线性的(例如 nn.relu),则可以禁用此选项,因为缩放将由下一层完成。

类型

bool

scale_init#

比例的初始化器,默认为一。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

reduction_axes#

用于计算归一化统计信息的轴。

类型

Union[int, collections.abc.Sequence[int]]

feature_axes#

用于学习偏差和缩放的特征轴。

类型

Union[int, collections.abc.Sequence[int]]

axis_name#

用于组合来自多个设备的批次统计信息的轴名称。有关轴名称的描述,请参阅 jax.pmap (默认值:None)。只有当模型在设备之间进行细分时才需要此操作,即被归一化的数组在 pmap 或 shard map 内的设备之间进行分片。对于 SPMD jit,您无需手动同步。只需确保轴被正确注释,XLA:SPMD 将插入必要的集合操作。

类型

str | None

axis_index_groups#

该命名轴内的一组轴索引,表示要缩减的设备子集(默认值:无)。 例如,[[0, 1], [2, 3]] 将独立地对前两个和后两个设备上的示例进行批处理规范化。 有关更多详细信息,请参阅 jax.lax.psum

类型

Any

use_fast_variance#

如果为 true,则使用更快但数值稳定性较差的方差计算。

类型

bool

__call__(x, *, mask=None)[source]#

对输入应用 RMS 层归一化。

参数
  • x – 输入

  • mask – 可广播到 inputs 张量的形状的二进制数组,指示应计算均值和方差的位置。

返回

归一化的输入(与输入相同的形状)。

方法

class flax.linen.InstanceNorm(epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, feature_axes=-1, axis_name=None, axis_index_groups=None, use_fast_variance=True, force_float32_reductions=True, parent=<flax.linen.module._Sentinel object>, name=None)[源]#

实例归一化 (https://arxiv.org/abs/1607.08022v3)。

InstanceNorm 对每个通道的激活进行归一化(而不是像层归一化那样跨所有通道),并且独立地对批次中的每个给定示例进行归一化(而不是像批次归一化那样跨整个批次)。即,应用一种转换,保持每个示例中每个通道内的平均激活接近 0,并且激活标准差接近 1。

注意

此归一化操作与 LayerNorm 和 GroupNorm 相同;不同之处仅在于减少的轴以及特征轴的形状(即,可学习的比例和偏差参数的形状)。

用法示例

>>> import flax.linen as nn
>>> import jax
>>> import numpy as np

>>> # dimensions: (batch, height, width, channel)
>>> x = jax.random.normal(jax.random.key(0), (2, 3, 4, 5))
>>> layer = nn.InstanceNorm()
>>> variables = layer.init(jax.random.key(1), x)
>>> variables
{'params': {'scale': Array([1., 1., 1., 1., 1.], dtype=float32), 'bias': Array([0., 0., 0., 0., 0.], dtype=float32)}}
>>> y = layer.apply(variables, x)

>>> # having a channel_axis of -1 in InstanceNorm is identical to reducing all non-batch,
>>> # non-channel axes and using the feature_axes as the feature_axes in LayerNorm
>>> y2 = nn.LayerNorm(reduction_axes=[1, 2], feature_axes=-1).apply(variables, x)
>>> np.testing.assert_allclose(y, y2, atol=1e-7)
>>> y3 = nn.GroupNorm(num_groups=x.shape[-1]).apply(variables, x)
>>> np.testing.assert_allclose(y, y3, atol=1e-7)
epsilon#

添加到方差中的一个很小的浮点数,以避免除以零。

类型

float

dtype#

结果的 dtype(默认值:从输入和参数推断)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的数据类型(默认值:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

use_bias#

如果为 True,则添加偏差 (beta)。

类型

bool

use_scale#

如果为 True,则乘以比例 (gamma)。如果下一层是线性的(例如 nn.relu),则可以禁用此选项,因为缩放将由下一层完成。

类型

bool

bias_init#

偏差的初始化器,默认为零。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

scale_init#

比例的初始化器,默认为一。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

feature_axes#

特征的轴。学习的偏差和比例参数将采用由特征轴定义的形状。除了批次轴(假设为前导轴)之外的所有其他轴都将被缩减。

类型

Union[int, collections.abc.Sequence[int]]

axis_name#

用于组合来自多个设备的批次统计信息的轴名称。有关轴名称的描述,请参阅 jax.pmap (默认值:None)。只有当模型在设备之间进行细分时才需要此操作,即被归一化的数组在 pmap 或 shard map 内的设备之间进行分片。对于 SPMD jit,您无需手动同步。只需确保轴被正确注释,XLA:SPMD 将插入必要的集合操作。

类型

str | None

axis_index_groups#

该命名轴内的一组轴索引,表示要缩减的设备子集(默认值:无)。 例如,[[0, 1], [2, 3]] 将独立地对前两个和后两个设备上的示例进行批处理规范化。 有关更多详细信息,请参阅 jax.lax.psum

类型

Any

use_fast_variance#

如果为 true,则使用更快但数值稳定性较差的方差计算。

类型

bool

__call__(x, *, mask=None)[源]#

对输入应用实例归一化。

参数
  • x – 输入

  • mask – 可广播到 inputs 张量的形状的二进制数组,指示应计算均值和方差的位置。

返回

归一化的输入(与输入相同的形状)。

方法

class flax.linen.SpectralNorm(layer_instance, n_steps=1, epsilon=1e-12, dtype=None, param_dtype=<class 'jax.numpy.float32'>, error_on_non_matrix=False, collection_name='batch_stats', parent=<flax.linen.module._Sentinel object>, name=None)[源]#

谱归一化。

请参阅

谱归一化对权重参数进行归一化,使矩阵的谱范数等于 1。这被实现为层包装器,其中每个包装的层都会在计算其 __call__ 输出之前对其参数进行谱归一化。

注意

初始化的变量字典将包含除了“参数”集合之外,还有一个单独的“batch_stats”集合,其中将包含 u 向量和 sigma 值,这些值是执行谱归一化时使用的中间值。在训练期间,我们传入 update_stats=Truemutable=['batch_stats'],以便 usigma 使用幂迭代计算出的最新值进行更新。这将有助于幂迭代方法随着时间的推移更准确地逼近真实的奇异值。在评估期间,我们传入 update_stats=False 以确保我们从模型中获得确定性的行为。

用法示例

>>> import flax, flax.linen as nn
>>> import jax, jax.numpy as jnp
>>> import optax

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x, train):
...     x = nn.Dense(3)(x)
...     # only spectral normalize the params of the second Dense layer
...     x = nn.SpectralNorm(nn.Dense(4))(x, update_stats=train)
...     x = nn.Dense(5)(x)
...     return x

>>> # init
>>> x = jnp.ones((1, 2))
>>> y = jnp.ones((1, 5))
>>> model = Foo()
>>> variables = model.init(jax.random.PRNGKey(0), x, train=False)
>>> flax.core.freeze(jax.tree_util.tree_map(jnp.shape, variables))
FrozenDict({
    batch_stats: {
        SpectralNorm_0: {
            Dense_1/kernel/sigma: (),
            Dense_1/kernel/u: (1, 4),
        },
    },
    params: {
        Dense_0: {
            bias: (3,),
            kernel: (2, 3),
        },
        Dense_1: {
            bias: (4,),
            kernel: (3, 4),
        },
        Dense_2: {
            bias: (5,),
            kernel: (4, 5),
        },
    },
})

>>> # train
>>> def train_step(variables, x, y):
...   def loss_fn(params):
...     logits, updates = model.apply(
...         {'params': params, 'batch_stats': variables['batch_stats']},
...         x,
...         train=True,
...         mutable=['batch_stats'],
...     )
...     loss = jnp.mean(optax.l2_loss(predictions=logits, targets=y))
...     return loss, updates
...
...   (loss, updates), grads = jax.value_and_grad(loss_fn, has_aux=True)(
...       variables['params']
...   )
...   return {
...       'params': jax.tree_util.tree_map(
...           lambda p, g: p - 0.1 * g, variables['params'], grads
...       ),
...       'batch_stats': updates['batch_stats'],
...   }, loss
>>> for _ in range(10):
...   variables, loss = train_step(variables, x, y)

>>> # inference / eval
>>> out = model.apply(variables, x, train=False)
layer_instance#

使用 SpectralNorm 包装的模块实例

类型

flax.linen.module.Module

n_steps#

执行多少步幂迭代来逼近权重参数的奇异值。

类型

int

epsilon#

添加到 l2 归一化中的一个小浮点数,以避免除以零。

类型

float

dtype#

结果的 dtype(默认值:从输入和参数推断)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的数据类型(默认值:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

error_on_non_matrix#

谱归一化仅在矩阵上定义。默认情况下,此模块将返回不变的标量,并将其高阶张量展平到其前导维度中。如果该层使用了维度大于 2 的权重张量,则将此标志设置为 True 将会引发错误。

类型

bool

collection_name#

存储执行谱归一化时使用的中间值的集合的名称。

类型

str

__call__(*args, update_stats, **kwargs)[源]#

在计算 __call__ 输出之前,使用幂迭代计算 self.layer_instance 中权重的最大奇异值,并使用此值对权重进行归一化。

参数
  • *args – 要传递到 self.layer_instance 中底层图层实例的 call 方法的位置参数。

  • update_stats – 如果为 True,则使用幂迭代计算出更新值后,更新内部 u 向量和 sigma 值。这将有助于幂迭代方法随着时间的推移更准确地逼近真实的奇异值。

  • **kwargs – 要传递到 self.layer_instance 中底层图层实例的 call 方法的关键字参数。

返回

使用谱归一化权重的图层输出。

方法

class flax.linen.WeightNorm(layer_instance, epsilon=1e-12, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_scale=True, scale_init=<function ones>, feature_axes=-1, variable_filter=<factory>, parent=<flax.linen.module._Sentinel object>, name=None)[源]#

L2 权重归一化 (https://arxiv.org/abs/1602.07868)。

权重归一化将权重参数归一化,使矩阵的 l2 范数等于 1。这被实现为一个层包装器,其中每个被包装的层在计算其 __call__ 输出之前,其参数都会进行 l2 归一化。

用法示例

>>> import flax, flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class Baz(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     return nn.Dense(2)(x)

>>> class Bar(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = Baz()(x)
...     x = nn.Dense(3)(x)
...     x = Baz()(x)
...     x = nn.Dense(3)(x)
...     return x

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(3)(x)
...     # l2-normalize all params of the second Dense layer
...     x = nn.WeightNorm(nn.Dense(4), variable_filter=None)(x)
...     x = nn.Dense(5)(x)
...     # l2-normalize all kernels in the Bar submodule and all params in
...     # the Baz submodule
...     x = nn.WeightNorm(Bar(), variable_filter={'kernel', 'Baz'})(x)
...     return x

>>> # init
>>> x = jnp.ones((1, 2))
>>> model = Foo()
>>> variables = model.init(jax.random.key(0), x)
>>> flax.core.freeze(jax.tree_util.tree_map(jnp.shape, variables))
FrozenDict({
    params: {
        Bar_0: {
            Baz_0: {
                Dense_0: {
                    bias: (2,),
                    kernel: (5, 2),
                },
            },
            Baz_1: {
                Dense_0: {
                    bias: (2,),
                    kernel: (3, 2),
                },
            },
            Dense_0: {
                bias: (3,),
                kernel: (2, 3),
            },
            Dense_1: {
                bias: (3,),
                kernel: (2, 3),
            },
        },
        Dense_0: {
            bias: (3,),
            kernel: (2, 3),
        },
        Dense_1: {
            bias: (4,),
            kernel: (3, 4),
        },
        Dense_2: {
            bias: (5,),
            kernel: (4, 5),
        },
        WeightNorm_0: {
            Dense_1/bias/scale: (4,),
            Dense_1/kernel/scale: (4,),
        },
        WeightNorm_1: {
            Bar_0/Baz_0/Dense_0/bias/scale: (2,),
            Bar_0/Baz_0/Dense_0/kernel/scale: (2,),
            Bar_0/Baz_1/Dense_0/bias/scale: (2,),
            Bar_0/Baz_1/Dense_0/kernel/scale: (2,),
            Bar_0/Dense_0/kernel/scale: (3,),
            Bar_0/Dense_1/kernel/scale: (3,),
        },
    },
})
layer_instance#

用 WeightNorm 包装的模块实例

类型

flax.linen.module.Module

epsilon#

添加到 l2 归一化中的一个小浮点数,以避免除以零。

类型

float

dtype#

结果的 dtype(默认值:从输入和参数推断)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的数据类型(默认值:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

use_scale#

如果为 True,则创建一个可学习的变量 scale,该变量在 l2 归一化后乘以 layer_instance 的变量。

类型

bool

scale_init#

缩放函数的初始化函数。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

feature_axes#

特征轴维度。l2 范数是通过对剩余的(非特征)轴上的 layer_instance 变量进行缩减计算的。因此,为每个指定的特征计算单独的 l2 范数值,并学习单独的缩放比例(如果 use_scale=True)。默认情况下,尾部维度被视为特征轴。

类型

Optional[Union[int, collections.abc.Sequence[int]]]

variable_filter#

一个可选的可迭代对象,其中包含字符串项。WeightNorm 层会选择性地将 l2 归一化应用于 layer_instance 变量,这些变量的键路径(由“/”分隔)与 variable_filter 匹配。例如,variable_filter={'kernel'} 将仅将 l2 归一化应用于键路径包含“kernel”的变量。默认情况下,variable_filter={'kernel'}

类型

collections.abc.Iterable | None

__call__(*args, **kwargs)[源代码]#

计算 self.layer_instance 中权重的 l2 范数,并在计算 __call__ 输出之前使用该值对权重进行归一化。

参数
  • *args – 要传递到 self.layer_instance 中底层图层实例的 call 方法的位置参数。

  • **kwargs – 要传递到 self.layer_instance 中底层图层实例的 call 方法的关键字参数。

返回

使用 l2 归一化权重的层的输出。

方法

组合器#

class flax.linen.Sequential(layers, parent=<flax.linen.module._Sentinel object>, name=None)[源代码]#

应用模块的线性链。

仅适用于将可调用对象融合在一起的简单情况,其中特定模块/操作的输入是前一个模块/操作的输出。

模块将按照它们在构造函数中传递的顺序应用。

Sequential 的 __call__ 方法接受任何输入,并将其转发到它包含的第一个模块。它按顺序将输出链接到下一个模块的输入,并返回最后一个模块的输出。

用法示例

>>> import flax.linen as nn

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     return nn.Sequential([nn.Dense(4),
...                           nn.relu,
...                           nn.Dense(2),
...                           nn.log_softmax])(x)

由于 Sequential.__call__ 是一个 compact 方法,如果你需要形状推断,你也可以内联传递构造模块的函数。

module = nn.Sequential([
    # << more layers
    lambda x: SomeModule(x.shape[-1])(x), # shape inference
    # << more layers
])

如果返回为元组或字典,此组合器还支持返回多个输出的层。如果一个层的输出是 tuple,它将在下一层中展开为 *args,如果它是一个 dict,它将展开为 **kwargs

用法示例

>>> class CrossAttentionBlock(nn.Module):
...   num_heads: int = 2
...   qkv_features: int = 16
...
...   @nn.compact
...   def __call__(self, query, key_value):
...     output = nn.MultiHeadDotProductAttention(
...       num_heads=self.num_heads, qkv_features=self.qkv_features)(query,
...                                                                 key_value)
...     output = nn.Dense(self.qkv_features)(output)
...     return dict(query=output, key_value=key_value)  # also works for tuples

>>> from typing import Sequence
>>> class CrossAttentionNetwork(nn.Module):
...   num_layers: Sequence[int]
...
...   @nn.compact
...   def __call__(self, x):
...     return nn.Sequential([CrossAttentionBlock() for _ in
...                           range(self.num_layers)])(query, key_value)
layers#

按顺序应用的可调用对象序列。

类型

collections.abc.Sequence[collections.abc.Callable[[…], Any]]

引发

ValueError – 如果 layers 不是序列。

__call__(*args, **kwargs)[源代码]#

将自身作为函数调用。

方法

随机#

class flax.linen.Dropout(rate, broadcast_dims=(), deterministic=None, rng_collection='dropout', parent=<flax.linen.module._Sentinel object>, name=None)[源代码]#

创建 dropout 层。

注意

当使用 Module.apply() 时,请确保包含一个名为 'dropout' 的 RNG 种子。Dropout 对于变量初始化不是必需的。

用法示例

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class MLP(nn.Module):
...   @nn.compact
...   def __call__(self, x, train):
...     x = nn.Dense(4)(x)
...     x = nn.Dropout(0.5, deterministic=not train)(x)
...     return x

>>> model = MLP()
>>> x = jnp.ones((1, 3))
>>> variables = model.init(jax.random.key(0), x, train=False) # don't use dropout
>>> model.apply(variables, x, train=False) # don't use dropout
Array([[-0.88686204, -0.5928178 , -0.5184689 , -0.4345976 ]], dtype=float32)
>>> model.apply(variables, x, train=True, rngs={'dropout': jax.random.key(1)}) # use dropout
Array([[ 0.       , -1.1856356, -1.0369378,  0.       ]], dtype=float32)
rate#

dropout 的概率。(_不是_ 保留率!)

类型

float

broadcast_dims#

将共享相同的 dropout 掩码的维度

类型

collections.abc.Sequence[int]

deterministic#

如果为 false,则输入按 1 / (1 - rate) 缩放并进行掩码,而如果为 true,则不应用掩码,并按原样返回输入。

类型

bool | None

rng_collection#

请求 rng 键时使用的 rng 集合名称。

类型

str

__call__(inputs, deterministic=None, rng=None)[源代码]#

将随机 dropout 掩码应用于输入。

参数
  • inputs – 应该随机掩码的输入。

  • deterministic – 如果为 false,则输入按 1 / (1 - rate) 缩放并进行掩码,而如果为 true,则不应用掩码,并按原样返回输入。

  • rng – 一个可选的 PRNGKey,用作随机键,如果未指定,则将使用 make_rngrng_collection 名称生成一个。

返回

重新加权的掩码输入以保留平均值。

方法

注意力机制#

class flax.linen.MultiHeadDotProductAttention(num_heads, dtype=None, param_dtype=<class 'jax.numpy.float32'>, qkv_features=None, out_features=None, broadcast_dropout=True, dropout_rate=0.0, deterministic=None, precision=None, kernel_init=<function variance_scaling.<locals>.init>, out_kernel_init=None, bias_init=<function zeros>, out_bias_init=None, use_bias=True, attention_fn=<function dot_product_attention>, decode=False, normalize_qk=False, force_fp32_for_softmax=False, qkv_dot_general=None, out_dot_general=None, qkv_dot_general_cls=None, out_dot_general_cls=None, qk_attn_weights_einsum_cls=None, attn_weights_value_einsum_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

多头点积注意力机制。

用法示例

>>> import flax.linen as nn
>>> import jax

>>> layer = nn.MultiHeadDotProductAttention(num_heads=8, qkv_features=16)
>>> key1, key2, key3, key4, key5, key6 = jax.random.split(jax.random.key(0), 6)
>>> shape = (4, 3, 2, 5)
>>> q, k, v = jax.random.uniform(key1, shape), jax.random.uniform(key2, shape), jax.random.uniform(key3, shape)
>>> variables = layer.init(jax.random.key(0), q)

>>> # different inputs for inputs_q, inputs_k and inputs_v
>>> out = layer.apply(variables, q, k, v)
>>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=k, inputs_v=k)
>>> out = layer.apply(variables, q, k)
>>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=q) and layer.apply(variables, inputs_q=q, inputs_k=q, inputs_v=q)
>>> out = layer.apply(variables, q)

>>> attention_kwargs = dict(
...     num_heads=8,
...     qkv_features=16,
...     kernel_init=nn.initializers.ones,
...     bias_init=nn.initializers.zeros,
...     dropout_rate=0.5,
...     deterministic=False,
...     )
>>> class Module(nn.Module):
...   attention_kwargs: dict
...
...   @nn.compact
...   def __call__(self, x, dropout_rng=None):
...     out1 = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng)
...     out2 = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng)
...     return out1, out2
>>> module = Module(attention_kwargs)
>>> variables = module.init({'params': key1, 'dropout': key2}, q)

>>> # out1 and out2 are different.
>>> out1, out2 = module.apply(variables, q, rngs={'dropout': key3})
>>> # out3 and out4 are different.
>>> # out1 and out3 are different. out2 and out4 are different.
>>> out3, out4 = module.apply(variables, q, rngs={'dropout': key4})
>>> # out1 and out2 are the same.
>>> out1, out2 = module.apply(variables, q, dropout_rng=key5)
>>> # out1 and out2 are the same as out3 and out4.
>>> # providing a `dropout_rng` arg will take precedence over the `rngs` arg in `.apply`
>>> out3, out4 = module.apply(variables, q, rngs={'dropout': key6}, dropout_rng=key5)
num_heads#

注意力头的数量。特征(即 inputs_q.shape[-1])应可被注意力头的数量整除。

类型

int

dtype#

计算的数据类型(默认:从输入和参数推断)

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的数据类型(默认:float32)

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

qkv_features#

键(key)、查询(query)和值(value)的维度。

类型

int | None

out_features#

最后投影的维度。

类型

int | None

broadcast_dropout#

沿批次维度使用广播 dropout。

类型

bool

dropout_rate#

Dropout 比率。

类型

float

deterministic#

如果为 False,则使用 dropout 随机屏蔽注意力权重,如果为 True,则注意力权重是确定的。

类型

bool | None

precision#

计算的数值精度,有关详细信息,请参阅 jax.lax.Precision

类型

Union[None, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]

kernel_init#

用于密集层(Dense layers)的核的初始化器。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

out_kernel_init#

输出密集层的核的可选初始化器,如果为 None,则将使用 kernel_init

类型

Optional[Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]]

bias_init#

用于密集层的偏置的初始化器。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

out_bias_init#

输出密集层的偏置的可选初始化器,如果为 None,则将使用 bias_init

类型

Optional[Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]]

use_bias#

点式 QKVO 密集变换是否使用偏置。

类型

bool

attention_fn#

点积注意力或兼容的函数。接受查询(query)、键(key)、值(value),并返回形状为 [bs, dim1, dim2, ..., dimN,, num_heads, value_channels] 的输出

类型

collections.abc.Callable[[…], Union[jax.Array, Any]]

decode#

是否准备和使用自回归缓存。

类型

bool

normalize_qk#

是否应应用 QK 归一化 (arxiv.org/abs/2302.05442)。

类型

bool

qk_attn_weights_einsum_cls#

用于创建计算注意力权重的 einsum 的工厂函数。

类型

collections.abc.Callable[[…], collections.abc.Callable[[…], Union[jax.Array, Any]]] | None

attn_weights_value_einsum_cls#

用于创建计算注意力权重和值的乘积的 einsum 的工厂函数。

类型

collections.abc.Callable[[…], collections.abc.Callable[[…], Union[jax.Array, Any]]] | None

__call__(inputs_q, inputs_k=None, inputs_v=None, *, inputs_kv=None, mask=None, deterministic=None, dropout_rng=None, sow_weights=False)[source]#

在输入数据上应用多头点积注意力。

将输入投影到多头的查询、键和值向量中,应用点积注意力,并将结果投影到输出向量。

如果 inputs_k 和 inputs_v 都为 None,它们都将复制 inputs_q 的值(自注意力)。如果只有 inputs_v 为 None,它将复制 inputs_k 的值。

参数
  • inputs_q – 形状为 [batch_sizes..., length, features] 的输入查询。

  • inputs_k – 形状为 [batch_sizes..., length, features] 的键。如果为 None,则 inputs_k 将复制 inputs_q 的值。

  • inputs_v – 形状为 [batch_sizes..., length, features] 的值。如果为 None,则 inputs_v 将复制 inputs_k 的值。

  • inputs_kv – 形状为 [batch_sizes..., length, features] 的键/值。如果为 None,则 inputs_kv 将复制 inputs_q 的值。此参数即将被弃用。请改用 inputs_k 和 inputs_v。

  • mask – 形状为 [batch_sizes..., num_heads, query_length, key/value_length] 的注意力掩码。如果对应的掩码值为 False,则会屏蔽掉注意力权重。

  • deterministic – 如果为 false,则使用 dropout 随机屏蔽注意力权重,如果为 true,则注意力权重是确定的。

  • dropout_rng – 可选的 rng 键,传递给注意力层的 dropout 掩码。否则,将改为使用 self.make_rng(‘dropout’)。

  • sow_weights – 如果为 True,则注意力权重将播种到“intermediates”集合中。请记住通过 mutable=['intermediates'] 将“intermediates”标记为可变的,以便返回该集合。

返回

形状为 [batch_sizes..., length, features] 的输出。

方法

class flax.linen.MultiHeadAttention(num_heads, dtype=None, param_dtype=<class 'jax.numpy.float32'>, qkv_features=None, out_features=None, broadcast_dropout=True, dropout_rate=0.0, deterministic=None, precision=None, kernel_init=<function variance_scaling.<locals>.init>, out_kernel_init=None, bias_init=<function zeros>, out_bias_init=None, use_bias=True, attention_fn=<function dot_product_attention>, decode=False, normalize_qk=False, force_fp32_for_softmax=False, qkv_dot_general=None, out_dot_general=None, qkv_dot_general_cls=None, out_dot_general_cls=None, qk_attn_weights_einsum_cls=None, attn_weights_value_einsum_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[源代码]#

多头点积注意力。 是 MultiHeadDotProductAttention 的别名。

注意: MultiHeadAttentionMultiHeadDotProductAttention 的一个包装器,因此它们的实现是相同的。但是,MultiHeadAttention 层默认会被命名为 MultiHeadAttention_{index},而 MultiHeadDotProductAttention 将被命名为 MultiHeadDotProductAttention_{index}。因此,这可能会影响检查点、参数收集名称和模块内的 RNG 线程(因为层名称在生成新的 RNG 时使用)。

用法示例

>>> import flax.linen as nn
>>> import jax

>>> layer = nn.MultiHeadAttention(num_heads=8, qkv_features=16)
>>> key1, key2, key3, key4, key5, key6 = jax.random.split(jax.random.key(0), 6)
>>> shape = (4, 3, 2, 5)
>>> q, k, v = jax.random.uniform(key1, shape), jax.random.uniform(key2, shape), jax.random.uniform(key3, shape)
>>> variables = layer.init(jax.random.key(0), q)

>>> # different inputs for inputs_q, inputs_k and inputs_v
>>> out = layer.apply(variables, q, k, v)
>>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=k, inputs_v=k)
>>> out = layer.apply(variables, q, k)
>>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=q) and layer.apply(variables, inputs_q=q, inputs_k=q, inputs_v=q)
>>> out = layer.apply(variables, q)

>>> attention_kwargs = dict(
...     num_heads=8,
...     qkv_features=16,
...     kernel_init=nn.initializers.ones,
...     bias_init=nn.initializers.zeros,
...     dropout_rate=0.5,
...     deterministic=False,
...     )
>>> class Module(nn.Module):
...   attention_kwargs: dict
...
...   @nn.compact
...   def __call__(self, x, dropout_rng=None):
...     out1 = nn.MultiHeadAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng)
...     out2 = nn.MultiHeadAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng)
...     return out1, out2
>>> module = Module(attention_kwargs)
>>> variables = module.init({'params': key1, 'dropout': key2}, q)

>>> # out1 and out2 are different.
>>> out1, out2 = module.apply(variables, q, rngs={'dropout': key3})
>>> # out3 and out4 are different.
>>> # out1 and out3 are different. out2 and out4 are different.
>>> out3, out4 = module.apply(variables, q, rngs={'dropout': key4})
>>> # out1 and out2 are the same.
>>> out1, out2 = module.apply(variables, q, dropout_rng=key5)
>>> # out1 and out2 are the same as out3 and out4.
>>> # providing a `dropout_rng` arg will take precedence over the `rngs` arg in `.apply`
>>> out3, out4 = module.apply(variables, q, rngs={'dropout': key6}, dropout_rng=key5)
num_heads#

注意力头的数量。特征(即 inputs_q.shape[-1])应该可以被头的数量整除。

类型

int

dtype#

计算的数据类型(默认:从输入和参数推断)

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化的数据类型(默认:float32)

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

qkv_features#

键、查询和值的维度。

类型

int | None

out_features#

最后一个投影的维度

类型

int | None

broadcast_dropout#

布尔值:沿批次维度使用广播的 dropout。

类型

bool

dropout_rate#

dropout 率

类型

float

deterministic#

如果为 false,则使用 dropout 随机掩盖注意力权重,而如果为 true,则注意力权重是确定性的。

类型

bool | None

precision#

计算的数值精度,详情请参阅 jax.lax.Precision

类型

Union[None, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]

kernel_init#

Dense 层内核的初始化器。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

bias_init#

Dense 层偏差的初始化器。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

use_bias#

布尔值:逐点 QKVO 密集变换是否使用偏差。

类型

bool

attention_fn#

点积注意力或兼容的函数。接受查询(query)、键(key)、值(value),并返回形状为 [bs, dim1, dim2, ..., dimN,, num_heads, value_channels] 的输出

类型

collections.abc.Callable[[…], Union[jax.Array, Any]]

decode#

是否准备和使用自回归缓存。

类型

bool

normalize_qk#

是否应应用 QK 归一化 (arxiv.org/abs/2302.05442)。

类型

bool

__call__(inputs_q, inputs_k=None, inputs_v=None, *, inputs_kv=None, mask=None, deterministic=None, dropout_rng=None, sow_weights=False)#

在输入数据上应用多头点积注意力。

将输入投影到多头的查询、键和值向量中,应用点积注意力,并将结果投影到输出向量。

如果 inputs_k 和 inputs_v 都为 None,它们都将复制 inputs_q 的值(自注意力)。如果只有 inputs_v 为 None,它将复制 inputs_k 的值。

参数
  • inputs_q – 形状为 [batch_sizes..., length, features] 的输入查询。

  • inputs_k – 形状为 [batch_sizes..., length, features] 的键。如果为 None,则 inputs_k 将复制 inputs_q 的值。

  • inputs_v – 形状为 [batch_sizes..., length, features] 的值。如果为 None,则 inputs_v 将复制 inputs_k 的值。

  • inputs_kv – 形状为 [batch_sizes..., length, features] 的键/值。如果为 None,则 inputs_kv 将复制 inputs_q 的值。此参数即将被弃用。请改用 inputs_k 和 inputs_v。

  • mask – 形状为 [batch_sizes..., num_heads, query_length, key/value_length] 的注意力掩码。如果对应的掩码值为 False,则会屏蔽掉注意力权重。

  • deterministic – 如果为 false,则使用 dropout 随机屏蔽注意力权重,如果为 true,则注意力权重是确定的。

  • dropout_rng – 可选的 rng 键,传递给注意力层的 dropout 掩码。否则,将改为使用 self.make_rng(‘dropout’)。

  • sow_weights – 如果为 True,则注意力权重将播种到“intermediates”集合中。请记住通过 mutable=['intermediates'] 将“intermediates”标记为可变的,以便返回该集合。

返回

形状为 [batch_sizes..., length, features] 的输出。

方法

class flax.linen.SelfAttention(num_heads, dtype=None, param_dtype=<class 'jax.numpy.float32'>, qkv_features=None, out_features=None, broadcast_dropout=True, dropout_rate=0.0, deterministic=None, precision=None, kernel_init=<function variance_scaling.<locals>.init>, out_kernel_init=None, bias_init=<function zeros>, out_bias_init=None, use_bias=True, attention_fn=<function dot_product_attention>, decode=False, normalize_qk=False, force_fp32_for_softmax=False, qkv_dot_general=None, out_dot_general=None, qkv_dot_general_cls=None, out_dot_general_cls=None, qk_attn_weights_einsum_cls=None, attn_weights_value_einsum_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[源代码]#

自注意力是多头点积注意力的特殊情况。 此层已弃用,建议使用 MultiHeadDotProductAttention

用法示例:
>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp
>>> layer = nn.MultiHeadDotProductAttention(num_heads=8, qkv_features=16)
>>> variables = layer.init(jax.random.key(0), jnp.ones((4, 3, 2, 5)))
__call__(inputs_q, mask=None, deterministic=None, dropout_rng=None, sow_weights=False)[源代码]#

在输入数据上应用多头点积自注意力。

将输入投影到多头的查询、键和值向量中,应用点积注意力,并将结果投影到输出向量。

参数
  • inputs_q – 形状为 [batch_sizes..., length, features] 的输入查询。

  • mask – 形状为 [batch_sizes..., num_heads, query_length, key/value_length] 的注意力掩码。如果对应的掩码值为 False,则会屏蔽掉注意力权重。

  • deterministic – 如果为 false,则使用 dropout 随机屏蔽注意力权重,如果为 true,则注意力权重是确定的。

返回

形状为 [batch_sizes..., length, features] 的输出。

方法

flax.linen.dot_product_attention_weights(query, key, bias=None, mask=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0.0, deterministic=False, dtype=None, precision=None, module=None, force_fp32_for_softmax=False, einsum_dot_general=None, einsum=None)[源代码]#

计算给定查询和键的点积注意力权重。

dot_product_attention() 使用,这很可能是你最常用的方法。但如果你想访问注意力权重以进行自省,那么你可以直接调用此函数并自行调用 einsum。

参数
  • query – 用于计算注意力的查询,形状为 [batch..., q_length, num_heads, qk_depth_per_head]

  • key – 用于计算注意力的键,形状为 [batch..., kv_length, num_heads, qk_depth_per_head]

  • bias – 注意力权重的偏置。其形状应可广播为 [batch..., num_heads, q_length, kv_length]。 这可用于结合因果掩码、填充掩码、邻近偏置等。

  • mask – 注意力权重的掩码。其形状应可广播为 [batch..., num_heads, q_length, kv_length]。 这可用于结合因果掩码。 如果对应的掩码值为 False,则注意力权重将被掩盖。

  • broadcast_dropout – bool:使用沿批次维度的广播 dropout。

  • dropout_rng – JAX PRNGKey:用于 dropout。

  • dropout_rate – dropout 率。

  • deterministic – bool,确定性或不确定性(用于应用 dropout)。

  • dtype – 计算的数据类型(默认:从输入和参数推断)。

  • precision – 计算的数值精度,有关详细信息,请参阅 jax.lax.Precision

  • module – 将注意力权重播种到 'intermediates' 集合中的模块。 记住要通过 mutable=['intermediates'] 将 'intermediates' 标记为可变的,以便返回该集合。如果 module 为 None,则不会播种注意力权重。

  • force_fp32_for_softmax – bool,是否强制在 fp32 中计算 softmax。 这对于需要更高精度的混合精度训练以确保数值稳定性非常有用。

  • einsum_dot_general – 在 einsum 中使用的 dot_general。

  • einsum – 如果未指定,将使用默认的 jnp.einsum。 此参数与 precisioneinsum_dot_general 互斥。

引发

ValueError – 如果同时指定了 precision/einsum_dot_generaleinsum

返回

输出的形状为 [batch..., num_heads, q_length, kv_length]

flax.linen.dot_product_attention(query, key, value, bias=None, mask=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0.0, deterministic=False, dtype=None, precision=None, module=None, force_fp32_for_softmax=False, einsum_dot_general=None, qk_attn_weights_einsum=None, attn_weights_value_einsum=None)[源代码]#

计算给定查询、键和值的点积注意力。

这是基于 https://arxiv.org/abs/1706.03762 应用注意力的核心函数。它计算给定查询和键的注意力权重,并使用注意力权重组合值。

注意

querykeyvalue 不需要有任何批次维度。

参数
  • query – 用于计算注意力的查询,形状为 [batch..., q_length, num_heads, qk_depth_per_head]

  • key – 用于计算注意力的键,形状为 [batch..., kv_length, num_heads, qk_depth_per_head]

  • value – 在注意力中使用的值,形状为 [batch..., kv_length, num_heads, v_depth_per_head]

  • bias – 注意力权重的偏置。其形状应可广播为 [batch..., num_heads, q_length, kv_length]。 这可用于结合因果掩码、填充掩码、邻近偏置等。

  • mask – 注意力权重的掩码。其形状应可广播为 [batch..., num_heads, q_length, kv_length]。 这可用于结合因果掩码。 如果对应的掩码值为 False,则注意力权重将被掩盖。

  • broadcast_dropout – bool:使用沿批次维度的广播 dropout。

  • dropout_rng – JAX PRNGKey:用于 dropout。

  • dropout_rate – dropout 率。

  • deterministic – bool,确定性或不确定性(用于应用 dropout)。

  • dtype – 计算的数据类型(默认:从输入推断)。

  • precision – 计算的数值精度,有关详细信息,请参阅 ``jax.lax.Precision`。

  • module – 将注意力权重播种到 'intermediates' 集合中的模块。 记住要通过 mutable=['intermediates'] 将 'intermediates' 标记为可变的,以便返回该集合。如果 module 为 None,则不会播种注意力权重。

  • force_fp32_for_softmax – bool,是否强制在 fp32 中计算 softmax。 这对于需要更高精度的混合精度训练以确保数值稳定性非常有用。

  • einsum_dot_general – 在 jnp.einsum 中使用的 dot_general。

  • qk_attn_weights_einsum – 用于计算注意力权重的 einsum。 如果未指定,将使用默认的 jnp.einsum。 此参数与 precisioneinsum_dot_general 互斥。

  • attn_weights_value_einsum – 用于计算注意力权重和值的乘积的 einsum。 如果未指定,将使用默认的 jnp.einsum。 此参数与 precisioneinsum_dot_general 互斥。

返回

输出的形状为 [batch..., q_length, num_heads, v_depth_per_head]

引发
  • ValueError – 如果同时指定了 precision/einsum_dot_general

  • qk_attn_weights_einsum – 。

flax.linen.make_attention_mask(query_input, key_input, pairwise_fn=<jnp.ufunc 'multiply'>, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[源代码]#

用于注意力权重的掩码生成辅助函数。

对于 1 维输入(即,[batch..., len_q][batch..., len_kv]),注意力权重将为 [batch..., heads, len_q, len_kv],并且此函数将生成 [batch..., 1, len_q, len_kv]

参数
  • query_input – 批处理的,平坦的查询长度大小的输入。

  • key_input – 批处理的,平坦的键长度大小的输入。

  • pairwise_fn – 广播元素比较函数。

  • extra_batch_dims – 要为其添加单例轴的额外批次维度数,默认为 none。

  • dtype – 掩码返回的数据类型。

返回

适用于 1 维注意力的 [batch..., 1, len_q, len_kv] 形状的掩码。

flax.linen.make_causal_mask(x, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[源代码]#

为自注意力生成因果掩码。

对于 1 维输入(即,[batch..., len]),自注意力权重将为 [batch..., heads, len, len],并且此函数将生成形状为 [batch..., 1, len, len] 的因果掩码。

参数
  • x – 形状为 [batch..., len] 的输入数组。

  • extra_batch_dims – 要为其添加单例轴的批次维度数,默认为 none。

  • dtype – 掩码返回的数据类型。

返回

适用于 1 维注意力的 [batch..., 1, len, len] 形状的因果掩码。

循环#

class flax.linen.RNNCellBase(parent=<flax.linen.module._Sentinel object>, name=None)[源代码]#

RNN 单元基类。

__call__(**kwargs)#

将自身作为函数调用。

initialize_carry(rng, input_shape)[源代码]#

初始化 RNN 单元的 carry。

参数
  • rng – 传递给 init_fn 的随机数生成器。

  • input_shape – 一个元组,提供单元输入的形状。

返回

给定 RNN 单元的初始化 carry。

方法

initialize_carry(rng, input_shape)

初始化 RNN 单元的 carry。

class flax.linen.LSTMCell(features, gate_fn=<PjitFunction of <function sigmoid>>, activation_fn=<PjitFunction of <function tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None)[源代码]#

LSTM 单元。

该单元的数学定义如下

\[\begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}\end{split}\]

其中 x 是输入,h 是前一时间步的输出,c 是记忆。

用法示例

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> x = jax.random.normal(jax.random.key(0), (2, 3))
>>> layer = nn.LSTMCell(features=4)
>>> carry = layer.initialize_carry(jax.random.key(1), x.shape)
>>> variables = layer.init(jax.random.key(2), carry, x)
>>> new_carry, out = layer.apply(variables, carry, x)
features#

输出特征的数量。

类型

int

gate_fn#

用于门的激活函数(默认:sigmoid)。

类型

collections.abc.Callable[[…], Any]

activation_fn#

用于输出和内存更新的激活函数(默认:tanh)。

类型

collections.abc.Callable[[…], Any]

kernel_init#

用于转换输入的核的初始化函数(默认:lecun_normal)。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

recurrent_kernel_init#

用于转换隐藏状态的核的初始化函数(默认:initializers.orthogonal())。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

bias_init#

偏置参数的初始化器(默认:initializers.zeros_init())

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

dtype#

计算的数据类型(默认:从输入和参数推断)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的数据类型(默认值:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

__call__(carry, inputs)[源代码]#

一个长短期记忆 (LSTM) 单元。

参数
  • carry – LSTM 单元的隐藏状态,使用 LSTMCell.initialize_carry 初始化。

  • inputs – 一个 ndarray,包含当前时间步的输入。除了最后一个维度之外的所有维度都被视为批次维度。

返回

一个包含新 carry 和输出的元组。

initialize_carry(rng, input_shape)[源代码]#

初始化 RNN 单元的 carry。

参数
  • rng – 传递给 init_fn 的随机数生成器。

  • input_shape – 一个元组,提供单元输入的形状。

返回

给定 RNN 单元的初始化 carry。

方法

initialize_carry(rng, input_shape)

初始化 RNN 单元的 carry。

class flax.linen.OptimizedLSTMCell(features, gate_fn=<PjitFunction of <function sigmoid>>, activation_fn=<PjitFunction of <function tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None)[源代码]#

更高效的 LSTM 单元,它在矩阵乘法之前连接状态组件。

这些参数与 LSTMCell 兼容。请注意,只要隐藏大小大约 <= 2048 个单位,此单元通常比 LSTMCell 快。

该单元的数学定义与 LSTMCell 相同,如下所示

\[\begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}\end{split}\]

其中 x 是输入,h 是前一时间步的输出,c 是记忆。

用法示例

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> x = jax.random.normal(jax.random.key(0), (2, 3))
>>> layer = nn.OptimizedLSTMCell(features=4)
>>> carry = layer.initialize_carry(jax.random.key(1), x.shape)
>>> variables = layer.init(jax.random.key(2), carry, x)
>>> new_carry, out = layer.apply(variables, carry, x)
gate_fn#

用于门的激活函数(默认:sigmoid)。

类型

collections.abc.Callable[[…], Any]

activation_fn#

用于输出和内存更新的激活函数(默认:tanh)。

类型

collections.abc.Callable[[…], Any]

kernel_init#

用于转换输入的核的初始化函数(默认:lecun_normal)。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

recurrent_kernel_init#

用于转换隐藏状态的核的初始化函数(默认:initializers.orthogonal())。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

bias_init#

偏置参数的初始化器(默认:initializers.zeros_init())。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

dtype#

计算的数据类型(默认:从输入和参数推断)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的数据类型(默认值:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

__call__(carry, inputs)[源代码]#

优化的长短期记忆 (LSTM) 单元。

参数
  • carry – LSTM 单元的隐藏状态,使用 LSTMCell.initialize_carry 初始化。

  • inputs – 一个 ndarray,包含当前时间步的输入。除了最后一个维度之外的所有维度都被视为批次维度。

返回

一个包含新 carry 和输出的元组。

initialize_carry(rng, input_shape)[源代码]#

初始化 RNN 单元的 carry。

参数
  • rng – 传递给 init_fn 的随机数生成器。

  • input_shape – 一个元组,提供单元输入的形状。

返回

给定 RNN 单元的初始化 carry。

方法

initialize_carry(rng, input_shape)

初始化 RNN 单元的 carry。

class flax.linen.ConvLSTMCell(features, kernel_size, strides=None, padding='SAME', use_bias=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None)[源代码]#

卷积 LSTM 单元。

该实现基于 xingjian2015convolutional。给定 x_t 和先前的状态 (h_{t-1}, c_{t-1}),核心计算为

\[\begin{split}\begin{array}{ll} i_t = \sigma(W_{ii} * x_t + W_{hi} * h_{t-1} + b_i) \\ f_t = \sigma(W_{if} * x_t + W_{hf} * h_{t-1} + b_f) \\ g_t = \tanh(W_{ig} * x_t + W_{hg} * h_{t-1} + b_g) \\ o_t = \sigma(W_{io} * x_t + W_{ho} * h_{t-1} + b_o) \\ c_t = f_t c_{t-1} + i_t g_t \\ h_t = o_t \tanh(c_t) \end{array}\end{split}\]

其中 * 表示卷积运算符;i_t、f_t、o_t 分别是输入门、遗忘门和输出门的激活值,而 g_t 是单元更新的向量。

注意

遗忘门初始化

根据 jozefowicz2015empirical 的做法,我们在初始化后向 b_f 添加 1.0,以减少训练开始时遗忘的规模。

用法示例

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> x = jax.random.normal(jax.random.key(0), (3, 5, 5))
>>> layer = nn.ConvLSTMCell(features=4, kernel_size=(2, 2))
>>> carry = layer.initialize_carry(jax.random.key(1), x.shape)
>>> variables = layer.init(jax.random.key(2), carry, x)
>>> new_carry, out = layer.apply(variables, carry, x)
features#

卷积滤波器的数量。

类型

int

kernel_size#

卷积核的形状。

类型

collections.abc.Sequence[int]

strides#

一个包含 n 个整数的序列,表示窗口间的步长。

类型

collections.abc.Sequence[int] | None

padding#

字符串 'SAME',字符串 'VALID',或者一个包含 n(low, high) 整数对的序列,表示在每个空间维度之前和之后应用的填充。

类型

str | collections.abc.Sequence[tuple[int, int]]

bias#

是否在输出中添加偏置(默认值:True)。

dtype#

计算的数据类型(默认值:None)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的数据类型(默认值:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

__call__(carry, inputs)[source]#

构建卷积 LSTM。

参数
  • carry – Conv2DLSTM 单元的隐藏状态,使用 Conv2DLSTM.initialize_carry 初始化。

  • inputs – 具有维度(批次,空间维度…,特征)的输入数据。

返回

一个包含新 carry 和输出的元组。

initialize_carry(rng, input_shape)[source]#

初始化 RNN 单元的 carry。

参数
  • rng – 传递给 init_fn 的随机数生成器。

  • input_shape – 一个元组,提供单元输入的形状。

返回

给定 RNN 单元的初始化 carry。

方法

initialize_carry(rng, input_shape)

初始化 RNN 单元的 carry。

class flax.linen.SimpleCell(features, activation_fn=<PjitFunction of <function tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, residual=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

简单单元。

该单元的数学定义如下

\[\begin{array}{ll} h' = \tanh(W_i x + b_i + W_h h) \end{array}\]

其中 x 是输入,h 是上一个时间步的输出。

如果 residualTrue,则

\[\begin{array}{ll} h' = \tanh(W_i x + b_i + W_h h + h) \end{array}\]

用法示例

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> x = jax.random.normal(jax.random.key(0), (2, 3))
>>> layer = nn.SimpleCell(features=4)
>>> carry = layer.initialize_carry(jax.random.key(1), x.shape)
>>> variables = layer.init(jax.random.key(2), carry, x)
>>> new_carry, out = layer.apply(variables, carry, x)
features#

输出特征的数量。

类型

int

activation_fn#

用于输出和内存更新的激活函数(默认:tanh)。

类型

collections.abc.Callable[[…], Any]

kernel_init#

用于转换输入的核的初始化函数(默认:lecun_normal)。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

recurrent_kernel_init#

用于转换隐藏状态的核的初始化函数(默认:initializers.orthogonal())。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

bias_init#

偏置参数的初始化器(默认:initializers.zeros_init())

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

dtype#

计算的数据类型(默认值:None)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的数据类型(默认值:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

residual#

预激活残差连接 (https://arxiv.org/abs/1801.06105)。

类型

bool

__call__(carry, inputs)[source]#

简单单元。

参数
  • carry – 简单单元的隐藏状态,使用 SimpleCell.initialize_carry 初始化。

  • inputs – 一个 ndarray,包含当前时间步的输入。除了最后一个维度之外的所有维度都被视为批次维度。

返回

一个包含新 carry 和输出的元组。

initialize_carry(rng, input_shape)[source]#

初始化 RNN 单元的 carry。

参数
  • rng – 传递给 init_fn 的随机数生成器。

  • input_shape – 一个元组,提供单元输入的形状。

返回

给定 RNN 单元的初始化 carry。

方法

initialize_carry(rng, input_shape)

初始化 RNN 单元的 carry。

class flax.linen.GRUCell(features, gate_fn=<PjitFunction of <function sigmoid>>, activation_fn=<PjitFunction of <function tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

GRU 单元。

该单元的数学定义如下

\[\begin{split}\begin{array}{ll} r = \sigma(W_{ir} x + b_{ir} + W_{hr} h) \\ z = \sigma(W_{iz} x + b_{iz} + W_{hz} h) \\ n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ h' = (1 - z) * n + z * h \\ \end{array}\end{split}\]

其中 x 是输入,h 是上一个时间步的输出。

用法示例

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> x = jax.random.normal(jax.random.key(0), (2, 3))
>>> layer = nn.GRUCell(features=4)
>>> carry = layer.initialize_carry(jax.random.key(1), x.shape)
>>> variables = layer.init(jax.random.key(2), carry, x)
>>> new_carry, out = layer.apply(variables, carry, x)
features#

输出特征的数量。

类型

int

gate_fn#

用于门的激活函数(默认:sigmoid)。

类型

collections.abc.Callable[[…], Any]

activation_fn#

用于输出和内存更新的激活函数(默认:tanh)。

类型

collections.abc.Callable[[…], Any]

kernel_init#

用于转换输入的核的初始化函数(默认:lecun_normal)。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

recurrent_kernel_init#

用于转换隐藏状态的核的初始化函数(默认:initializers.orthogonal())。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

bias_init#

偏置参数的初始化器(默认:initializers.zeros_init())

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

dtype#

计算的数据类型(默认值:None)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的数据类型(默认值:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

__call__(carry, inputs)[source]#

门控循环单元 (GRU) 单元。

参数
  • carry – GRU 单元的隐藏状态,使用 GRUCell.initialize_carry 初始化。

  • inputs – 一个 ndarray,包含当前时间步的输入。除了最后一个维度之外的所有维度都被视为批次维度。

返回

一个包含新 carry 和输出的元组。

initialize_carry(rng, input_shape)[source]#

初始化 RNN 单元的 carry。

参数
  • rng – 传递给 init_fn 的随机数生成器。

  • input_shape – 一个元组,提供单元输入的形状。

返回

给定 RNN 单元的初始化 carry。

方法

initialize_carry(rng, input_shape)

初始化 RNN 单元的 carry。

class flax.linen.MGUCell(features, gate_fn=<PjitFunction of <function sigmoid>>, activation_fn=<PjitFunction of <function tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, forget_bias_init=<function ones>, activation_bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, reset_gate=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

MGU 单元 (https://arxiv.org/pdf/1603.09420.pdf)。

该单元的数学定义如下

\[\begin{split}\begin{array}{ll} f = \sigma(W_{if} x + b_{if} + W_{hf} h) \\ n = \tanh(W_{in} x + b_{in} + f * (W_{hn} h + b_{hn})) \\ h' = (1 - f) * n + f * h \\ \end{array}\end{split}\]

其中 x 是输入,h 是上一个时间步的输出。

如果 reset_gate 为 false,则上述公式变为:

\[\begin{split}\begin{array}{ll} f = \sigma(W_{if} x + b_{if} + W_{hf} h) \\ n = \tanh(W_{in} x + b_{in} + W_{hn} h) \\ h' = (1 - f) * n + f * h \\ \end{array}\end{split}\]

用法示例

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> x = jax.random.normal(jax.random.key(0), (2, 3))
>>> layer = nn.MGUCell(features=4)
>>> carry = layer.initialize_carry(jax.random.key(1), x.shape)
>>> variables = layer.init(jax.random.key(2), carry, x)
>>> new_carry, out = layer.apply(variables, carry, x)
features#

输出特征的数量。

类型

int

gate_fn#

用于门的激活函数(默认:sigmoid)。

类型

collections.abc.Callable[[…], Any]

activation_fn#

用于输出和内存更新的激活函数(默认:tanh)。

类型

collections.abc.Callable[[…], Any]

kernel_init#

用于转换输入的核的初始化函数(默认:lecun_normal)。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

recurrent_kernel_init#

用于转换隐藏状态的核的初始化函数(默认:initializers.orthogonal())。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

forget_bias_init#

遗忘门偏置参数的初始化器。默认设置为 initializers.ones_init(),因为这可以防止梯度消失。更多详情请参考 https://proceedings.mlr.press/v37/jozefowicz15.pdf 的 2.2 节。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

activation_bias_init#

激活输出的偏置参数的初始化器(默认值:initializers.zeros_init())。

类型

Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]

dtype#

计算的数据类型(默认值:None)。

类型

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

传递给参数初始化器的数据类型(默认值:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

reset_gate#

用于应用重置门的标志。

类型

bool

__call__(carry, inputs)[source]#

最小门控单元 (MGU) 单元。

参数
  • carry – MGU 单元的隐藏状态,使用 MGUCell.initialize_carry 初始化。

  • inputs – 一个 ndarray,包含当前时间步的输入。除了最后一个维度之外的所有维度都被视为批次维度。

返回

一个包含新 carry 和输出的元组。

initialize_carry(rng, input_shape)[source]#

初始化 RNN 单元的 carry。

参数
  • rng – 传递给 init_fn 的随机数生成器。

  • input_shape – 一个元组,提供单元输入的形状。

返回

给定 RNN 单元的初始化 carry。

方法

initialize_carry(rng, input_shape)

初始化 RNN 单元的 carry。

class flax.linen.RNN(cell, time_major=False, return_carry=False, reverse=False, keep_order=False, unroll=1, variable_axes=FrozenDict({}), variable_broadcast='params', variable_carry=False, split_rngs=FrozenDict({     params: False, }), parent=<flax.linen.module._Sentinel object>, name=None)[source]#

RNN 模块接收任何 RNNCellBase 实例,并使用 flax.linen.scan() 将其应用于序列。

使用 flax.linen.scan()

示例

>>> import jax.numpy as jnp
>>> import jax
>>> import flax.linen as nn

>>> x = jnp.ones((10, 50, 32)) # (batch, time, features)
>>> lstm = nn.RNN(nn.LSTMCell(64))
>>> variables = lstm.init(jax.random.key(0), x)
>>> y = lstm.apply(variables, x)
>>> y.shape # (batch, time, cell_size)
(10, 50, 64)

如上所示,RNN 使用 cell_size 参数来设置单元格的 initialize_carry 方法的 size 参数,在实践中,这通常是您希望单元格具有的隐藏单元的数量。但是,这可能会因您使用的单元格而异,例如,ConvLSTMCell 需要一个 size 参数,其形式为 (kernel_height, kernel_width, features)

>>> x = jnp.ones((10, 50, 32, 32, 3)) # (batch, time, height, width, features)
>>> conv_lstm = nn.RNN(nn.ConvLSTMCell(64, kernel_size=(3, 3)))
>>> y, variables = conv_lstm.init_with_output(jax.random.key(0), x)
>>> y.shape # (batch, time, height, width, features)
(10, 50, 32, 32, 64)

默认情况下,RNN 期望时间维度在批处理维度之后((*batch, time, *features)),如果您设置 time_major=True,则 RNN 将改为期望时间维度位于开头 ((time, *batch, *features))

>>> x = jnp.ones((50, 10, 32)) # (time, batch, features)
>>> lstm = nn.RNN(nn.LSTMCell(64), time_major=True)
>>> variables = lstm.init(jax.random.key(0), x)
>>> y = lstm.apply(variables, x)
>>> y.shape # (time, batch, cell_size)
(50, 10, 64)

默认情况下,输出是一个形状为 (*batch, time, *cell_size) 的数组(通常),但是,如果您设置 return_carry=True,它将改为返回最终 carry 和输出的元组。

>>> x = jnp.ones((10, 50, 32)) # (batch, time, features)
>>> lstm = nn.RNN(nn.LSTMCell(64), return_carry=True)
>>> variables = lstm.init(jax.random.key(0), x)
>>> carry, y = lstm.apply(variables, x)
>>> jax.tree_util.tree_map(jnp.shape, carry) # ((batch, cell_size), (batch, cell_size))
((10, 64), (10, 64))
>>> y.shape # (batch, time, cell_size)
(10, 50, 64)

为了支持可变长度序列,您可以传递一个 seq_lengths,它是一个形状为 (*batch) 的整数数组,其中每个元素都是批处理中序列的长度。例如

>>> seq_lengths = jnp.array([3, 2, 5])

与填充元素相对应的输出元素不会归零。如果 return_carry 设置为 True,则 carry 将是每个序列的最后一个有效元素的状态。

RNN 还接受 flax.linen.scan() 的一些参数,默认情况下,它们设置为与 LSTMCellGRUCell 等单元格一起使用,但可以根据需要覆盖它们。覆盖扫描的默认值如下所示

>>> lstm = nn.RNN(
...   nn.LSTMCell(64),
...   unroll=1, variable_axes={}, variable_broadcast='params',
...   variable_carry=False, split_rngs={'params': False})
cell#

RNNCellBase 的一个实例。

类型

flax.linen.recurrent.RNNCellBase

time_major#

如果 time_major=False(默认),它将期望输入的形状为 (*batch, time, *features),否则它将期望输入的形状为 (time, *batch, *features)

类型

bool

return_carry#

如果 return_carry=False(默认),则仅返回输出序列,否则将返回最终 carry 和输出序列的元组。

类型

bool

reverse#

如果 reverse=False(默认),则序列将从左到右处理,并按原始顺序返回,否则将从右到左处理,并按相反顺序返回。如果传递了 seq_lengths,则填充将始终保留在序列的末尾。

类型

bool

keep_order#

如果 keep_order=True,当 reverse=True 时,输出将在处理后反转回原始顺序,这对于在双向 RNN 中对齐序列很有用。如果 keep_order=False(默认值),则输出将保持 reverse 指定的顺序。

类型

bool

unroll#

在单个循环迭代中展开的扫描迭代次数,默认为 1。此参数将传递给 nn.scan

类型

int

variable_axes#

一个字典,将每个集合映射到一个整数 i(表示我们扫描维度 i)或 None(复制而不是扫描)。此参数将转发到 nn.scan

类型

collections.abc.Mapping[Union[bool, str, Collection[str], DenyList], Union[int, flax.typing.In[int], flax.typing.Out[int]]]

variable_broadcast#

指定广播的变量集合。广播的变量不应依赖于任何无法从循环中提取的计算。这通常用于在 fn 内部定义共享参数。此参数将转发到 nn.scan

类型

Union[bool, str, Collection[str], DenyList]

variable_carry#

指定通过循环携带的变量集合。对这些变量的更改将传递到下一次迭代,并在扫描完成时保留。此参数将转发到 nn.scan

类型

Union[bool, str, Collection[str], DenyList]

split_rngs#

一个从 PRNGSequenceFilter 到布尔值的映射,指定是否应拆分集合的 PRNG 密钥,使其值在每个步骤都不同,或复制使其值在每个步骤保持不变。此参数将转发到 nn.scan

类型

collections.abc.Mapping[Union[bool, str, Collection[str], DenyList], bool]

__call__(inputs, *, initial_carry=None, init_key=None, seq_lengths=None, return_carry=None, time_major=None, reverse=None, keep_order=None)[source]#

将 RNN 应用于输入。

__call__ 允许您选择性地覆盖构造函数中定义的某些属性,如 return_carrytime_major

参数
  • inputs – 输入序列。

  • initial_carry – 初始 carry,如果未提供,它将使用 cell 的 RNNCellBase.initialize_carry() 方法进行初始化。

  • init_key – 用于初始化 carry 的 PRNG 密钥,如果未提供,将使用 jax.random.key(0)。大多数 cell 将忽略此参数。

  • seq_lengths – 一个可选的整数数组,形状为 (*batch),指示每个序列的长度,时间维度中索引大于相应长度的元素将被视为填充并被忽略。

  • return_carry – 如果 return_carry=False(默认值),则仅返回输出序列,否则将返回最终 carry 和输出序列的元组。

  • time_major – 如果 time_major=False(默认值),它将期望输入的形状为 (*batch, time, *features),否则它将期望输入的形状为 (time, *batch, *features)

  • reverse – 覆盖 reverse 属性,如果 reverse=False(默认值),则序列将从左到右处理并按原始顺序返回,否则将从右到左处理,并按相反顺序返回。如果传递了 seq_lengths,则填充将始终保留在序列的末尾。

  • keep_order – 覆盖 keep_order 属性,如果 keep_order=True,当 reverse=True 时,输出将在处理后反转回原始顺序,这对于在双向 RNN 中对齐序列很有用。如果 keep_order=False(默认值),则输出将保持 reverse 指定的顺序。

返回

如果 return_carry=False(默认),则仅返回输出序列,否则将返回最终 carry 和输出序列的元组。

方法

class flax.linen.Bidirectional(forward_rnn, backward_rnn, merge_fn=<function _concatenate>, time_major=False, return_carry=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

在两个方向处理输入并合并结果。

用法示例

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> layer = nn.Bidirectional(nn.RNN(nn.GRUCell(4)), nn.RNN(nn.GRUCell(4)))
>>> x = jnp.ones((2, 3))
>>> variables = layer.init(jax.random.key(0), x)
>>> out = layer.apply(variables, x)
__call__(inputs, *, initial_carry=None, init_key=None, seq_lengths=None, return_carry=None, time_major=None, reverse=None, keep_order=None)[source]#

将自身作为函数调用。

方法

BatchApply#

class flax.linen.BatchApply(f, num_dims=2)[source]#

临时合并输入张量的头部维度。

将张量的头部维度合并为单个维度,运行给定的可调用对象,然后拆分结果的头部维度以匹配输入。

秩小于要折叠的维度数量的输入数组将保持不变。

这对于将模块应用于例如 [Time, Batch, ...] 数组的每个时间步可能很有用。

对于某些 f 和平台,这可能比 jax.vmap() 更高效,特别是当与其他转换(如 jax.grad())结合使用时。

用法示例

>>> import jax, jax.numpy as jnp

>>> a = jax.random.normal(jax.random.key(0), [2, 3, 4])
>>> b = jax.random.normal(jax.random.key(1), [4])

>>> def raises(a, b):
...   if len(a.shape) != 2:
...     raise ValueError("a must be shape 2")
...   if len(b.shape) != 1:
...     raise ValueError("b must be shape 1")
...   return jnp.dot(a, b)

>>> out = BatchApply(raises)(a, b)
>>> expected_merged_leading = raises(a.reshape(2*3, 4), b)
>>> expected = expected_merged_leading.reshape((2, 3) + expected_merged_leading.shape[1:])
>>> np.testing.assert_array_equal(out, expected)
__call__(*args, **kwargs)[源代码]#

将自身作为函数调用。

方法