🔪 Flax - 尖锐之处 🔪#

Flax 展现了 JAX 的全部力量。就像使用 JAX 一样,在使用 Flax 时您可能会遇到某些 “尖锐之处”。这份不断更新的文档旨在帮助您解决这些问题。

首先,安装和/或更新 Flax

! pip install -qq flax

🔪 flax.linen.Dropout 层和随机性#

要点总结#

在处理具有 dropout(从 Flax Module 继承)的模型时,仅在正向传播期间添加 'dropout' PRNGkey。

  1. jax.random.split() 开始,显式地为 'params''dropout' 创建 PRNG 密钥。

  2. flax.linen.Dropout 层添加到您的模型(从 Flax Module 继承)。

  3. 在初始化模型时(flax.linen.init()),无需传入额外的 'dropout' PRNG 密钥 — 就像在“更简单”的模型中一样,仅传入 'params' 密钥。

  4. 在用 flax.linen.apply() 进行正向传播期间,传入 rngs={'dropout': dropout_key}

查看下面的完整示例。

为什么它有效#

  • 在内部,flax.linen.Dropout 利用 flax.linen.Module.make_rng 为 dropout 创建密钥(查看 源代码)。

  • 每次调用 make_rng 时(在本例中,它是在 Dropout 中隐式完成的),您都会获得一个从主/根 PRNG 密钥拆分出的新 PRNG 密钥。

  • make_rng 仍然保证完全的可重复性

背景#

dropout 随机正则化技术随机删除网络中的隐藏单元和可见单元。Dropout 是一种随机操作,需要 PRNG 状态,而 Flax(如 JAX)使用可拆分的 Threefry PRNG。

注意:回想一下,JAX 有一种显式的方法来为您提供 PRNG 密钥:您可以使用 key, subkey = jax.random.split(key) 将主 PRNG 状态(例如 key = jax.random.key(seed=0))分成多个新的 PRNG 密钥。在 🔪 JAX - 尖锐之处 🔪 随机性和 PRNG 密钥中复习一下。

Flax 提供了一种通过 Flax Moduleflax.linen.Module.make_rng 辅助函数处理 PRNG 密钥流的隐式方式。它允许 Flax Module 中的代码(或其子 Module)“拉取 PRNG 密钥”。make_rng 保证每次调用时都提供唯一的密钥。有关更多详细信息,请参阅 RNG 指南

注意:回想一下,flax.linen.Module 是所有神经网络模块的基类。所有层和模型都从中继承。

示例#

请记住,每个 Flax PRNG 流都有一个名称。下面的示例使用 'params' 流来初始化参数,以及 'dropout' 流。提供给 flax.linen.init() 的 PRNG 密钥是为 'params' PRNG 密钥流播种的密钥。要在正向传播期间(使用 dropout)提取 PRNG 密钥,请在调用 Module.apply() 时提供一个 PRNG 密钥来为该流 ('dropout') 播种。

# Setup.
import jax
import jax.numpy as jnp
import flax.linen as nn
# Randomness.
seed = 0
root_key = jax.random.key(seed=seed)
main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)

# A simple network.
class MyModel(nn.Module):
  num_neurons: int
  training: bool
  @nn.compact
  def __call__(self, x):
    x = nn.Dense(self.num_neurons)(x)
    # Set the dropout layer with a rate of 50% .
    # When the `deterministic` flag is `True`, dropout is turned off.
    x = nn.Dropout(rate=0.5, deterministic=not self.training)(x)
    return x

# Instantiate `MyModel` (you don't need to set `training=True` to
# avoid performing the forward pass computation).
my_model = MyModel(num_neurons=3, training=False)

x = jax.random.uniform(key=main_key, shape=(3, 4, 4))

# Initialize with `flax.linen.init()`.
# The `params_key` is equivalent to a dictionary of PRNGs.
# (Here, you are providing only one PRNG key.) 
variables = my_model.init(params_key, x)

# Perform the forward pass with `flax.linen.apply()`.
y = my_model.apply(variables, x, rngs={'dropout': dropout_key})

现实生活中的例子