激活函数#

激活函数。

class flax.linen.activation.PReLU(param_dtype=<class 'jax.numpy.float32'>, negative_slope_init=0.01, parent=<flax.linen.module._Sentinel object>, name=None)[源代码]#

参数化修正线性单元 (PReLU) 激活函数。

请注意,PReLU 是一个 Flax 层,而不是一个简单的激活函数,因此需要在调用之前进行初始化。

用法示例:
>>> import flax.linen as nn
>>> class MLP(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(2)(x)
...     x = nn.PReLU()(x) # initialized
...     return x
param_dtype#

传递给参数初始化器的 dtype(默认值:float32)。

类型

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

negative_slope_init#

初始化负斜率的值(默认值 0.01)。

类型

float

__call__(inputs)[源代码]#

将激活应用于输入。

参数

inputs – 应用激活函数的 nd-array。

返回

转换后的输入。

param_dtype#

别名 of float32

flax.linen.activation.celu(x, alpha=1.0)[源代码]#

连续可微指数线性单元激活。

计算逐元素函数

\[\begin{split}\mathrm{celu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0 \end{cases}\end{split}\]

有关更多信息,请参阅 Continuously Differentiable Exponential Linear Units

参数
  • x – 输入数组

  • alpha – 数组或标量(默认值:1.0)

返回

一个数组。

flax.linen.activation.elu(x, alpha=1.0)[源代码]#

指数线性单元激活函数。

计算逐元素函数

\[\begin{split}\mathrm{elu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(x) - 1\right), & x \le 0 \end{cases}\end{split}\]
参数
  • x – 输入数组

  • alpha – alpha 值的标量或数组(默认值:1.0)

返回

一个数组。

另请参阅

selu()

flax.linen.activation.gelu(x, approximate=True)[源代码]#

高斯误差线性单元激活函数。

如果 approximate=False,则计算逐元素函数

\[\mathrm{gelu}(x) = \frac{x}{2} \left(\mathrm{erfc} \left( \frac{-x}{\sqrt{2}} \right) \right)\]

如果 approximate=True,则使用 GELU 的近似公式

\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left( \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)\]

有关更多信息,请参阅 Gaussian Error Linear Units (GELUs),第 2 节。

参数
  • x – 输入数组

  • approximate – 是否使用近似或精确公式。

flax.linen.activation.glu(x, axis=-1)[源代码]#

门控线性单元激活函数。

计算函数

\[\mathrm{glu}(x) = x\left[\ldots, 0:\frac{n}{2}, \ldots\right] \cdot \mathrm{sigmoid} \left( x\left[\ldots, \frac{n}{2}:n, \ldots\right] \right)\]

其中数组沿 axis 分成两部分。 axis 维度的大小必须可以被 2 整除。

参数
  • x – 输入数组

  • axis – 应计算拆分的轴(默认值:-1)

返回

一个数组。

另请参阅

sigmoid()

flax.linen.activation.hard_sigmoid(x)[源代码]#

硬 Sigmoid 激活函数。

计算逐元素函数

\[\mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6}\]
参数

x – 输入数组

返回

一个数组。

另请参阅

relu6()

flax.linen.activation.hard_silu(x)[源代码]#

硬 SiLU (swish) 激活函数

计算逐元素函数

\[\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)\]

hard_silu()hard_swish() 都是同一个函数的别名。

参数

x – 输入数组

返回

一个数组。

另请参阅

hard_sigmoid()

flax.linen.activation.hard_swish(x)#

硬 SiLU (swish) 激活函数

计算逐元素函数

\[\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)\]

hard_silu()hard_swish() 都是同一个函数的别名。

参数

x – 输入数组

返回

一个数组。

另请参阅

hard_sigmoid()

flax.linen.activation.hard_tanh(x)[source]#

\(\mathrm{tanh}\) 激活函数。

计算逐元素函数

\[\begin{split}\mathrm{hard\_tanh}(x) = \begin{cases} -1, & x < -1\\ x, & -1 \le x \le 1\\ 1, & 1 < x \end{cases}\end{split}\]
参数

