Dropout#
本指南概述了如何使用 flax.linen.Dropout()
应用 dropout。
Dropout 是一种随机正则化技术,可以随机删除网络中的隐藏和可见单元。
在整个指南中,您将能够比较使用和不使用 Flax Dropout
的代码示例。
拆分 PRNG 密钥#
由于 dropout 是一种随机操作,因此它需要伪随机数生成器 (PRNG) 状态。Flax 使用 JAX 的(可拆分的)PRNG 密钥,这些密钥对于神经网络具有许多理想的属性。要了解更多信息,请参阅 JAX 中的伪随机数教程。
注意: 回想一下,JAX 有一种显式的方式为您提供 PRNG 密钥:您可以使用 key, subkey = jax.random.split(key)
将主 PRNG 状态(例如 key = jax.random.key(seed=0)
)分叉为多个新的 PRNG 密钥。您可以在 🔪 JAX - 尖锐的部分 🔪 随机性和 PRNG 密钥 中重新查看。
首先使用 jax.random.split() 将 PRNG 密钥拆分为三个密钥,包括一个用于 Flax Linen Dropout
的密钥。
root_key = jax.random.key(seed=0)
main_key, params_key = jax.random.split(key=root_key)
root_key = jax.random.key(seed=0)
main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)
注意: 在 Flax 中,您使用名称提供 PRNG 流,以便稍后可以在 flax.linen.Module()
中使用它们。例如,您传递流 'params'
以初始化参数,传递 'dropout'
以应用 flax.linen.Dropout()
。
使用 Dropout
定义您的模型#
要创建带 dropout 的模型
子类化
flax.linen.Module()
,然后使用flax.linen.Dropout()
添加 dropout 层。回想一下,flax.linen.Module()
是 所有神经网络模块的基类,所有层和模型都从中子类化。在
flax.linen.Dropout()
中,需要将deterministic
参数作为关键字参数传递,可以是在构造
flax.linen.Module()
时;或在对构造的
Module
调用flax.linen.init()
或flax.linen.apply()
时。(有关更多详细信息,请参阅flax.linen.module.merge_param()
。)
因为
deterministic
是一个布尔值如果将其设置为
False
,则使用rate
设置的概率来屏蔽(即设置为零)输入。其余输入按1 / (1 - rate)
缩放,这确保了输入均值得到保留。如果将其设置为
True
,则不应用掩码(dropout 关闭),并且输入按原样返回。
一种常见的模式是在父 Flax Module
中接受一个 training
(或 train
)参数(一个布尔值),并使用它来启用或禁用 dropout(如本指南的后续部分所示)。在其他机器学习框架(如 PyTorch 或 TensorFlow (Keras))中,这是通过可变状态或调用标志指定的(例如,在 torch.nn.Module.eval 中或通过设置 training 标志的 tf.keras.Model
中)。
注意: Flax 提供了一种通过 Flax flax.linen.Module()
的 flax.linen.Module.make_rng()
方法隐式处理 PRNG 密钥流的方法。这使您可以在 Flax 模块(或其子模块)内部从 PRNG 流中拆分出一个新的 PRNG 密钥。make_rng
方法保证每次调用时都提供唯一的密钥。在内部,flax.linen.Dropout()
使用 flax.linen.Module.make_rng()
创建 dropout 的密钥。您可以查看 源代码。简而言之,flax.linen.Module.make_rng()
保证完全可重复性。
class MyModel(nn.Module):
num_neurons: int
@nn.compact
def __call__(self, x):
x = nn.Dense(self.num_neurons)(x)
return x
class MyModel(nn.Module):
num_neurons: int
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(self.num_neurons)(x)
# Set the dropout layer with a `rate` of 50%.
# When the `deterministic` flag is `True`, dropout is turned off.
x = nn.Dropout(rate=0.5, deterministic=not training)(x)
return x
初始化模型#
创建模型后
实例化模型。
然后,在
flax.linen.init()
调用中,设置training=False
。最后,从变量字典中提取
params
。
这里,没有使用 Flax Dropout
和使用了 Dropout
的代码之间的主要区别在于,如果需要启用 dropout,则必须提供 training
(或 train
)参数。
my_model = MyModel(num_neurons=3)
x = jnp.empty((3, 4, 4))
variables = my_model.init(params_key, x)
params = variables['params']
my_model = MyModel(num_neurons=3)
x = jnp.empty((3, 4, 4))
# Dropout is disabled with `training=False` (that is, `deterministic=True`).
variables = my_model.init(params_key, x, training=False)
params = variables['params']
在训练期间执行前向传播#
当使用 flax.linen.apply()
运行模型时
将
training=True
传递给flax.linen.apply()
。然后,为了在前向传播期间(使用 dropout)绘制 PRNG 密钥,在调用
flax.linen.apply()
时,提供一个 PRNG 密钥来播种'dropout'
流。
# No need to pass the `training` and `rngs` flags.
y = my_model.apply({'params': params}, x)
# Dropout is enabled with `training=True` (that is, `deterministic=False`).
y = my_model.apply({'params': params}, x, training=True, rngs={'dropout': dropout_key})
这里,没有使用 Flax Dropout
和使用了 Dropout
的代码之间的主要区别在于,如果需要启用 dropout,则必须提供 training
(或 train
)和 rngs
参数。
在评估期间,使用上述代码,不启用 dropout (这意味着您也不必传递 RNG)。
TrainState
和训练步骤#
本节介绍如何在启用 dropout 的情况下修改训练步骤函数内部的代码。
注意: 回顾一下,Flax 有一个常见的模式,即创建一个数据类来表示整个训练状态,包括参数和优化器状态。然后,您可以将单个参数 state: TrainState
传递给训练步骤函数。有关更多信息,请参阅 flax.training.train_state.TrainState()
API 文档。
首先,向自定义的
flax.training.train_state.TrainState()
类添加一个key
字段。然后,将
key
值(在本例中为dropout_key
)传递给train_state.TrainState.create()
方法。
from flax.training import train_state
state = train_state.TrainState.create(
apply_fn=my_model.apply,
params=params,
tx=optax.adam(1e-3)
)
from flax.training import train_state
class TrainState(train_state.TrainState):
key: jax.Array
state = TrainState.create(
apply_fn=my_model.apply,
params=params,
key=dropout_key,
tx=optax.adam(1e-3)
)
接下来,在 Flax 训练步骤函数
train_step
中,从dropout_key
生成一个新的 PRNG 密钥,以便在每一步应用 dropout。这可以通过以下方式之一完成使用
jax.random.fold_in()
通常更快。当您使用jax.random.split()
时,您会拆分出一个可以稍后重用的 PRNG 密钥。但是,使用jax.random.fold_in()
可以确保 1) 折叠唯一数据;2) 可以产生更长的 PRNG 流序列。最后,在执行前向传播时,将新的 PRNG 密钥作为额外的参数传递给
state.apply_fn()
。
@jax.jit
def train_step(state: train_state.TrainState, batch):
def loss_fn(params):
logits = state.apply_fn(
{'params': params},
x=batch['image'],
)
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label'])
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
return state
@jax.jit
def train_step(state: TrainState, batch, dropout_key):
dropout_train_key = jax.random.fold_in(key=dropout_key, data=state.step)
def loss_fn(params):
logits = state.apply_fn(
{'params': params},
x=batch['image'],
training=True,
rngs={'dropout': dropout_train_key}
)
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label'])
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
return state
带有 dropout 的 Flax 示例#
一个在 WMT 机器翻译数据集上训练的基于 Transformer 的模型。此示例使用 dropout 和注意力 dropout。
在 文本分类 上下文中,将词 dropout 应用于一批输入 ID。此示例使用自定义的
flax.linen.Dropout()
层。