Flax 中的随机性和 PRNG#

在本指南中,您将了解 Flax 如何使用 JAX 的显式伪随机数生成器 (PRNG) 密钥 来模拟随机性,并添加一些其他功能,以便用户更容易地在不同的 Flax Module 中传递 PRNG 密钥。

如果您是 JAX PRNG 密钥的新手或需要复习,请查看

设置#

安装或升级 Flax,然后导入一些必要的依赖项。

注意:本指南使用 --xla_force_host_platform_device_count=8 标志,在 Google Colab/Jupyter Notebook 的 CPU 环境中模拟多个设备。如果您已经在 Google Cloud 或具有 TPU 的 Kaggle VM 等多设备 Google Cloud TPU 环境中使用,则不需要此标志。

!pip install -q flax
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
import flax, flax.linen as nn
import jax, jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map

import hashlib
jax.devices()
[CpuDevice(id=0),
 CpuDevice(id=1),
 CpuDevice(id=2),
 CpuDevice(id=3),
 CpuDevice(id=4),
 CpuDevice(id=5),
 CpuDevice(id=6),
 CpuDevice(id=7)]

将 JAX 配置变量 jax_threefry_partitionable 设置为 True。这将在未来成为默认值,并使 PRNG 在 jax.jit 下更有效地自动并行化。有关更多详细信息,请参阅 JAX 讨论

jax.config.update('jax_threefry_partitionable', True)
assert jax.config.jax_threefry_partitionable == True
assert jax.config.jax_default_prng_impl == 'threefry2x32'

使用 Module.make_rng 接收、操作和创建 PRNG 密钥#

Flax 用于接收、操作和创建 PRNG 密钥的主要方法是通过 Module 方法 self.make_rng。它是一个接受字符串名称的方法,该字符串名称表示“RNG 流”。每个 RNG 流都有一个初始起始种子 PRNG 密钥,用户将其作为字典参数传入(即传入 .init.apply 函数),起始种子由 self.make_rng 用于为该流生成更多 PRNG 密钥。如果 self.make_rng 在一个没有初始起始种子 PRNG 密钥的字符串名称上被调用(即用户没有将具有相应名称的密钥传入 .init.apply),则 self.make_rng 将默认使用 'params' 密钥作为初始起始种子。

请注意,此方法只能与有界模块一起调用(请参阅 Flax 模块生命周期)。

class RNGModule(nn.Module):
  @nn.compact
  def __call__(self):
    print(self.make_rng('rng_stream'))
    print(self.make_rng('rng_stream'))
    print(self.make_rng('rng_stream'))

rng_module = RNGModule()
variables = rng_module.init({'rng_stream': jax.random.key(0)})
Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
Array((), dtype=key<fry>) overlaying:
[2411773124 4124888837]

现在,如果我们使用不同的起始种子 PRNG 密钥,我们将生成不同的值(正如预期的那样)。

variables = rng_module.init({'rng_stream': jax.random.key(1)})
Array((), dtype=key<fry>) overlaying:
[3077990774 2166202870]
Array((), dtype=key<fry>) overlaying:
[3825832496 2886313970]
Array((), dtype=key<fry>) overlaying:
[ 791337683 1373966058]

为一个流调用 self.make_rng 不会影响从另一个流生成的随机值;即,调用顺序无关紧要。

class RNGModuleTwoStreams(nn.Module):
  @nn.compact
  def __call__(self):
    # same value as first code snippet above
    print(f"rng_stream1: {self.make_rng('rng_stream1')}")
    # same value as second code snippet above
    print(f"rng_stream2: {self.make_rng('rng_stream2')}")
    # same value as first code snippet above
    print(f"rng_stream1: {self.make_rng('rng_stream1')}")
    # same value as second code snippet above
    print(f"rng_stream2: {self.make_rng('rng_stream2')}")
    # same value as first code snippet above
    print(f"rng_stream1: {self.make_rng('rng_stream1')}")
    # same value as second code snippet above
    print(f"rng_stream2: {self.make_rng('rng_stream2')}")

rng_module_two_streams = RNGModuleTwoStreams()
variables = rng_module_two_streams.init(
  {'rng_stream1': jax.random.key(0), 'rng_stream2': jax.random.key(1)}
)
rng_stream1: Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
rng_stream2: Array((), dtype=key<fry>) overlaying:
[3077990774 2166202870]
rng_stream1: Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
rng_stream2: Array((), dtype=key<fry>) overlaying:
[3825832496 2886313970]
rng_stream1: Array((), dtype=key<fry>) overlaying:
[2411773124 4124888837]
rng_stream2: Array((), dtype=key<fry>) overlaying:
[ 791337683 1373966058]

提供相同的种子 PRNG 密钥将导致生成相同的值(前提是对这些密钥使用相同的操作)。

