🔪 Flax - 尖锐之处 🔪#
Flax 展现了 JAX 的全部力量。就像使用 JAX 一样,在使用 Flax 时您可能会遇到某些 “尖锐之处”。这份不断更新的文档旨在帮助您解决这些问题。
首先,安装和/或更新 Flax
! pip install -qq flax
🔪 flax.linen.Dropout
层和随机性#
要点总结#
在处理具有 dropout(从 Flax Module
继承)的模型时,仅在正向传播期间添加 'dropout'
PRNGkey。
从
jax.random.split()
开始,显式地为'params'
和'dropout'
创建 PRNG 密钥。将
flax.linen.Dropout
层添加到您的模型(从 FlaxModule
继承)。在初始化模型时(
flax.linen.init()
),无需传入额外的'dropout'
PRNG 密钥 — 就像在“更简单”的模型中一样,仅传入'params'
密钥。在用
flax.linen.apply()
进行正向传播期间,传入rngs={'dropout': dropout_key}
。
查看下面的完整示例。
为什么它有效#
在内部,
flax.linen.Dropout
利用flax.linen.Module.make_rng
为 dropout 创建密钥(查看 源代码)。每次调用
make_rng
时(在本例中,它是在Dropout
中隐式完成的),您都会获得一个从主/根 PRNG 密钥拆分出的新 PRNG 密钥。make_rng
仍然保证完全的可重复性。
背景#
dropout 随机正则化技术随机删除网络中的隐藏单元和可见单元。Dropout 是一种随机操作,需要 PRNG 状态,而 Flax(如 JAX)使用可拆分的 Threefry PRNG。
注意:回想一下,JAX 有一种显式的方法来为您提供 PRNG 密钥:您可以使用
key, subkey = jax.random.split(key)
将主 PRNG 状态(例如key = jax.random.key(seed=0)
)分成多个新的 PRNG 密钥。在 🔪 JAX - 尖锐之处 🔪 随机性和 PRNG 密钥中复习一下。
Flax 提供了一种通过 Flax Module
的 flax.linen.Module.make_rng
辅助函数处理 PRNG 密钥流的隐式方式。它允许 Flax Module
中的代码(或其子 Module
)“拉取 PRNG 密钥”。make_rng
保证每次调用时都提供唯一的密钥。有关更多详细信息,请参阅 RNG 指南。
注意:回想一下,
flax.linen.Module
是所有神经网络模块的基类。所有层和模型都从中继承。
示例#
请记住,每个 Flax PRNG 流都有一个名称。下面的示例使用 'params'
流来初始化参数,以及 'dropout'
流。提供给 flax.linen.init()
的 PRNG 密钥是为 'params'
PRNG 密钥流播种的密钥。要在正向传播期间(使用 dropout)提取 PRNG 密钥,请在调用 Module.apply()
时提供一个 PRNG 密钥来为该流 ('dropout'
) 播种。
# Setup.
import jax
import jax.numpy as jnp
import flax.linen as nn
# Randomness.
seed = 0
root_key = jax.random.key(seed=seed)
main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)
# A simple network.
class MyModel(nn.Module):
num_neurons: int
training: bool
@nn.compact
def __call__(self, x):
x = nn.Dense(self.num_neurons)(x)
# Set the dropout layer with a rate of 50% .
# When the `deterministic` flag is `True`, dropout is turned off.
x = nn.Dropout(rate=0.5, deterministic=not self.training)(x)
return x
# Instantiate `MyModel` (you don't need to set `training=True` to
# avoid performing the forward pass computation).
my_model = MyModel(num_neurons=3, training=False)
x = jax.random.uniform(key=main_key, shape=(3, 4, 4))
# Initialize with `flax.linen.init()`.
# The `params_key` is equivalent to a dictionary of PRNGs.
# (Here, you are providing only one PRNG key.)
variables = my_model.init(params_key, x)
# Perform the forward pass with `flax.linen.apply()`.
y = my_model.apply(variables, x, rngs={'dropout': dropout_key})
现实生活中的例子