x – 输入数组

返回

一个数组。

flax.linen.activation.leaky_relu(x, negative_slope=0.01)[source]#

Leaky ReLU 激活函数。

计算逐元素函数

\[\begin{split}\mathrm{leaky\_relu}(x) = \begin{cases} x, & x \ge 0\\ \alpha x, & x < 0 \end{cases}\end{split}\]

其中 \(\alpha\) = negative_slope

参数
  • x – 输入数组

  • negative_slope – 指定负斜率的数组或标量 (默认值: 0.01)

返回

一个数组。

另请参阅

relu()

flax.linen.activation.log_sigmoid(x)[source]#

Log-sigmoid 激活函数。

计算逐元素函数

\[\mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})\]
参数

x – 输入数组

返回

一个数组。

另请参阅

sigmoid()

flax.linen.activation.log_softmax(x, axis=-1, where=None, initial=_UNSPECIFIED)[source]#

Log-Softmax 函数。

计算 softmax 函数的对数,该函数将元素重新缩放到范围 \([-\infty, 0)\)

\[\mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)} \right)\]
参数
  • x – 输入数组

  • axis – 应计算 log_softmax 的轴或轴。可以是整数或整数元组。

  • where – 要包含在 log_softmax 中的元素。

返回

一个数组。

注意

如果任何输入值为 +inf,结果将全部为 NaN:这反映了在浮点数学的上下文中 inf / inf 没有明确定义的这一事实。

另请参阅

softmax()

flax.linen.activation.logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False, where=None)[source]#

Log-sum-exp 归约。

JAX 实现的 scipy.special.logsumexp()

\[\mathrm{logsumexp}(a) = \mathrm{log} \sum_j b \cdot \mathrm{exp}(a_{ij})\]

其中 \(j\) 索引范围覆盖一个或多个要归约的维度。

参数
  • a – 输入数组

  • axis – 要在其上进行归约的轴或轴。可以是 None、整数或整数元组。

  • b\(\mathrm{exp}(a)\) 的缩放因子。必须可以广播到 a 的形状。

  • keepdims – 如果为 True,则归约的轴将保留在输出中作为大小为 1 的维度。

  • return_sign – 如果为 True,则输出将为 (result, sign) 对,其中 sign 是总和的符号,result 包含其绝对值的对数。如果为 False,则仅返回 result,并且如果总和为负,它将包含 NaN 值。

  • where – 要包含在归约中的元素。

返回

根据 return_sign 参数的值,可以是数组 result 或数组对 (result, sign)

flax.linen.activation.one_hot(x, num_classes, *, dtype=<class 'jax.numpy.float64'>, axis=-1)[source]#

对给定索引进行独热编码。

输入 x 中的每个索引都编码为长度为 num_classes 的零向量,并且 index 处的元素设置为 1

>>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
Array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)

范围 [0, num_classes) 之外的索引将编码为零

>>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
Array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32)
参数
  • x – 索引张量。

  • num_classes – 独热维度中的类数。

  • dtype – 可选,返回值的浮点 dtype(默认 jnp.float_)。

  • axis – 应计算函数的轴或轴。

flax.linen.activation.relu(x)[source]#

修正线性单元激活函数。

计算逐元素函数

\[\mathrm{relu}(x) = \max(x, 0)\]

但在微分下,我们取

\[\nabla \mathrm{relu}(0) = 0\]

有关更多信息,请参阅 ReLU'(0) 对反向传播的数值影响

参数

x – 输入数组

返回

一个数组。

示例

>>> jax.nn.relu(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.]))
Array([0. , 0. , 0. , 0. , 0.5, 1. , 2. ], dtype=float32)

另请参阅

relu6()

flax.linen.activation.selu(x)[source]#

缩放指数线性单元激活。

计算逐元素函数

\[\begin{split}\mathrm{selu}(x) = \lambda \begin{cases} x, & x > 0\\ \alpha e^x - \alpha, & x \le 0 \end{cases}\end{split}\]