variables = rng_module_two_streams.init(
  {'rng_stream1': jax.random.key(0), 'rng_stream2': jax.random.key(0)}
)
rng_stream1: Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
rng_stream2: Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
rng_stream1: Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
rng_stream2: Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
rng_stream1: Array((), dtype=key<fry>) overlaying:
[2411773124 4124888837]
rng_stream2: Array((), dtype=key<fry>) overlaying:
[2411773124 4124888837]

self.make_rng 的底层工作原理#

这是调用 self.make_rng (flax.linen.Module.make_rng) 时发生的情况

  • 收集以下数据

    • self.scope.path 提供的 Module 的路径(顶层根模块具有空路径 ())。

    • self.make_rng 调用计数。也就是说,对于此特定流,self.make_rng 已被调用的次数(包括此调用)。

      • 注意:每个子 Module 都有自己的调用计数,该计数与其他 Module 分开。例如,一个调用 self.make_rng('params') 两次的 Module 包含一个调用 self.make_rng('params') 一次的子 Module,则对于每个 RNG 流 'params',其调用计数分别为 2 和 1。

  • 将数据捆绑到一个元组中,并馈送到哈希函数中,生成一个整数。

  • 将生成的整数折叠到 RNG 流的起始种子 PRNG 密钥中,以生成一个新的唯一 PRNG 密钥。

以下是 Flax 用于 self.make_rng 的哈希函数的稍微简化的版本

