Dropout#

本指南概述了如何使用 flax.linen.Dropout() 应用 dropout

Dropout 是一种随机正则化技术,可以随机删除网络中的隐藏和可见单元。

在整个指南中,您将能够比较使用和不使用 Flax Dropout 的代码示例。

拆分 PRNG 密钥#

由于 dropout 是一种随机操作,因此它需要伪随机数生成器 (PRNG) 状态。Flax 使用 JAX 的(可拆分的)PRNG 密钥,这些密钥对于神经网络具有许多理想的属性。要了解更多信息,请参阅 JAX 中的伪随机数教程

注意: 回想一下,JAX 有一种显式的方式为您提供 PRNG 密钥:您可以使用 key, subkey = jax.random.split(key) 将主 PRNG 状态(例如 key = jax.random.key(seed=0))分叉为多个新的 PRNG 密钥。您可以在 🔪 JAX - 尖锐的部分 🔪 随机性和 PRNG 密钥 中重新查看。

首先使用 jax.random.split() 将 PRNG 密钥拆分为三个密钥,包括一个用于 Flax Linen Dropout 的密钥。

root_key = jax.random.key(seed=0)
main_key, params_key = jax.random.split(key=root_key)
root_key = jax.random.key(seed=0)
main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)

注意: 在 Flax 中,您使用名称提供 PRNG 流,以便稍后可以在 flax.linen.Module() 中使用它们。例如,您传递流 'params' 以初始化参数,传递 'dropout' 以应用 flax.linen.Dropout()

使用 Dropout 定义您的模型#

要创建带 dropout 的模型

一种常见的模式是在父 Flax Module 中接受一个 training(或 train)参数(一个布尔值),并使用它来启用或禁用 dropout(如本指南的后续部分所示)。在其他机器学习框架(如 PyTorch 或 TensorFlow (Keras))中,这是通过可变状态或调用标志指定的(例如,在 torch.nn.Module.eval 中或通过设置 training 标志的 tf.keras.Model 中)。

注意: Flax 提供了一种通过 Flax flax.linen.Module()flax.linen.Module.make_rng() 方法隐式处理 PRNG 密钥流的方法。这使您可以在 Flax 模块(或其子模块)内部从 PRNG 流中拆分出一个新的 PRNG 密钥。make_rng 方法保证每次调用时都提供唯一的密钥。在内部,flax.linen.Dropout() 使用 flax.linen.Module.make_rng() 创建 dropout 的密钥。您可以查看 源代码。简而言之,flax.linen.Module.make_rng() 保证完全可重复性

class MyModel(nn.Module):
  num_neurons: int

  @nn.compact
  def __call__(self, x):
    x = nn.Dense(self.num_neurons)(x)

    return x
class MyModel(nn.Module):
  num_neurons: int

  @nn.compact
  def __call__(self, x, training: bool):
    x = nn.Dense(self.num_neurons)(x)
    # Set the dropout layer with a `rate` of 50%.
    # When the `deterministic` flag is `True`, dropout is turned off.
    x = nn.Dropout(rate=0.5, deterministic=not training)(x)
    return x

初始化模型#

创建模型后

这里,没有使用 Flax Dropout 和使用了 Dropout 的代码之间的主要区别在于,如果需要启用 dropout,则必须提供 training (或 train)参数。

my_model = MyModel(num_neurons=3)
x = jnp.empty((3, 4, 4))

variables = my_model.init(params_key, x)
params = variables['params']
my_model = MyModel(num_neurons=3)
x = jnp.empty((3, 4, 4))
# Dropout is disabled with `training=False` (that is, `deterministic=True`).
variables = my_model.init(params_key, x, training=False)
params = variables['params']

在训练期间执行前向传播#

当使用 flax.linen.apply() 运行模型时

  • training=True 传递给 flax.linen.apply()

  • 然后,为了在前向传播期间(使用 dropout)绘制 PRNG 密钥,在调用 flax.linen.apply() 时,提供一个 PRNG 密钥来播种 'dropout' 流。

# No need to pass the `training` and `rngs` flags.
y = my_model.apply({'params': params}, x)
# Dropout is enabled with `training=True` (that is, `deterministic=False`).
y = my_model.apply({'params': params}, x, training=True, rngs={'dropout': dropout_key})

这里,没有使用 Flax Dropout 和使用了 Dropout 的代码之间的主要区别在于,如果需要启用 dropout,则必须提供 training (或 train)和 rngs 参数。

在评估期间,使用上述代码,不启用 dropout (这意味着您也不必传递 RNG)。

TrainState 和训练步骤#

本节介绍如何在启用 dropout 的情况下修改训练步骤函数内部的代码。

注意: 回顾一下,Flax 有一个常见的模式,即创建一个数据类来表示整个训练状态,包括参数和优化器状态。然后,您可以将单个参数 state: TrainState 传递给训练步骤函数。有关更多信息,请参阅 flax.training.train_state.TrainState() API 文档。

  • 首先,向自定义的 flax.training.train_state.TrainState() 类添加一个 key 字段。

  • 然后,将 key 值(在本例中为 dropout_key)传递给 train_state.TrainState.create() 方法。

from flax.training import train_state

state = train_state.TrainState.create(
  apply_fn=my_model.apply,
  params=params,

  tx=optax.adam(1e-3)
)
from flax.training import train_state

class TrainState(train_state.TrainState):
  key: jax.Array

state = TrainState.create(
  apply_fn=my_model.apply,
  params=params,
  key=dropout_key,
  tx=optax.adam(1e-3)
)

  • 接下来,在 Flax 训练步骤函数 train_step 中,从 dropout_key 生成一个新的 PRNG 密钥,以便在每一步应用 dropout。这可以通过以下方式之一完成

    使用 jax.random.fold_in() 通常更快。当您使用 jax.random.split() 时,您会拆分出一个可以稍后重用的 PRNG 密钥。但是,使用 jax.random.fold_in() 可以确保 1) 折叠唯一数据;2) 可以产生更长的 PRNG 流序列。

  • 最后,在执行前向传播时,将新的 PRNG 密钥作为额外的参数传递给 state.apply_fn()

@jax.jit
def train_step(state: train_state.TrainState, batch):

  def loss_fn(params):
    logits = state.apply_fn(
      {'params': params},
      x=batch['image'],


      )
    loss = optax.softmax_cross_entropy_with_integer_labels(
      logits=logits, labels=batch['label'])
    return loss, logits
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  return state
@jax.jit
def train_step(state: TrainState, batch, dropout_key):
  dropout_train_key = jax.random.fold_in(key=dropout_key, data=state.step)
  def loss_fn(params):
    logits = state.apply_fn(
      {'params': params},
      x=batch['image'],
      training=True,
      rngs={'dropout': dropout_train_key}
      )
    loss = optax.softmax_cross_entropy_with_integer_labels(
      logits=logits, labels=batch['label'])
    return loss, logits
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  return state

带有 dropout 的 Flax 示例#

更多使用模块 make_rng() 的 Flax 示例#