其中 \(\lambda = 1.0507009873554804934193349852946\)\(\alpha = 1.6732632423543772848170429916717\)

有关更多信息,请参阅 自归一化神经网络

参数

x – 输入数组

返回

一个数组。

另请参阅

elu()

flax.linen.activation.sigmoid(x)[source]#

Sigmoid 激活函数。

计算逐元素函数

\[\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}\]
参数

x – 输入数组

返回

一个数组。

另请参阅

log_sigmoid()

flax.linen.activation.silu(x)[source]#

SiLU(又名 swish)激活函数。

计算逐元素函数

\[\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}\]

swish()silu() 都是同一函数的别名。

参数

x – 输入数组

返回

一个数组。

另请参阅

sigmoid()

flax.linen.activation.soft_sign(x)[source]#

Soft-sign 激活函数。

计算逐元素函数

\[\mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}\]
参数

x – 输入数组

flax.linen.activation.softmax(x, axis=-1, where=None, initial=_UNSPECIFIED)[源代码]#

Softmax 函数。

计算将元素重新缩放到范围 \([0, 1]\) 的函数,使得沿 axis 的元素之和为 \(1\)

\[\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\]
参数
  • x – 输入数组

  • axis – 应该计算 softmax 的轴或多个轴。 这些维度上求和的 softmax 输出应该总和为 \(1\)。 可以是整数或整数元组。

  • where – 要包含在 softmax 中的元素。

返回

一个数组。

注意

如果任何输入值为 +inf,结果将全部为 NaN:这反映了在浮点数学的上下文中 inf / inf 没有明确定义的这一事实。

另请参阅

log_softmax()

flax.linen.activation.softplus(x)[源代码]#

Softplus 激活函数。

计算逐元素函数

\[\mathrm{softplus}(x) = \log(1 + e^x)\]
参数

x – 输入数组

flax.linen.activation.standardize(x, axis=-1, mean=None, variance=None, epsilon=1e-05, where=None)[源代码]#

通过减去 mean 并除以 \(\sqrt{\mathrm{variance}}\) 来标准化数组。

flax.linen.activation.swish(x)#

SiLU(又名 swish)激活函数。

计算逐元素函数

\[\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}\]

swish()silu() 都是同一函数的别名。

参数

x – 输入数组

返回

一个数组。

另请参阅

sigmoid()

flax.linen.activation.tanh(x, /)#

计算输入的逐元素双曲正切值。

numpy.tanh 的 JAX 实现。

双曲正切定义为

\[tanh(x) = \frac{sinh(x)}{cosh(x)} = \frac{e^x - e^{-x}}{e^x + e^{-x}}\]
参数

x – 输入数组或标量。

返回

一个包含 x 的每个元素的双曲正切值的数组,并提升到非精确 dtype。

注意

jnp.tanh 等效于计算 -1j * jnp.tan(1j * x)

另请参阅

  • jax.numpy.sinh():计算输入的逐元素双曲正弦值。

  • jax.numpy.cosh():计算输入的逐元素双曲余弦值。

  • jax.numpy.arctanh():计算输入的逐元素双曲正切的反函数。

示例

>>> x = jnp.array([[-1, 0, 1],
...                [3, -2, 5]])
>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.tanh(x)
Array([[-0.762,  0.   ,  0.762],
       [ 0.995, -0.964,  1.   ]], dtype=float32)
>>> with jnp.printoptions(precision=3, suppress=True):
...   -1j * jnp.tan(1j * x)
Array([[-0.762+0.j,  0.   -0.j,  0.762-0.j],
       [ 0.995-0.j, -0.964+0.j,  1.   -0.j]],      dtype=complex64, weak_type=True)

对于复数值输入

>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.tanh(2-5j)
Array(1.031+0.021j, dtype=complex64, weak_type=True)
>>> with jnp.printoptions(precision=3, suppress=True):
...   -1j * jnp.tan(1j * (2-5j))
Array(1.031+0.021j, dtype=complex64, weak_type=True)