def produce_hash(data):
  m = hashlib.sha1()
  for x in data:
    if isinstance(x, str):
      m.update(x.encode('utf-8'))
    elif isinstance(x, int):
      m.update(x.to_bytes((x.bit_length() + 7) // 8, byteorder='big'))
    else:
      raise ValueError(f'Expected int or string, got: {x}')
  d = m.digest()
  hash_int = int.from_bytes(d[:4], byteorder='big')
  return hash_int

现在您可以手动重现从 self.make_rng 生成的 PRNG 密钥

stream_seed = jax.random.key(0)
for call_count in range(1, 4):
  hash_int = produce_hash(data=(call_count,))
  print(jax.random.fold_in(stream_seed, jnp.uint32(hash_int)))
Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
Array((), dtype=key<fry>) overlaying:
[2411773124 4124888837]
variables = rng_module.init({'rng_stream': jax.random.key(0)})
Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
Array((), dtype=key<fry>) overlaying:
[2411773124 4124888837]

Moduleself.make_rng#

本节探讨 self.make_rng (flax.linen.Module.make_rng) 在子 Module 中的行为方式。

考虑以下示例

class RNGSubSubModule(nn.Module):
  def __call__(self):
    print(f"{self.name}, count 1: {self.make_rng('rng_stream')}")
    print(f"{self.name}, count 2: {self.make_rng('rng_stream')}")

class RNGSubModule(nn.Module):
  @nn.compact
  def __call__(self):
    print(f"{self.name}, count 1: {self.make_rng('rng_stream')}")
    print(f"{self.name}, count 2: {self.make_rng('rng_stream')}")
    RNGSubSubModule()()

class RNGModule(nn.Module):
  @nn.compact
  def __call__(self):
    print(f"RNGModule, count 1: {self.make_rng('rng_stream')}")
    print(f"RNGModule, count 2: {self.make_rng('rng_stream')}")
    RNGSubModule()()

rng_module = RNGModule()
variables = rng_module.init({'rng_stream': jax.random.key(0)})
RNGModule, count 1: Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
RNGModule, count 2: Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
RNGSubModule_0, count 1: Array((), dtype=key<fry>) overlaying:
[3858825717 2323087578]
RNGSubModule_0, count 2: Array((), dtype=key<fry>) overlaying:
[ 601859108 3782857444]
RNGSubSubModule_0, count 1: Array((), dtype=key<fry>) overlaying:
[ 234240654 1028548813]
RNGSubSubModule_0, count 2: Array((), dtype=key<fry>) overlaying:
[3650462303 2124609379]

如前所述,馈送到 Flax 哈希函数的数据包括

  • self.scope.path 提供的 Module 的路径(顶层根模块具有空路径 ());以及

  • 特定 RNG 流的调用计数。

此外,请注意,即使对于相同的 RNG 流,每个 Flax Module 和子 Module 也都有其各自的调用计数。子 Module 名称的约定是:f'{module_name}_{module_number}'。例如,第一个 DenseModule 将被称为 Dense_0,第二个将被称为 Dense_1,依此类推。

因此,以下数据将馈送到哈希函数

  • 对于 RNGModule:由于根 Module 具有空路径,因此数据只是调用计数,例如 (1,)(2,)

  • 对于 RNGSubModule:数据为 ('RNGSubModule_0', 1)('RNGSubModule_0', 2)

  • 对于 RNGSubSubModule:数据为 ('RNGSubModule_0', 'RNGSubSubModule_0', 1)('RNGSubModule_0', 'RNGSubSubModule_0', 2)

通过这些数据,你可以使用 self.make_rng 手动复现从 Module 和子 Module 生成的 PRNG 密钥。

例如:

stream_seed = jax.random.key(0)
for initial_data in ((), ('RNGSubModule_0',), ('RNGSubModule_0', 'RNGSubSubModule_0')):
  if initial_data:
    module_name = initial_data[-1]
  else:
    module_name = 'RNGModule'
  for call_count in (1, 2):
    hash_int = produce_hash(data=initial_data+(call_count,))
    rng_key = jax.random.fold_in(stream_seed, jnp.uint32(hash_int))
    print(f"{module_name}, count {call_count}: {rng_key}")
RNGModule, count 1: Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
RNGModule, count 2: Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
RNGSubModule_0, count 1: Array((), dtype=key<fry>) overlaying:
[3858825717 2323087578]
RNGSubModule_0, count 2: Array((), dtype=key<fry>) overlaying:
[ 601859108 3782857444]
RNGSubSubModule_0, count 1: Array((), dtype=key<fry>) overlaying:
[ 234240654 1028548813]
RNGSubSubModule_0, count 2: Array((), dtype=key<fry>) overlaying:
[3650462303 2124609379]

如果同一个子 Module 类被多次使用,你可以相应地增加子 Module 名称的后缀。例如:RNGSubModule_0RNGSubModule_1 等。

class RNGSubModule(nn.Module):
  @nn.compact
  def __call__(self):
    print(f"{self.name}, count 1: {self.make_rng('rng_stream')}")
    print(f"{self.name}, count 2: {self.make_rng('rng_stream')}")

class RNGModule(nn.Module):
  @nn.compact
  def __call__(self):
    print(f"RNGModule, count 1: {self.make_rng('rng_stream')}")
    print(f"RNGModule, count 2: {self.make_rng('rng_stream')}")
    RNGSubModule()()
    RNGSubModule()()

rng_module = RNGModule()
variables = rng_module.init({'rng_stream': jax.random.key(0)})
RNGModule, count 1: Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
RNGModule, count 2: Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
RNGSubModule_0, count 1: Array((), dtype=key<fry>) overlaying:
[3858825717 2323087578]
RNGSubModule_0, count 2: Array((), dtype=key<fry>) overlaying:
[ 601859108 3782857444]
RNGSubModule_1, count 1: Array((), dtype=key<fry>) overlaying:
[ 426957352 2006350344]
RNGSubModule_1, count 2: Array((), dtype=key<fry>) overlaying:
[4006253729 4205356731]
stream_seed = jax.random.key(0)
for initial_data in ((), ('RNGSubModule_0',), ('RNGSubModule_1',)):
  if initial_data:
    module_name = initial_data[-1]
  else:
    module_name = 'RNGModule'
  for call_count in (1, 2):
    hash_int = produce_hash(data=initial_data+(call_count,))
    rng_key = jax.random.fold_in(stream_seed, jnp.uint32(hash_int))
    print(f"{module_name}, count {call_count}: {rng_key}")
RNGModule, count 1: Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
RNGModule, count 2: Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
RNGSubModule_0, count 1: Array((), dtype=key<fry>) overlaying:
[3858825717 2323087578]
RNGSubModule_0, count 2: Array((), dtype=key<fry>) overlaying:
[ 601859108 3782857444]
RNGSubModule_1, count 1: Array((), dtype=key<fry>) overlaying:
[ 426957352 2006350344]
RNGSubModule_1, count 2: Array((), dtype=key<fry>) overlaying:
[4006253729 4205356731]

使用 self.paramself.variable#

Flax 用户可以选择使用 self.paramself.variable Module 方法在其模块中创建额外的参数和变量。必须将 init_fn 参数传递给这些方法,以便它可以生成参数/变量的初始值。self.make_rng 通常在这个 init_fn 中隐式或显式地使用,因为许多初始化函数本质上是随机的,需要一个 PRNG 密钥。请参阅 此处 的 Flax 初始化器完整列表。

用户应注意这两种方法之间存在一些差异

  • self.param 始终在 'params' 集合中创建一个参数,而 self.variable 在用户指定的任何集合中创建一个变量

  • self.param 将自动调用 self.make_rng('params') 并将生成的 PRNG 密钥隐式传递给你实例化的参数的 init_fn(它将作为第一个参数传递),而用户必须在 self.variableinit_fn 中手动指定要调用 self.make_rng 的 RNG 流(它可以是 'params' 或其他内容)。

下面是一个同时使用 self.paramself.variable 的示例

class Model(nn.Module):
  @nn.compact
  def __call__(self, x):
    # kernel will use 'params' seed, initial data will include 'Dense_0', call count 1
    x = nn.Dense(2, kernel_init=jax.random.normal, use_bias=False)(x)
    # model_param will use 'params' seed, call count 1
    model_param = self.param('model_param', jax.random.normal, x.shape)
    # model_variable1 will use 'params' seed, call count 2
    model_variable1 = self.variable(
      'other_collection',
      'model_variable1',
      lambda: jax.random.normal(self.make_rng('params'), x.shape),
    )
    # model_variable2 will use 'other' seed, call count 1
    model_variable2 = self.variable(
      'other_collection',
      'model_variable2',
      lambda: jax.random.normal(self.make_rng('other'), x.shape),
    )
    # kernel will use 'params' seed, initial data will include 'Dense_1', call count 1
    # bias will use 'params' seed, initial data will include 'Dense_1', call count 2
    x = nn.Dense(2, kernel_init=jax.random.normal, bias_init=jax.random.normal)(
      x
    )
    return x

model = Model()
variables = model.init(
  {'params': jax.random.key(0), 'other': jax.random.key(1)}, jnp.ones((2, 2))
)
print(variables['params']['Dense_0']['kernel'])
print(variables['params']['model_param'])
print(variables['other_collection']['model_variable1'])
print(variables['other_collection']['model_variable2'])
print(variables['params']['Dense_1']['kernel'])
print(variables['params']['Dense_1']['bias'])
[[-1.6185919   0.700908  ]
 [-1.3146383  -0.79342234]]
[[ 0.0761425 -1.6157459]
 [-1.6857724  0.7126891]]
[[ 0.60175574  0.2553228 ]
 [ 0.27367848 -2.1975214 ]]
[[1.6249592  0.30813068]
 [1.6613585  1.0404155 ]]
[[ 0.0030665   0.29551846]
 [ 0.16670242 -0.78252524]]
[1.582462   0.15216611]

记住

  • 每个 RNG 流都有一个单独的计数;这就是为什么即使之前调用过 self.make_rng('params')self.make_rng('other') 的计数也从 1 开始的原因

  • 每个子模块对每个 rng 流都有自己的单独计数;这就是为什么每个 Dense 层都有自己单独的 self.make_rng('params') 计数,以及为什么 model_parammodel_variable1 共享相同的计数(因为它们是在同一个顶级父模块中定义的)

params_seed = jax.random.key(0)
other_seed = jax.random.key(1)
for initial_data, count, seed, shape in (
  (('Dense_0',), 1, params_seed, (2, 2)),
  ((), 1, params_seed, (2, 2)),
  ((), 2, params_seed, (2, 2)),
  ((), 1, other_seed, (2, 2)),
  (('Dense_1',), 1, params_seed, (2, 2)),
  (('Dense_1',), 2, params_seed, (1, 2)),
):
  hash_int = produce_hash(data=(*initial_data, count))
  rng_key = jax.random.fold_in(seed, jnp.uint32(hash_int))
  print(jax.random.normal(rng_key, shape))
[[-1.6185919   0.700908  ]
 [-1.3146383  -0.79342234]]
[[ 0.0761425 -1.6157459]
 [-1.6857724  0.7126891]]
[[ 0.60175574  0.2553228 ]
 [ 0.27367848 -2.1975214 ]]
[[1.6249592  0.30813068]
 [1.6613585  1.0404155 ]]
[[ 0.0030665   0.29551846]
 [ 0.16670242 -0.78252524]]
[[1.582462   0.15216611]]

在训练循环中管理 RNG 流#

下面是一个在训练循环中管理来自 self.make_rngself.paramself.variablenn.Dropout 的 RNG 流的示例(注意:nn.Dropout 需要在 'dropout' RNG 流中传入一个种子 PRNG 密钥,因为它会隐式调用 self.make_rng('dropout')

class SubModule(nn.Module):
  @nn.compact
  def __call__(self, x, train):
    # variables created using `self.param` will use `self.make_rng('params')`
    kernel = self.param('submodule_kernel', jax.random.normal, x.shape)
    x = x + kernel
    # `nn.Dropout` will use self.make_rng('dropout')
    x = nn.Dropout(0.2)(x, deterministic=not train)
    # `nn.Dense` will use self.make_rng('params')
    x = nn.Dense(3)(x)
    return x

class Model(nn.Module):
  @nn.compact
  def __call__(self, x, train):
    # make kernel use `self.make_rng('other')`
    kernel = self.variable(
      'other_collection',
      'module_kernel',
      lambda: jax.random.normal(self.make_rng('other'), x.shape),
    )
    x = (
      x + kernel.value
    )  # `.value` will extract the underlying value of the variable
    x = SubModule()(x, train)
    # `nn.Dropout` will use self.make_rng('dropout')
    x = nn.Dropout(0.2)(x, deterministic=not train)
    # `nn.Dense` will use self.make_rng('params')
    x = nn.Dense(2)(x)
    return x

params_rng, other_rng, train_rng = jax.random.split(jax.random.key(0), 3)
init_rngs = {'params': params_rng, 'other': other_rng}

x = jnp.ones((1, 3))
y = jnp.ones((1, 2))

module = Model()
variables = module.init(init_rngs, x, train=False)
def update(variables, rng):
  # we don't need to provide a 'params' or 'other' rng, as only 'dropout' rng will be used during training
  # split the rng to get a dropout_rng to be used for this training iteration,
  # and to get another rng key to be used for the next training iteration
  dropout_rng, next_rng = jax.random.split(rng)
  def loss(params):
    out = module.apply(
      {'params': params, 'other_collection': variables['other_collection']},
      x,
      train=True,
      rngs={'dropout': dropout_rng},
    )
    return jnp.mean((y - out) ** 2)
  grads = jax.grad(loss)(variables['params'])
  params = jax.tree_util.tree_map(lambda p, g: p - 1e-3 * g, variables['params'], grads)
  return {
    'params': params,
    'other_collection': variables['other_collection'],
  }, next_rng

for _ in range(10):
  variables, train_rng = update(variables, train_rng)
  out = module.apply(variables, x, train=False)
  print(jnp.mean((y - out)**2))
2.518454
2.4859657
2.4171872
2.412684
2.3435805
2.2773488
2.2592616
2.2009292
2.1839895
2.1707344

🔪 锋利的边缘 🔪 - 无意中生成相同的值#

存在一个边缘情况,可能会无意中生成相同的值。有关更多详细信息,请参阅 Flax 问题

class Leaf(nn.Module):
  def __call__(self, x):
    return x + jax.random.randint(self.make_rng("rng"), (), 0, 100)

class Node(nn.Module):
  leaf_name: str
  @nn.compact
  def __call__(self, x):
    return Leaf(name=self.leaf_name)(x)

class Model(nn.Module):
  @nn.compact
  def __call__(self, x):
    return (Node(name="ab", leaf_name="cdef")(x),
            Node(name="abc", leaf_name="def")(x),
    )

out1, out2 = Model().apply({}, 0, rngs={"rng": jax.random.key(33)})
out1 == out2 # same output, despite having different submodule names
Array(True, dtype=bool)

发生这种情况是因为哈希函数 将字符串连接在一起,因此当将数据输入到哈希函数时,数据 ('AB', 'C') 等同于数据 ('A', 'BC'),因此会生成相同的哈希 int。

print(produce_hash(data=('A', 'B', 'C', 1)))
print(produce_hash(data=('AB', 'C', 1)))
print(produce_hash(data=('A', 'BC', 1)))
print(produce_hash(data=('ABC', 1)))
947574064
947574064
947574064
947574064

为了避免这种情况,用户可以将 flax_fix_rng_separator 配置标志翻转为 True

flax.config.update('flax_fix_rng_separator', True)
out1, out2 = Model().apply({}, 0, rngs={"rng": jax.random.key(33)})
out1 == out2 # different output
Array(False, dtype=bool)

在多个设备上管理 RNG#

本节将展示如何在多设备设置中使用 jitshard_map 来使用 RNG 的示例。

使用 jax.jit#

当使用 jax.jit 时,我们可以像以前一样使用 RNG,但现在我们包括 in_shardingsout_shardings 参数来指定如何分片输入和输出数据。

有关在 Flax 中使用 jax.jit 在多个设备上进行训练的更多详细信息,请参阅我们的在多个设备上扩展 Flax 模块指南lm1b 示例

# Create a mesh and annotate the axis with a name.
device_mesh = mesh_utils.create_device_mesh((8,))
print(device_mesh)

mesh = Mesh(devices=device_mesh, axis_names=('data',))
print(mesh)

data_sharding = NamedSharding(mesh, PartitionSpec('data',))
print(data_sharding)
[CpuDevice(id=0) CpuDevice(id=1) CpuDevice(id=2) CpuDevice(id=3)
 CpuDevice(id=4) CpuDevice(id=5) CpuDevice(id=6) CpuDevice(id=7)]
Mesh('data': 8)
NamedSharding(mesh=Mesh('data': 8), spec=PartitionSpec('data',), memory_kind=unpinned_host)
class Model(nn.Module):
  @nn.compact
  def __call__(self, x, add_noise):
    x = nn.Dense(1)(x)
    # use jnp.where for control flow; for more details see: https://jax.net.cn/en/latest/errors.html#jax.errors.TracerBoolConversionError
    return jnp.where(
      add_noise, x + jax.random.normal(self.make_rng('params'), x.shape), x
    )

module = Model()
init_rng, apply_rng = jax.random.split(jax.random.key(0))
x = jnp.ones((8, 1))
variables = module.init(init_rng, x, False)

# create custom forward function, since jit does not support kwargs when in_shardings is specified
def forward(variables, x, add_noise, rng):
  return module.apply(variables, x, add_noise, rngs={'params': rng})

# shard the inputs x across devices
# replicate the variables, add_noise boolean and rng key across devices
# shard the output across devices
jit_forward = jax.jit(
  forward,
  in_shardings=(None, data_sharding, None, None),
  out_shardings=data_sharding,
)
out = jit_forward(variables, x, True, apply_rng)
out
Array([[-2.2187614 ],
       [-2.8055234 ],
       [-2.5464184 ],
       [ 1.027039  ],
       [-3.5243359 ],
       [-2.2795477 ],
       [-0.6504516 ],
       [ 0.17373265]], dtype=float32)

给定相同的输入,输出是不同的,这意味着 RNG 密钥被用来为输出添加噪声。

我们还可以确认输出在设备之间是分片的

out.addressable_shards
[Shard(device=CpuDevice(id=0), index=(slice(0, 1, None), slice(None, None, None)), replica_id=0, data=[[-2.2187614]]),
 Shard(device=CpuDevice(id=1), index=(slice(1, 2, None), slice(None, None, None)), replica_id=0, data=[[-2.8055234]]),
 Shard(device=CpuDevice(id=2), index=(slice(2, 3, None), slice(None, None, None)), replica_id=0, data=[[-2.5464184]]),
 Shard(device=CpuDevice(id=3), index=(slice(3, 4, None), slice(None, None, None)), replica_id=0, data=[[1.027039]]),
 Shard(device=CpuDevice(id=4), index=(slice(4, 5, None), slice(None, None, None)), replica_id=0, data=[[-3.5243359]]),
 Shard(device=CpuDevice(id=5), index=(slice(5, 6, None), slice(None, None, None)), replica_id=0, data=[[-2.2795477]]),
 Shard(device=CpuDevice(id=6), index=(slice(6, 7, None), slice(None, None, None)), replica_id=0, data=[[-0.6504516]]),
 Shard(device=CpuDevice(id=7), index=(slice(7, 8, None), slice(None, None, None)), replica_id=0, data=[[0.17373265]])]

可视化输出分片的另一种方式

jax.debug.visualize_array_sharding(out)
  CPU 0  
         
  CPU 1  
         
  CPU 2  
         
  CPU 3  
         
  CPU 4  
         
  CPU 5  
         
  CPU 6  
         
  CPU 7  
         

如果我们选择不添加噪声,那么所有批次的输出都是相同的(正如预期的那样,因为所有批次的输入都相同)

out = jit_forward(variables, x, False, apply_rng)
out
Array([[-1.2839764],
       [-1.2839764],
       [-1.2839764],
       [-1.2839764],
       [-1.2839764],
       [-1.2839764],
       [-1.2839764],
       [-1.2839764]], dtype=float32)

我们可以确认未进行 JIT 编译的函数产生相同的值,尽管是未分片的(请注意,由于 JIT 编译器的优化,可能存在小的数值差异)

out = forward(variables, x, True, apply_rng)
out
Array([[-2.2187614 ],
       [-2.8055234 ],
       [-2.5464187 ],
       [ 1.0270392 ],
       [-3.5243359 ],
       [-2.2795477 ],
       [-0.6504516 ],
       [ 0.17373264]], dtype=float32)
out = forward(variables, x, False, apply_rng)
out
Array([[-1.2839764],
       [-1.2839764],
       [-1.2839764],
       [-1.2839764],
       [-1.2839764],
       [-1.2839764],
       [-1.2839764],
       [-1.2839764]], dtype=float32)

使用 shard_map#

当使用 jax.experimental.shard_map.shard_map 时,要记住的重要部分是

  • 拆分你的 PRNG 密钥,以便为每个设备生成不同的密钥

  • PRNG 密钥将自动分片到每个设备(前提是你使用正确的分区规范),但 原始批处理 PRNG 密钥数组的秩不会减小;例如,对于一批 8 个 PRNG 密钥和 8 个设备,每个设备将在 shard_map 化的函数中看到大小为 1 的 PRNG 密钥批次

    • 因此,要访问 PRNG 密钥本身,我们需要对其进行索引切片(请参见下面的示例)

def forward(variables, x, add_noise, rng_key_batch):
  # rng_key_batch is a batch of size 1 containing 1 PRNG key
  # index slice into the rng_key_batch to access the PRNG key
  return module.apply(
    variables, x, add_noise, rngs={'params': rng_key_batch[0]}
  )

# define partition specifications
data_pspec = PartitionSpec('data')
no_pspec = PartitionSpec()

# shard the inputs x and rng keys across devices
# replicate the variables and add_noise boolean across devices
# shard the output across devices
shmap_forward = shard_map(
  forward,
  mesh=mesh,
  in_specs=(no_pspec, data_pspec, no_pspec, data_pspec),
  out_specs=data_pspec,
)
# get 8 different rng's that will be used by the 8 devices when doing forward inference
apply_rngs = jax.random.split(apply_rng, 8)
out = shmap_forward(variables, x, True, apply_rngs)
out
Array([[-1.2605132 ],
       [-1.2405176 ],
       [-0.99350417],
       [-1.0277128 ],
       [-1.4154483 ],
       [-0.3905797 ],
       [-2.417677  ],
       [ 0.9023453 ]], dtype=float32)

确认输出在设备之间是分片的

out.addressable_shards
[Shard(device=CpuDevice(id=0), index=(slice(0, 1, None), slice(None, None, None)), replica_id=0, data=[[-1.2605132]]),
 Shard(device=CpuDevice(id=1), index=(slice(1, 2, None), slice(None, None, None)), replica_id=0, data=[[-1.2405176]]),
 Shard(device=CpuDevice(id=2), index=(slice(2, 3, None), slice(None, None, None)), replica_id=0, data=[[-0.99350417]]),
 Shard(device=CpuDevice(id=3), index=(slice(3, 4, None), slice(None, None, None)), replica_id=0, data=[[-1.0277128]]),
 Shard(device=CpuDevice(id=4), index=(slice(4, 5, None), slice(None, None, None)), replica_id=0, data=[[-1.4154483]]),
 Shard(device=CpuDevice(id=5), index=(slice(5, 6, None), slice(None, None, None)), replica_id=0, data=[[-0.3905797]]),
 Shard(device=CpuDevice(id=6), index=(slice(6, 7, None), slice(None, None, None)), replica_id=0, data=[[-2.417677]]),
 Shard(device=CpuDevice(id=7), index=(slice(7, 8, None), slice(None, None, None)), replica_id=0, data=[[0.9023453]])]
jax.debug.visualize_array_sharding(out)
  CPU 0  
         
  CPU 1  
         
  CPU 2  
         
  CPU 3  
         
  CPU 4  
         
  CPU 5  
         
  CPU 6  
         
  CPU 7  
         

提升的变换#

Flax 提升的变换允许你将 JAX 变换Module 参数一起使用。本节将向你展示如何在 Flax 提升的变换中控制 PRNG 密钥的拆分方式。

有关更多详细信息,请参阅提升的变换

nn.vmap#

我们可以使用 nn.vmap 创建批处理的 Dense

x = jnp.ones((3, 2))

BatchDense = nn.vmap(
    nn.Dense,
    in_axes=0, out_axes=0,
    variable_axes={'params': None},
    split_rngs={'params': False})

BatchDense(2).init(jax.random.key(0), x)
{'params': {'bias': Array([0., 0.], dtype=float32),
  'kernel': Array([[-1.2488099 , -0.6127134 ],
         [-0.07084481,  0.60130936]], dtype=float32)}}

通过表示 variable_axes={'params': 0}',我们在第一个轴上向量化 params 数组。但是,生成的参数值彼此相同

BatchDense = nn.vmap(
    nn.Dense,
    in_axes=0, out_axes=0,
    variable_axes={'params': 0},
    split_rngs={'params': False})

BatchDense(2).init(jax.random.key(0), x)
{'params': {'bias': Array([[0., 0.],
         [0., 0.],
         [0., 0.]], dtype=float32),
  'kernel': Array([[[-1.2488099 , -0.6127134 ],
          [-0.07084481,  0.60130936]],
  
         [[-1.2488099 , -0.6127134 ],
          [-0.07084481,  0.60130936]],
  
         [[-1.2488099 , -0.6127134 ],
          [-0.07084481,  0.60130936]]], dtype=float32)}}

如果我们还使 split_rngs={'params': True},那么我们提供的 PRNG 密钥将在变量轴(在这种情况下为批处理轴 0)上拆分,并且我们可以为每个批处理输入生成不同的参数

BatchDense = nn.vmap(
    nn.Dense,
    in_axes=0, out_axes=0,
    variable_axes={'params': 0},
    split_rngs={'params': True})

BatchDense(2).init(jax.random.key(0), x)
{'params': {'bias': Array([[0., 0.],
         [0., 0.],
         [0., 0.]], dtype=float32),
  'kernel': Array([[[-0.2526208 , -0.15088455],
          [-1.1987205 , -0.40843305]],
  
         [[-0.7064888 , -1.108805  ],
          [-0.938775  ,  1.4812315 ]],
  
         [[-0.59468937, -0.2502723 ],
          [-1.33515   ,  0.5067442 ]]], dtype=float32)}}

通过 self.variable 添加变量很简单

class Model(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Dense(2)(x)
    kernel = self.variable(
      'other_collection',
      'kernel',
      lambda: jax.random.normal(self.make_rng('other'), x.shape),
    )
    return x + kernel.value

BatchModel = nn.vmap(
  Model,
  in_axes=0,
  out_axes=0,
  variable_axes={'params': 0, 'other_collection': 0},
  split_rngs={'params': True, 'other': True},
)

BatchModel().init({'params': jax.random.key(0), 'other': jax.random.key(1)}, x)
{'params': {'Dense_0': {'bias': Array([[0., 0.],
          [0., 0.],
          [0., 0.]], dtype=float32),
   'kernel': Array([[[-0.9079084 ,  0.76390624],
           [-0.01285526,  0.4320353 ]],
   
          [[ 0.12398645,  0.7884565 ],
           [ 1.5344163 ,  1.3186085 ]],
   
          [[-0.44171348,  0.43430036],
           [-0.40732604,  0.29774475]]], dtype=float32)}},
 'other_collection': {'kernel': Array([[-0.8193048 ,  0.711106  ],
         [-0.37802765, -0.66705877],
         [-0.44808003,  0.93031347]], dtype=float32)}}

我们可以控制要拆分的 RNG 流,例如,如果我们只想拆分 'params' RNG 流,那么从 self.variable 生成的变量对于每个批处理输入都将相同

BatchModel = nn.vmap(
    Model,
    in_axes=0, out_axes=0,
    variable_axes={'params': 0, 'other_collection': 0},
    split_rngs={'params': True, 'other': False})

BatchModel().init({'params': jax.random.key(0), 'other': jax.random.key(1)}, x)
{'params': {'Dense_0': {'bias': Array([[0., 0.],
          [0., 0.],
          [0., 0.]], dtype=float32),
   'kernel': Array([[[-0.9079084 ,  0.76390624],
           [-0.01285526,  0.4320353 ]],
   
          [[ 0.12398645,  0.7884565 ],
           [ 1.5344163 ,  1.3186085 ]],
   
          [[-0.44171348,  0.43430036],
           [-0.40732604,  0.29774475]]], dtype=float32)}},
 'other_collection': {'kernel': Array([[ 0.44956833, -1.1854612 ],
         [ 0.44956833, -1.1854612 ],
         [ 0.44956833, -1.1854612 ]], dtype=float32)}}

我们还可以控制应为每个批处理输入生成哪些参数/变量,例如,如果我们只希望 'params' 为每个批处理输入生成单独的参数

BatchModel = nn.vmap(
    Model,
    in_axes=0, out_axes=0,
    variable_axes={'params': 0, 'other_collection': None},
    split_rngs={'params': True, 'other': False})

BatchModel().init({'params': jax.random.key(0), 'other': jax.random.key(1)}, x)
{'params': {'Dense_0': {'bias': Array([[0., 0.],
          [0., 0.],
          [0., 0.]], dtype=float32),
   'kernel': Array([[[-0.9079084 ,  0.76390624],
           [-0.01285526,  0.4320353 ]],
   
          [[ 0.12398645,  0.7884565 ],
           [ 1.5344163 ,  1.3186085 ]],
   
          [[-0.44171348,  0.43430036],
           [-0.40732604,  0.29774475]]], dtype=float32)}},
 'other_collection': {'kernel': Array([ 0.44956833, -1.1854612 ], dtype=float32)}}

nn.scan#

我们可以使用 nn.scan 创建扫描的 Module 层(这对于简化重复堆叠的子模块很有用)

x = jnp.ones((3, 2))

class ResidualMLPBlock(nn.Module):
  @nn.compact
  def __call__(self, x, _):
    h = nn.Dense(features=2)(x)
    h = nn.relu(h)
    return x + h, None # return an empty carry

ScanMLP = nn.scan(
      ResidualMLPBlock, variable_axes={'params': 0},
      variable_broadcast=False, split_rngs={'params': True},
      length=3)

ScanMLP().init(jax.random.key(0), x, None) # pass in an empty carry
{'params': {'Dense_0': {'bias': Array([[0., 0.],
          [0., 0.],
          [0., 0.]], dtype=float32),
   'kernel': Array([[[-0.07838312, -0.7422982 ],
           [ 0.87488323,  0.13773395]],
   
          [[ 0.97309333,  0.9087693 ],
           [-0.12564984, -1.0920651 ]],
   
          [[-0.99055105,  1.1499453 ],
           [-0.15721127, -0.62520015]]], dtype=float32)}}}

与之前类似,我们可以控制是否拆分 RNG 流,例如,如果我们希望所有堆叠的模块都初始化为相同的参数值,则可以传入 split_rngs={'params': False}

ScanMLP = nn.scan(
      ResidualMLPBlock, variable_axes={'params': 0},
      variable_broadcast=False, split_rngs={'params': False},
      length=3)

ScanMLP().init(jax.random.key(0), x, None)
{'params': {'Dense_0': {'bias': Array([[0., 0.],
          [0., 0.],
          [0., 0.]], dtype=float32),
   'kernel': Array([[[-0.66715515, -0.0484313 ],
           [ 0.9867164 ,  0.75408363]],
   
          [[-0.66715515, -0.0484313 ],
           [ 0.9867164 ,  0.75408363]],
   
          [[-0.66715515, -0.0484313 ],
           [ 0.9867164 ,  0.75408363]]], dtype=float32)}}}