初始化器#
Flax 的初始化器。
- flax.linen.initializers.constant(value, dtype=<class 'jax.numpy.float64'>)#
构建一个初始化器,该初始化器返回充满常数
value
的数组。- 参数
value – 用于填充初始化器的常量值。
dtype – 可选; 初始化器的默认 dtype。
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.constant(-7) >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[-7., -7., -7.], [-7., -7., -7.]], dtype=float32)
- flax.linen.initializers.delta_orthogonal(scale=1.0, column_axis=-1, dtype=<class 'jax.numpy.float64'>)#
构建一个用于 delta 正交内核的初始化器。
- 参数
scale – 均匀分布的上限。
column_axis – 包含应正交的列的轴。
dtype – 权重的默认 dtype。
- 返回值
一个 delta 正交初始化器。传递给初始化器的形状必须是 3D、4D 或 5D。
示例
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.delta_orthogonal() >>> initializer(jax.random.key(42), (3, 3, 3), jnp.float32) Array([[[ 0. , 0. , 0. ], [ 0. , 0. , 0. ], [ 0. , 0. , 0. ]], [[ 0.27858758, -0.7949833 , -0.53887904], [ 0.9120717 , 0.04322892, 0.40774566], [-0.30085585, -0.6050892 , 0.73712474]], [[ 0. , 0. , 0. ], [ 0. , 0. , 0. ], [ 0. , 0. , 0. ]]], dtype=float32)
- flax.linen.initializers.glorot_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
构建一个 Glorot 正态初始化器(又名 Xavier 正态初始化器)。
Glorot 正态初始化器 是
jax.nn.initializers.variance_scaling()
的一种特殊形式,其中scale = 1.0
,mode="fan_avg"
和distribution="truncated_normal"
。- 参数
in_axis – 权重数组中输入维度的轴或轴序列。
out_axis – 权重数组中输出维度的轴或轴序列。
batch_axis – 权重数组中应忽略的轴或轴序列。
dtype – 权重的 dtype。
- 返回值
一个初始化器。
示例
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.glorot_normal() >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[ 0.41770416, 0.75262755, 0.7619329 ], [-0.5516644 , -0.6028657 , 0.08661086]], dtype=float32)
- flax.linen.initializers.glorot_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
构建一个 Glorot 均匀初始化器(又名 Xavier 均匀初始化器)。
Glorot 均匀初始化器 是
jax.nn.initializers.variance_scaling()
的一种特殊形式,其中scale = 1.0
,mode="fan_avg"
和distribution="uniform"
。- 参数
in_axis – 权重数组中输入维度的轴或轴序列。
out_axis – 权重数组中输出维度的轴或轴序列。
batch_axis – 权重数组中应忽略的轴或轴序列。
dtype – 权重的 dtype。
- 返回值
一个初始化器。
示例
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.glorot_uniform() >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[ 0.50350785, 0.8088631 , 0.81566876], [-0.6393332 , -0.6865721 , 0.11003882]], dtype=float32)
- flax.linen.initializers.he_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
构建一个 He 正态初始化器(又名 Kaiming 正态初始化器)。
He 正态初始化器 是
jax.nn.initializers.variance_scaling()
的一种特殊形式,其中scale = 2.0
,mode="fan_in"
和distribution="truncated_normal"
。- 参数
in_axis – 权重数组中输入维度的轴或轴序列。
out_axis – 权重数组中输出维度的轴或轴序列。
batch_axis – 权重数组中应忽略的轴或轴序列。
dtype – 权重的 dtype。
- 返回值
一个初始化器。
示例
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.he_normal() >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[ 0.6604483 , 1.1900088 , 1.2047218 ], [-0.87225807, -0.95321447, 0.1369438 ]], dtype=float32)
- flax.linen.initializers.he_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
构建一个 He 均匀初始化器(又名 Kaiming 均匀初始化器)。
He 均匀初始化器 是
jax.nn.initializers.variance_scaling()
的一种特殊形式,其中scale = 2.0
,mode="fan_in"
和distribution="uniform"
。- 参数
in_axis – 权重数组中输入维度的轴或轴序列。
out_axis – 权重数组中输出维度的轴或轴序列。
batch_axis – 权重数组中应忽略的轴或轴序列。
dtype – 权重的 dtype。
- 返回值
一个初始化器。
示例
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.he_uniform() >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[ 0.79611576, 1.2789248 , 1.2896855 ], [-1.0108745 , -1.0855657 , 0.17398663]], dtype=float32)
- flax.linen.initializers.kaiming_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
构建一个 He 正态初始化器(又名 Kaiming 正态初始化器)。
He 正态初始化器 是
jax.nn.initializers.variance_scaling()
的一种特殊形式,其中scale = 2.0
,mode="fan_in"
和distribution="truncated_normal"
。- 参数
in_axis – 权重数组中输入维度的轴或轴序列。
out_axis – 权重数组中输出维度的轴或轴序列。
batch_axis – 权重数组中应忽略的轴或轴序列。
dtype – 权重的 dtype。
- 返回值
一个初始化器。
示例
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.he_normal() >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[ 0.6604483 , 1.1900088 , 1.2047218 ], [-0.87225807, -0.95321447, 0.1369438 ]], dtype=float32)
- flax.linen.initializers.kaiming_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
构建一个 He 均匀初始化器(又名 Kaiming 均匀初始化器)。
He 均匀初始化器 是
jax.nn.initializers.variance_scaling()
的一种特殊形式,其中scale = 2.0
,mode="fan_in"
和distribution="uniform"
。- 参数
in_axis – 权重数组中输入维度的轴或轴序列。
out_axis – 权重数组中输出维度的轴或轴序列。
batch_axis – 权重数组中应忽略的轴或轴序列。
dtype – 权重的 dtype。
- 返回值
一个初始化器。
示例
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.he_uniform() >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[ 0.79611576, 1.2789248 , 1.2896855 ], [-1.0108745 , -1.0855657 , 0.17398663]], dtype=float32)
- flax.linen.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
构建一个 LeCun 正态初始化器。
LeCun 正态初始化器 是
jax.nn.initializers.variance_scaling()
的一种特殊形式,其中scale = 1.0
,mode="fan_in"
, 并且distribution="truncated_normal"
。- 参数
in_axis – 权重数组中输入维度的轴或轴序列。
out_axis – 权重数组中输出维度的轴或轴序列。
batch_axis – 权重数组中应忽略的轴或轴序列。
dtype – 权重的 dtype。
- 返回值
一个初始化器。
示例
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.lecun_normal() >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[ 0.46700746, 0.8414632 , 0.8518669 ], [-0.61677957, -0.67402434, 0.09683388]], dtype=float32)
- flax.linen.initializers.lecun_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
构建一个 LeCun 均匀初始化器。
LeCun 均匀初始化器 是
jax.nn.initializers.variance_scaling()
的一种特殊形式,其中scale = 1.0
,mode="fan_in"
, 并且distribution="uniform"
。- 参数
in_axis – 权重数组中输入维度的轴或轴序列。
out_axis – 权重数组中输出维度的轴或轴序列。
batch_axis – 权重数组中应忽略的轴或轴序列。
dtype – 权重的 dtype。
- 返回值
一个初始化器。
示例
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.lecun_uniform() >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[ 0.56293887, 0.90433645, 0.9119454 ], [-0.71479625, -0.7676109 , 0.12302713]], dtype=float32)
- flax.linen.initializers.normal(stddev=0.01, dtype=<class 'jax.numpy.float64'>)#
构建一个返回实数正态分布随机数组的初始化器。
- 参数
stddev – 可选参数;分布的标准差。
dtype – 可选; 初始化器的默认 dtype。
- 返回值
一个返回数组的初始化器,该数组的值服从均值为
0
,标准差为stddev
的正态分布。
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.normal(5.0) >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[ 3.0613258 , 5.6129413 , 5.6866574 ], [-4.063663 , -4.4520254 , 0.63115686]], dtype=float32)
- flax.linen.initializers.truncated_normal(stddev=0.01, dtype=<class 'jax.numpy.float64'>, lower=-2.0, upper=2.0)#
构建一个返回截断正态随机数组的初始化器。
- 参数
stddev – 可选参数;未截断分布的标准差。 请注意,此函数不会像在 variancescaling 初始化器中那样应用 stddev 校正,用户如果希望使用它,则应通过 stddev 参数自行应用此校正。
dtype – 可选; 初始化器的默认 dtype。
lower – 表示截断下限的浮点数。 在输出乘以 stddev 之前应用。
upper – 表示截断上限的浮点数。 在输出乘以 stddev 之前应用。
- 返回值
一个返回数组的初始化器,该数组的值遵循均值为
0
,标准差为stddev
,范围为 \(\rm{lower * stddev} < x < \rm{upper * stddev}\) 的截断正态分布。
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.truncated_normal(5.0) >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[ 2.9047365, 5.2338114, 5.29852 ], [-3.836303 , -4.192359 , 0.6022964]], dtype=float32)
- flax.linen.initializers.ones(key, shape, dtype=<class 'jax.numpy.float64'>)#
一个返回充满 1 的常量数组的初始化器。
忽略
key
参数。>>> import jax, jax.numpy as jnp >>> jax.nn.initializers.ones(jax.random.key(42), (3, 2), jnp.float32) Array([[1., 1.], [1., 1.], [1., 1.]], dtype=float32)
- flax.linen.initializers.ones_init()[source]#
构建一个返回充满 1 的常量数组的初始化器。
>>> import jax, jax.numpy as jnp >>> from flax.linen.initializers import ones_init >>> ones_initializer = ones_init() >>> ones_initializer(jax.random.key(42), (3, 2), jnp.float32) Array([[1., 1.], [1., 1.], [1., 1.]], dtype=float32)
- flax.linen.initializers.orthogonal(scale=1.0, column_axis=-1, dtype=<class 'jax.numpy.float64'>)#
构建一个返回均匀分布的正交矩阵的初始化器。
如果形状不是正方形,则矩阵将具有正交的行或列,具体取决于哪一侧较小。
- 参数
scale – 均匀分布的上限。
column_axis – 包含应正交的列的轴。
dtype – 权重的默认 dtype。
- 返回值
一个正交初始化器。
示例
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.orthogonal() >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01], [ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32)
- flax.linen.initializers.uniform(scale=0.01, dtype=<class 'jax.numpy.float64'>)#
构建一个返回实数均匀分布随机数组的初始化器。
- 参数
scale – 可选参数;随机分布的上限。
dtype – 可选; 初始化器的默认 dtype。
- 返回值
一个返回数组的初始化器,该数组的值在
[0, scale)
范围内均匀分布。
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.uniform(10.0) >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[7.298188 , 8.691938 , 8.7230015], [2.0818567, 1.8662417, 5.5022564]], dtype=float32)
- flax.linen.initializers.variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
初始化器,使其尺度适应权重张量的形状。
使用
distribution="truncated_normal"
或distribution="normal"
时,样本从均值为零且标准差(如果适用,则在截断后)为 \(\sqrt{\frac{scale}{n}}\) 的(截断)正态分布中抽取,其中 n 是如果
mode="fan_in"
,则为权重张量中的输入单元数,如果
mode="fan_out"
,则为输出单元数,或者如果
mode="fan_avg"
,则为输入和输出单元数的平均值。
可以使用
in_axis
、out_axis
和batch_axis
配置此初始化器,以用于通用卷积层或密集层;任何未在这些参数中的轴都假定为“感受野”(卷积核空间轴)。使用
distribution="truncated_normal"
时,样本的绝对值在缩放之前被截断为 2 个标准差。使用
distribution="uniform"
时,样本从以下项中抽取:如果 dtype 为实数,则为均匀区间,或者
如果 dtype 为复数,则为均匀圆盘,
均值为零,标准差为 \(\sqrt{\frac{scale}{n}}\),其中 n 的定义如上所述。
- 参数
scale – 缩放因子(正浮点数)。
mode –
"fan_in"
、"fan_out"
和"fan_avg"
之一。distribution – 要使用的随机分布。
"truncated_normal"
、"normal"
和"uniform"
之一。in_axis – 权重数组中输入维度的轴或轴序列。
out_axis – 权重数组中输出维度的轴或轴序列。
batch_axis – 权重数组中应忽略的轴或轴序列。
dtype – 权重的 dtype。
- flax.linen.initializers.xavier_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
构建一个 Glorot 正态初始化器(又名 Xavier 正态初始化器)。
Glorot 正态初始化器 是
jax.nn.initializers.variance_scaling()
的一种特殊形式,其中scale = 1.0
,mode="fan_avg"
和distribution="truncated_normal"
。- 参数
in_axis – 权重数组中输入维度的轴或轴序列。
out_axis – 权重数组中输出维度的轴或轴序列。
batch_axis – 权重数组中应忽略的轴或轴序列。
dtype – 权重的 dtype。
- 返回值
一个初始化器。
示例
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.glorot_normal() >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[ 0.41770416, 0.75262755, 0.7619329 ], [-0.5516644 , -0.6028657 , 0.08661086]], dtype=float32)
- flax.linen.initializers.xavier_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
构建一个 Glorot 均匀初始化器(又名 Xavier 均匀初始化器)。
Glorot 均匀初始化器 是
jax.nn.initializers.variance_scaling()
的一种特殊形式,其中scale = 1.0
,mode="fan_avg"
和distribution="uniform"
。- 参数
in_axis – 权重数组中输入维度的轴或轴序列。
out_axis – 权重数组中输出维度的轴或轴序列。
batch_axis – 权重数组中应忽略的轴或轴序列。
dtype – 权重的 dtype。
- 返回值
一个初始化器。
示例
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.glorot_uniform() >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[ 0.50350785, 0.8088631 , 0.81566876], [-0.6393332 , -0.6865721 , 0.11003882]], dtype=float32)
- flax.linen.initializers.zeros(key, shape, dtype=<class 'jax.numpy.float64'>)#
一个返回充满零的常数数组的初始化器。
忽略
key
参数。>>> import jax, jax.numpy as jnp >>> jax.nn.initializers.zeros(jax.random.key(42), (2, 3), jnp.float32) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32)