Flax 中的随机性和 PRNG#
在本指南中,您将了解 Flax 如何使用 JAX 的显式伪随机数生成器 (PRNG) 密钥 来模拟随机性,并添加一些其他功能,以便用户更容易地在不同的 Flax Module
中传递 PRNG 密钥。
如果您是 JAX PRNG 密钥的新手或需要复习,请查看
设置#
安装或升级 Flax,然后导入一些必要的依赖项。
注意:本指南使用 --xla_force_host_platform_device_count=8
标志,在 Google Colab/Jupyter Notebook 的 CPU 环境中模拟多个设备。如果您已经在 Google Cloud 或具有 TPU 的 Kaggle VM 等多设备 Google Cloud TPU 环境中使用,则不需要此标志。
!pip install -q flax
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
import flax, flax.linen as nn
import jax, jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map
import hashlib
jax.devices()
[CpuDevice(id=0),
CpuDevice(id=1),
CpuDevice(id=2),
CpuDevice(id=3),
CpuDevice(id=4),
CpuDevice(id=5),
CpuDevice(id=6),
CpuDevice(id=7)]
将 JAX 配置变量 jax_threefry_partitionable
设置为 True
。这将在未来成为默认值,并使 PRNG 在 jax.jit
下更有效地自动并行化。有关更多详细信息,请参阅 JAX 讨论。
jax.config.update('jax_threefry_partitionable', True)
assert jax.config.jax_threefry_partitionable == True
assert jax.config.jax_default_prng_impl == 'threefry2x32'
使用 Module.make_rng
接收、操作和创建 PRNG 密钥#
Flax 用于接收、操作和创建 PRNG 密钥的主要方法是通过 Module
方法 self.make_rng
。它是一个接受字符串名称的方法,该字符串名称表示“RNG 流”。每个 RNG 流都有一个初始起始种子 PRNG 密钥,用户将其作为字典参数传入(即传入 .init
或 .apply
函数),起始种子由 self.make_rng
用于为该流生成更多 PRNG 密钥。如果 self.make_rng
在一个没有初始起始种子 PRNG 密钥的字符串名称上被调用(即用户没有将具有相应名称的密钥传入 .init
或 .apply
),则 self.make_rng
将默认使用 'params'
密钥作为初始起始种子。
请注意,此方法只能与有界模块一起调用(请参阅 Flax 模块生命周期)。
class RNGModule(nn.Module):
@nn.compact
def __call__(self):
print(self.make_rng('rng_stream'))
print(self.make_rng('rng_stream'))
print(self.make_rng('rng_stream'))
rng_module = RNGModule()
variables = rng_module.init({'rng_stream': jax.random.key(0)})
Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
Array((), dtype=key<fry>) overlaying:
[2411773124 4124888837]
现在,如果我们使用不同的起始种子 PRNG 密钥,我们将生成不同的值(正如预期的那样)。
variables = rng_module.init({'rng_stream': jax.random.key(1)})
Array((), dtype=key<fry>) overlaying:
[3077990774 2166202870]
Array((), dtype=key<fry>) overlaying:
[3825832496 2886313970]
Array((), dtype=key<fry>) overlaying:
[ 791337683 1373966058]
为一个流调用 self.make_rng
不会影响从另一个流生成的随机值;即,调用顺序无关紧要。
class RNGModuleTwoStreams(nn.Module):
@nn.compact
def __call__(self):
# same value as first code snippet above
print(f"rng_stream1: {self.make_rng('rng_stream1')}")
# same value as second code snippet above
print(f"rng_stream2: {self.make_rng('rng_stream2')}")
# same value as first code snippet above
print(f"rng_stream1: {self.make_rng('rng_stream1')}")
# same value as second code snippet above
print(f"rng_stream2: {self.make_rng('rng_stream2')}")
# same value as first code snippet above
print(f"rng_stream1: {self.make_rng('rng_stream1')}")
# same value as second code snippet above
print(f"rng_stream2: {self.make_rng('rng_stream2')}")
rng_module_two_streams = RNGModuleTwoStreams()
variables = rng_module_two_streams.init(
{'rng_stream1': jax.random.key(0), 'rng_stream2': jax.random.key(1)}
)
rng_stream1: Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
rng_stream2: Array((), dtype=key<fry>) overlaying:
[3077990774 2166202870]
rng_stream1: Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
rng_stream2: Array((), dtype=key<fry>) overlaying:
[3825832496 2886313970]
rng_stream1: Array((), dtype=key<fry>) overlaying:
[2411773124 4124888837]
rng_stream2: Array((), dtype=key<fry>) overlaying:
[ 791337683 1373966058]
提供相同的种子 PRNG 密钥将导致生成相同的值(前提是对这些密钥使用相同的操作)。
variables = rng_module_two_streams.init(
{'rng_stream1': jax.random.key(0), 'rng_stream2': jax.random.key(0)}
)
rng_stream1: Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
rng_stream2: Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
rng_stream1: Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
rng_stream2: Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
rng_stream1: Array((), dtype=key<fry>) overlaying:
[2411773124 4124888837]
rng_stream2: Array((), dtype=key<fry>) overlaying:
[2411773124 4124888837]
self.make_rng
的底层工作原理#
这是调用 self.make_rng
(flax.linen.Module.make_rng
) 时发生的情况
收集以下数据
由
self.scope.path
提供的Module
的路径(顶层根模块具有空路径()
)。self.make_rng
调用计数。也就是说,对于此特定流,self.make_rng
已被调用的次数(包括此调用)。注意:每个子
Module
都有自己的调用计数,该计数与其他Module
分开。例如,一个调用self.make_rng('params')
两次的Module
包含一个调用self.make_rng('params')
一次的子Module
,则对于每个 RNG 流'params'
,其调用计数分别为 2 和 1。
将数据捆绑到一个元组中,并馈送到哈希函数中,生成一个整数。
将生成的整数折叠到 RNG 流的起始种子 PRNG 密钥中,以生成一个新的唯一 PRNG 密钥。
以下是 Flax 用于 self.make_rng
的哈希函数的稍微简化的版本
def produce_hash(data):
m = hashlib.sha1()
for x in data:
if isinstance(x, str):
m.update(x.encode('utf-8'))
elif isinstance(x, int):
m.update(x.to_bytes((x.bit_length() + 7) // 8, byteorder='big'))
else:
raise ValueError(f'Expected int or string, got: {x}')
d = m.digest()
hash_int = int.from_bytes(d[:4], byteorder='big')
return hash_int
现在您可以手动重现从 self.make_rng
生成的 PRNG 密钥
stream_seed = jax.random.key(0)
for call_count in range(1, 4):
hash_int = produce_hash(data=(call_count,))
print(jax.random.fold_in(stream_seed, jnp.uint32(hash_int)))
Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
Array((), dtype=key<fry>) overlaying:
[2411773124 4124888837]
variables = rng_module.init({'rng_stream': jax.random.key(0)})
Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
Array((), dtype=key<fry>) overlaying:
[2411773124 4124888837]
子 Module
和 self.make_rng
#
本节探讨 self.make_rng
(flax.linen.Module.make_rng
) 在子 Module
中的行为方式。
考虑以下示例
class RNGSubSubModule(nn.Module):
def __call__(self):
print(f"{self.name}, count 1: {self.make_rng('rng_stream')}")
print(f"{self.name}, count 2: {self.make_rng('rng_stream')}")
class RNGSubModule(nn.Module):
@nn.compact
def __call__(self):
print(f"{self.name}, count 1: {self.make_rng('rng_stream')}")
print(f"{self.name}, count 2: {self.make_rng('rng_stream')}")
RNGSubSubModule()()
class RNGModule(nn.Module):
@nn.compact
def __call__(self):
print(f"RNGModule, count 1: {self.make_rng('rng_stream')}")
print(f"RNGModule, count 2: {self.make_rng('rng_stream')}")
RNGSubModule()()
rng_module = RNGModule()
variables = rng_module.init({'rng_stream': jax.random.key(0)})
RNGModule, count 1: Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
RNGModule, count 2: Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
RNGSubModule_0, count 1: Array((), dtype=key<fry>) overlaying:
[3858825717 2323087578]
RNGSubModule_0, count 2: Array((), dtype=key<fry>) overlaying:
[ 601859108 3782857444]
RNGSubSubModule_0, count 1: Array((), dtype=key<fry>) overlaying:
[ 234240654 1028548813]
RNGSubSubModule_0, count 2: Array((), dtype=key<fry>) overlaying:
[3650462303 2124609379]
如前所述,馈送到 Flax 哈希函数的数据包括
由
self.scope.path
提供的Module
的路径(顶层根模块具有空路径()
);以及特定 RNG 流的调用计数。
此外,请注意,即使对于相同的 RNG 流,每个 Flax Module
和子 Module
也都有其各自的调用计数。子 Module
名称的约定是:f'{module_name}_{module_number}'
。例如,第一个 Dense
子 Module
将被称为 Dense_0
,第二个将被称为 Dense_1
,依此类推。
因此,以下数据将馈送到哈希函数
对于
RNGModule
:由于根Module
具有空路径,因此数据只是调用计数,例如(1,)
和(2,)
。对于
RNGSubModule
:数据为('RNGSubModule_0', 1)
和('RNGSubModule_0', 2)
。对于
RNGSubSubModule
:数据为('RNGSubModule_0', 'RNGSubSubModule_0', 1)
和('RNGSubModule_0', 'RNGSubSubModule_0', 2)
。
通过这些数据,你可以使用 self.make_rng
手动复现从 Module
和子 Module
生成的 PRNG 密钥。
例如:
stream_seed = jax.random.key(0)
for initial_data in ((), ('RNGSubModule_0',), ('RNGSubModule_0', 'RNGSubSubModule_0')):
if initial_data:
module_name = initial_data[-1]
else:
module_name = 'RNGModule'
for call_count in (1, 2):
hash_int = produce_hash(data=initial_data+(call_count,))
rng_key = jax.random.fold_in(stream_seed, jnp.uint32(hash_int))
print(f"{module_name}, count {call_count}: {rng_key}")
RNGModule, count 1: Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
RNGModule, count 2: Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
RNGSubModule_0, count 1: Array((), dtype=key<fry>) overlaying:
[3858825717 2323087578]
RNGSubModule_0, count 2: Array((), dtype=key<fry>) overlaying:
[ 601859108 3782857444]
RNGSubSubModule_0, count 1: Array((), dtype=key<fry>) overlaying:
[ 234240654 1028548813]
RNGSubSubModule_0, count 2: Array((), dtype=key<fry>) overlaying:
[3650462303 2124609379]
如果同一个子 Module
类被多次使用,你可以相应地增加子 Module
名称的后缀。例如:RNGSubModule_0
、RNGSubModule_1
等。
class RNGSubModule(nn.Module):
@nn.compact
def __call__(self):
print(f"{self.name}, count 1: {self.make_rng('rng_stream')}")
print(f"{self.name}, count 2: {self.make_rng('rng_stream')}")
class RNGModule(nn.Module):
@nn.compact
def __call__(self):
print(f"RNGModule, count 1: {self.make_rng('rng_stream')}")
print(f"RNGModule, count 2: {self.make_rng('rng_stream')}")
RNGSubModule()()
RNGSubModule()()
rng_module = RNGModule()
variables = rng_module.init({'rng_stream': jax.random.key(0)})
RNGModule, count 1: Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
RNGModule, count 2: Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
RNGSubModule_0, count 1: Array((), dtype=key<fry>) overlaying:
[3858825717 2323087578]
RNGSubModule_0, count 2: Array((), dtype=key<fry>) overlaying:
[ 601859108 3782857444]
RNGSubModule_1, count 1: Array((), dtype=key<fry>) overlaying:
[ 426957352 2006350344]
RNGSubModule_1, count 2: Array((), dtype=key<fry>) overlaying:
[4006253729 4205356731]
stream_seed = jax.random.key(0)
for initial_data in ((), ('RNGSubModule_0',), ('RNGSubModule_1',)):
if initial_data:
module_name = initial_data[-1]
else:
module_name = 'RNGModule'
for call_count in (1, 2):
hash_int = produce_hash(data=initial_data+(call_count,))
rng_key = jax.random.fold_in(stream_seed, jnp.uint32(hash_int))
print(f"{module_name}, count {call_count}: {rng_key}")
RNGModule, count 1: Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
RNGModule, count 2: Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
RNGSubModule_0, count 1: Array((), dtype=key<fry>) overlaying:
[3858825717 2323087578]
RNGSubModule_0, count 2: Array((), dtype=key<fry>) overlaying:
[ 601859108 3782857444]
RNGSubModule_1, count 1: Array((), dtype=key<fry>) overlaying:
[ 426957352 2006350344]
RNGSubModule_1, count 2: Array((), dtype=key<fry>) overlaying:
[4006253729 4205356731]
使用 self.param
和 self.variable
#
Flax 用户可以选择使用 self.param
和 self.variable
Module
方法在其模块中创建额外的参数和变量。必须将 init_fn
参数传递给这些方法,以便它可以生成参数/变量的初始值。self.make_rng
通常在这个 init_fn
中隐式或显式地使用,因为许多初始化函数本质上是随机的,需要一个 PRNG 密钥。请参阅 此处 的 Flax 初始化器完整列表。
用户应注意这两种方法之间存在一些差异
self.param
始终在'params'
集合中创建一个参数,而self.variable
在用户指定的任何集合中创建一个变量self.param
将自动调用self.make_rng('params')
并将生成的 PRNG 密钥隐式传递给你实例化的参数的init_fn
(它将作为第一个参数传递),而用户必须在self.variable
的init_fn
中手动指定要调用self.make_rng
的 RNG 流(它可以是'params'
或其他内容)。
下面是一个同时使用 self.param
和 self.variable
的示例
class Model(nn.Module):
@nn.compact
def __call__(self, x):
# kernel will use 'params' seed, initial data will include 'Dense_0', call count 1
x = nn.Dense(2, kernel_init=jax.random.normal, use_bias=False)(x)
# model_param will use 'params' seed, call count 1
model_param = self.param('model_param', jax.random.normal, x.shape)
# model_variable1 will use 'params' seed, call count 2
model_variable1 = self.variable(
'other_collection',
'model_variable1',
lambda: jax.random.normal(self.make_rng('params'), x.shape),
)
# model_variable2 will use 'other' seed, call count 1
model_variable2 = self.variable(
'other_collection',
'model_variable2',
lambda: jax.random.normal(self.make_rng('other'), x.shape),
)
# kernel will use 'params' seed, initial data will include 'Dense_1', call count 1
# bias will use 'params' seed, initial data will include 'Dense_1', call count 2
x = nn.Dense(2, kernel_init=jax.random.normal, bias_init=jax.random.normal)(
x
)
return x
model = Model()
variables = model.init(
{'params': jax.random.key(0), 'other': jax.random.key(1)}, jnp.ones((2, 2))
)
print(variables['params']['Dense_0']['kernel'])
print(variables['params']['model_param'])
print(variables['other_collection']['model_variable1'])
print(variables['other_collection']['model_variable2'])
print(variables['params']['Dense_1']['kernel'])
print(variables['params']['Dense_1']['bias'])
[[-1.6185919 0.700908 ]
[-1.3146383 -0.79342234]]
[[ 0.0761425 -1.6157459]
[-1.6857724 0.7126891]]
[[ 0.60175574 0.2553228 ]
[ 0.27367848 -2.1975214 ]]
[[1.6249592 0.30813068]
[1.6613585 1.0404155 ]]
[[ 0.0030665 0.29551846]
[ 0.16670242 -0.78252524]]
[1.582462 0.15216611]
记住
每个 RNG 流都有一个单独的计数;这就是为什么即使之前调用过
self.make_rng('params')
,self.make_rng('other')
的计数也从 1 开始的原因每个子模块对每个 rng 流都有自己的单独计数;这就是为什么每个
Dense
层都有自己单独的self.make_rng('params')
计数,以及为什么model_param
和model_variable1
共享相同的计数(因为它们是在同一个顶级父模块中定义的)
params_seed = jax.random.key(0)
other_seed = jax.random.key(1)
for initial_data, count, seed, shape in (
(('Dense_0',), 1, params_seed, (2, 2)),
((), 1, params_seed, (2, 2)),
((), 2, params_seed, (2, 2)),
((), 1, other_seed, (2, 2)),
(('Dense_1',), 1, params_seed, (2, 2)),
(('Dense_1',), 2, params_seed, (1, 2)),
):
hash_int = produce_hash(data=(*initial_data, count))
rng_key = jax.random.fold_in(seed, jnp.uint32(hash_int))
print(jax.random.normal(rng_key, shape))
[[-1.6185919 0.700908 ]
[-1.3146383 -0.79342234]]
[[ 0.0761425 -1.6157459]
[-1.6857724 0.7126891]]
[[ 0.60175574 0.2553228 ]
[ 0.27367848 -2.1975214 ]]
[[1.6249592 0.30813068]
[1.6613585 1.0404155 ]]
[[ 0.0030665 0.29551846]
[ 0.16670242 -0.78252524]]
[[1.582462 0.15216611]]
在训练循环中管理 RNG 流#
下面是一个在训练循环中管理来自 self.make_rng
、self.param
、self.variable
和 nn.Dropout
的 RNG 流的示例(注意:nn.Dropout
需要在 'dropout'
RNG 流中传入一个种子 PRNG 密钥,因为它会隐式调用 self.make_rng('dropout')
)
class SubModule(nn.Module):
@nn.compact
def __call__(self, x, train):
# variables created using `self.param` will use `self.make_rng('params')`
kernel = self.param('submodule_kernel', jax.random.normal, x.shape)
x = x + kernel
# `nn.Dropout` will use self.make_rng('dropout')
x = nn.Dropout(0.2)(x, deterministic=not train)
# `nn.Dense` will use self.make_rng('params')
x = nn.Dense(3)(x)
return x
class Model(nn.Module):
@nn.compact
def __call__(self, x, train):
# make kernel use `self.make_rng('other')`
kernel = self.variable(
'other_collection',
'module_kernel',
lambda: jax.random.normal(self.make_rng('other'), x.shape),
)
x = (
x + kernel.value
) # `.value` will extract the underlying value of the variable
x = SubModule()(x, train)
# `nn.Dropout` will use self.make_rng('dropout')
x = nn.Dropout(0.2)(x, deterministic=not train)
# `nn.Dense` will use self.make_rng('params')
x = nn.Dense(2)(x)
return x
params_rng, other_rng, train_rng = jax.random.split(jax.random.key(0), 3)
init_rngs = {'params': params_rng, 'other': other_rng}
x = jnp.ones((1, 3))
y = jnp.ones((1, 2))
module = Model()
variables = module.init(init_rngs, x, train=False)
def update(variables, rng):
# we don't need to provide a 'params' or 'other' rng, as only 'dropout' rng will be used during training
# split the rng to get a dropout_rng to be used for this training iteration,
# and to get another rng key to be used for the next training iteration
dropout_rng, next_rng = jax.random.split(rng)
def loss(params):
out = module.apply(
{'params': params, 'other_collection': variables['other_collection']},
x,
train=True,
rngs={'dropout': dropout_rng},
)
return jnp.mean((y - out) ** 2)
grads = jax.grad(loss)(variables['params'])
params = jax.tree_util.tree_map(lambda p, g: p - 1e-3 * g, variables['params'], grads)
return {
'params': params,
'other_collection': variables['other_collection'],
}, next_rng
for _ in range(10):
variables, train_rng = update(variables, train_rng)
out = module.apply(variables, x, train=False)
print(jnp.mean((y - out)**2))
2.518454
2.4859657
2.4171872
2.412684
2.3435805
2.2773488
2.2592616
2.2009292
2.1839895
2.1707344
🔪 锋利的边缘 🔪 - 无意中生成相同的值#
存在一个边缘情况,可能会无意中生成相同的值。有关更多详细信息,请参阅 Flax 问题。
class Leaf(nn.Module):
def __call__(self, x):
return x + jax.random.randint(self.make_rng("rng"), (), 0, 100)
class Node(nn.Module):
leaf_name: str
@nn.compact
def __call__(self, x):
return Leaf(name=self.leaf_name)(x)
class Model(nn.Module):
@nn.compact
def __call__(self, x):
return (Node(name="ab", leaf_name="cdef")(x),
Node(name="abc", leaf_name="def")(x),
)
out1, out2 = Model().apply({}, 0, rngs={"rng": jax.random.key(33)})
out1 == out2 # same output, despite having different submodule names
Array(True, dtype=bool)
发生这种情况是因为哈希函数 将字符串连接在一起,因此当将数据输入到哈希函数时,数据 ('AB', 'C')
等同于数据 ('A', 'BC')
,因此会生成相同的哈希 int。
print(produce_hash(data=('A', 'B', 'C', 1)))
print(produce_hash(data=('AB', 'C', 1)))
print(produce_hash(data=('A', 'BC', 1)))
print(produce_hash(data=('ABC', 1)))
947574064
947574064
947574064
947574064
为了避免这种情况,用户可以将 flax_fix_rng_separator
配置标志翻转为 True
。
flax.config.update('flax_fix_rng_separator', True)
out1, out2 = Model().apply({}, 0, rngs={"rng": jax.random.key(33)})
out1 == out2 # different output
Array(False, dtype=bool)
在多个设备上管理 RNG#
本节将展示如何在多设备设置中使用 jit
和 shard_map
来使用 RNG 的示例。
使用 jax.jit
#
当使用 jax.jit
时,我们可以像以前一样使用 RNG,但现在我们包括 in_shardings
和 out_shardings
参数来指定如何分片输入和输出数据。
有关在 Flax 中使用 jax.jit
在多个设备上进行训练的更多详细信息,请参阅我们的在多个设备上扩展 Flax 模块指南和 lm1b 示例。
# Create a mesh and annotate the axis with a name.
device_mesh = mesh_utils.create_device_mesh((8,))
print(device_mesh)
mesh = Mesh(devices=device_mesh, axis_names=('data',))
print(mesh)
data_sharding = NamedSharding(mesh, PartitionSpec('data',))
print(data_sharding)
[CpuDevice(id=0) CpuDevice(id=1) CpuDevice(id=2) CpuDevice(id=3)
CpuDevice(id=4) CpuDevice(id=5) CpuDevice(id=6) CpuDevice(id=7)]
Mesh('data': 8)
NamedSharding(mesh=Mesh('data': 8), spec=PartitionSpec('data',), memory_kind=unpinned_host)
class Model(nn.Module):
@nn.compact
def __call__(self, x, add_noise):
x = nn.Dense(1)(x)
# use jnp.where for control flow; for more details see: https://jax.net.cn/en/latest/errors.html#jax.errors.TracerBoolConversionError
return jnp.where(
add_noise, x + jax.random.normal(self.make_rng('params'), x.shape), x
)
module = Model()
init_rng, apply_rng = jax.random.split(jax.random.key(0))
x = jnp.ones((8, 1))
variables = module.init(init_rng, x, False)
# create custom forward function, since jit does not support kwargs when in_shardings is specified
def forward(variables, x, add_noise, rng):
return module.apply(variables, x, add_noise, rngs={'params': rng})
# shard the inputs x across devices
# replicate the variables, add_noise boolean and rng key across devices
# shard the output across devices
jit_forward = jax.jit(
forward,
in_shardings=(None, data_sharding, None, None),
out_shardings=data_sharding,
)
out = jit_forward(variables, x, True, apply_rng)
out
Array([[-2.2187614 ],
[-2.8055234 ],
[-2.5464184 ],
[ 1.027039 ],
[-3.5243359 ],
[-2.2795477 ],
[-0.6504516 ],
[ 0.17373265]], dtype=float32)
给定相同的输入,输出是不同的,这意味着 RNG 密钥被用来为输出添加噪声。
我们还可以确认输出在设备之间是分片的
out.addressable_shards
[Shard(device=CpuDevice(id=0), index=(slice(0, 1, None), slice(None, None, None)), replica_id=0, data=[[-2.2187614]]),
Shard(device=CpuDevice(id=1), index=(slice(1, 2, None), slice(None, None, None)), replica_id=0, data=[[-2.8055234]]),
Shard(device=CpuDevice(id=2), index=(slice(2, 3, None), slice(None, None, None)), replica_id=0, data=[[-2.5464184]]),
Shard(device=CpuDevice(id=3), index=(slice(3, 4, None), slice(None, None, None)), replica_id=0, data=[[1.027039]]),
Shard(device=CpuDevice(id=4), index=(slice(4, 5, None), slice(None, None, None)), replica_id=0, data=[[-3.5243359]]),
Shard(device=CpuDevice(id=5), index=(slice(5, 6, None), slice(None, None, None)), replica_id=0, data=[[-2.2795477]]),
Shard(device=CpuDevice(id=6), index=(slice(6, 7, None), slice(None, None, None)), replica_id=0, data=[[-0.6504516]]),
Shard(device=CpuDevice(id=7), index=(slice(7, 8, None), slice(None, None, None)), replica_id=0, data=[[0.17373265]])]
可视化输出分片的另一种方式
jax.debug.visualize_array_sharding(out)
CPU 0 CPU 1 CPU 2 CPU 3 CPU 4 CPU 5 CPU 6 CPU 7
如果我们选择不添加噪声,那么所有批次的输出都是相同的(正如预期的那样,因为所有批次的输入都相同)
out = jit_forward(variables, x, False, apply_rng)
out
Array([[-1.2839764],
[-1.2839764],
[-1.2839764],
[-1.2839764],
[-1.2839764],
[-1.2839764],
[-1.2839764],
[-1.2839764]], dtype=float32)
我们可以确认未进行 JIT 编译的函数产生相同的值,尽管是未分片的(请注意,由于 JIT 编译器的优化,可能存在小的数值差异)
out = forward(variables, x, True, apply_rng)
out
Array([[-2.2187614 ],
[-2.8055234 ],
[-2.5464187 ],
[ 1.0270392 ],
[-3.5243359 ],
[-2.2795477 ],
[-0.6504516 ],
[ 0.17373264]], dtype=float32)
out = forward(variables, x, False, apply_rng)
out
Array([[-1.2839764],
[-1.2839764],
[-1.2839764],
[-1.2839764],
[-1.2839764],
[-1.2839764],
[-1.2839764],
[-1.2839764]], dtype=float32)
使用 shard_map
#
当使用 jax.experimental.shard_map.shard_map
时,要记住的重要部分是
拆分你的 PRNG 密钥,以便为每个设备生成不同的密钥
PRNG 密钥将自动分片到每个设备(前提是你使用正确的分区规范),但 原始批处理 PRNG 密钥数组的秩不会减小;例如,对于一批 8 个 PRNG 密钥和 8 个设备,每个设备将在
shard_map
化的函数中看到大小为 1 的 PRNG 密钥批次因此,要访问 PRNG 密钥本身,我们需要对其进行索引切片(请参见下面的示例)
def forward(variables, x, add_noise, rng_key_batch):
# rng_key_batch is a batch of size 1 containing 1 PRNG key
# index slice into the rng_key_batch to access the PRNG key
return module.apply(
variables, x, add_noise, rngs={'params': rng_key_batch[0]}
)
# define partition specifications
data_pspec = PartitionSpec('data')
no_pspec = PartitionSpec()
# shard the inputs x and rng keys across devices
# replicate the variables and add_noise boolean across devices
# shard the output across devices
shmap_forward = shard_map(
forward,
mesh=mesh,
in_specs=(no_pspec, data_pspec, no_pspec, data_pspec),
out_specs=data_pspec,
)
# get 8 different rng's that will be used by the 8 devices when doing forward inference
apply_rngs = jax.random.split(apply_rng, 8)
out = shmap_forward(variables, x, True, apply_rngs)
out
Array([[-1.2605132 ],
[-1.2405176 ],
[-0.99350417],
[-1.0277128 ],
[-1.4154483 ],
[-0.3905797 ],
[-2.417677 ],
[ 0.9023453 ]], dtype=float32)
确认输出在设备之间是分片的
out.addressable_shards
[Shard(device=CpuDevice(id=0), index=(slice(0, 1, None), slice(None, None, None)), replica_id=0, data=[[-1.2605132]]),
Shard(device=CpuDevice(id=1), index=(slice(1, 2, None), slice(None, None, None)), replica_id=0, data=[[-1.2405176]]),
Shard(device=CpuDevice(id=2), index=(slice(2, 3, None), slice(None, None, None)), replica_id=0, data=[[-0.99350417]]),
Shard(device=CpuDevice(id=3), index=(slice(3, 4, None), slice(None, None, None)), replica_id=0, data=[[-1.0277128]]),
Shard(device=CpuDevice(id=4), index=(slice(4, 5, None), slice(None, None, None)), replica_id=0, data=[[-1.4154483]]),
Shard(device=CpuDevice(id=5), index=(slice(5, 6, None), slice(None, None, None)), replica_id=0, data=[[-0.3905797]]),
Shard(device=CpuDevice(id=6), index=(slice(6, 7, None), slice(None, None, None)), replica_id=0, data=[[-2.417677]]),
Shard(device=CpuDevice(id=7), index=(slice(7, 8, None), slice(None, None, None)), replica_id=0, data=[[0.9023453]])]
jax.debug.visualize_array_sharding(out)
CPU 0 CPU 1 CPU 2 CPU 3 CPU 4 CPU 5 CPU 6 CPU 7
提升的变换#
Flax 提升的变换允许你将 JAX 变换 与 Module
参数一起使用。本节将向你展示如何在 Flax 提升的变换中控制 PRNG 密钥的拆分方式。
有关更多详细信息,请参阅提升的变换。
nn.vmap
#
我们可以使用 nn.vmap
创建批处理的 Dense
层
x = jnp.ones((3, 2))
BatchDense = nn.vmap(
nn.Dense,
in_axes=0, out_axes=0,
variable_axes={'params': None},
split_rngs={'params': False})
BatchDense(2).init(jax.random.key(0), x)
{'params': {'bias': Array([0., 0.], dtype=float32),
'kernel': Array([[-1.2488099 , -0.6127134 ],
[-0.07084481, 0.60130936]], dtype=float32)}}
通过表示 variable_axes={'params': 0}'
,我们在第一个轴上向量化 params
数组。但是,生成的参数值彼此相同
BatchDense = nn.vmap(
nn.Dense,
in_axes=0, out_axes=0,
variable_axes={'params': 0},
split_rngs={'params': False})
BatchDense(2).init(jax.random.key(0), x)
{'params': {'bias': Array([[0., 0.],
[0., 0.],
[0., 0.]], dtype=float32),
'kernel': Array([[[-1.2488099 , -0.6127134 ],
[-0.07084481, 0.60130936]],
[[-1.2488099 , -0.6127134 ],
[-0.07084481, 0.60130936]],
[[-1.2488099 , -0.6127134 ],
[-0.07084481, 0.60130936]]], dtype=float32)}}
如果我们还使 split_rngs={'params': True}
,那么我们提供的 PRNG 密钥将在变量轴(在这种情况下为批处理轴 0)上拆分,并且我们可以为每个批处理输入生成不同的参数
BatchDense = nn.vmap(
nn.Dense,
in_axes=0, out_axes=0,
variable_axes={'params': 0},
split_rngs={'params': True})
BatchDense(2).init(jax.random.key(0), x)
{'params': {'bias': Array([[0., 0.],
[0., 0.],
[0., 0.]], dtype=float32),
'kernel': Array([[[-0.2526208 , -0.15088455],
[-1.1987205 , -0.40843305]],
[[-0.7064888 , -1.108805 ],
[-0.938775 , 1.4812315 ]],
[[-0.59468937, -0.2502723 ],
[-1.33515 , 0.5067442 ]]], dtype=float32)}}
通过 self.variable
添加变量很简单
class Model(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(2)(x)
kernel = self.variable(
'other_collection',
'kernel',
lambda: jax.random.normal(self.make_rng('other'), x.shape),
)
return x + kernel.value
BatchModel = nn.vmap(
Model,
in_axes=0,
out_axes=0,
variable_axes={'params': 0, 'other_collection': 0},
split_rngs={'params': True, 'other': True},
)
BatchModel().init({'params': jax.random.key(0), 'other': jax.random.key(1)}, x)
{'params': {'Dense_0': {'bias': Array([[0., 0.],
[0., 0.],
[0., 0.]], dtype=float32),
'kernel': Array([[[-0.9079084 , 0.76390624],
[-0.01285526, 0.4320353 ]],
[[ 0.12398645, 0.7884565 ],
[ 1.5344163 , 1.3186085 ]],
[[-0.44171348, 0.43430036],
[-0.40732604, 0.29774475]]], dtype=float32)}},
'other_collection': {'kernel': Array([[-0.8193048 , 0.711106 ],
[-0.37802765, -0.66705877],
[-0.44808003, 0.93031347]], dtype=float32)}}
我们可以控制要拆分的 RNG 流,例如,如果我们只想拆分 'params'
RNG 流,那么从 self.variable
生成的变量对于每个批处理输入都将相同
BatchModel = nn.vmap(
Model,
in_axes=0, out_axes=0,
variable_axes={'params': 0, 'other_collection': 0},
split_rngs={'params': True, 'other': False})
BatchModel().init({'params': jax.random.key(0), 'other': jax.random.key(1)}, x)
{'params': {'Dense_0': {'bias': Array([[0., 0.],
[0., 0.],
[0., 0.]], dtype=float32),
'kernel': Array([[[-0.9079084 , 0.76390624],
[-0.01285526, 0.4320353 ]],
[[ 0.12398645, 0.7884565 ],
[ 1.5344163 , 1.3186085 ]],
[[-0.44171348, 0.43430036],
[-0.40732604, 0.29774475]]], dtype=float32)}},
'other_collection': {'kernel': Array([[ 0.44956833, -1.1854612 ],
[ 0.44956833, -1.1854612 ],
[ 0.44956833, -1.1854612 ]], dtype=float32)}}
我们还可以控制应为每个批处理输入生成哪些参数/变量,例如,如果我们只希望 'params'
为每个批处理输入生成单独的参数
BatchModel = nn.vmap(
Model,
in_axes=0, out_axes=0,
variable_axes={'params': 0, 'other_collection': None},
split_rngs={'params': True, 'other': False})
BatchModel().init({'params': jax.random.key(0), 'other': jax.random.key(1)}, x)
{'params': {'Dense_0': {'bias': Array([[0., 0.],
[0., 0.],
[0., 0.]], dtype=float32),
'kernel': Array([[[-0.9079084 , 0.76390624],
[-0.01285526, 0.4320353 ]],
[[ 0.12398645, 0.7884565 ],
[ 1.5344163 , 1.3186085 ]],
[[-0.44171348, 0.43430036],
[-0.40732604, 0.29774475]]], dtype=float32)}},
'other_collection': {'kernel': Array([ 0.44956833, -1.1854612 ], dtype=float32)}}
nn.scan
#
我们可以使用 nn.scan
创建扫描的 Module
层(这对于简化重复堆叠的子模块很有用)
x = jnp.ones((3, 2))
class ResidualMLPBlock(nn.Module):
@nn.compact
def __call__(self, x, _):
h = nn.Dense(features=2)(x)
h = nn.relu(h)
return x + h, None # return an empty carry
ScanMLP = nn.scan(
ResidualMLPBlock, variable_axes={'params': 0},
variable_broadcast=False, split_rngs={'params': True},
length=3)
ScanMLP().init(jax.random.key(0), x, None) # pass in an empty carry
{'params': {'Dense_0': {'bias': Array([[0., 0.],
[0., 0.],
[0., 0.]], dtype=float32),
'kernel': Array([[[-0.07838312, -0.7422982 ],
[ 0.87488323, 0.13773395]],
[[ 0.97309333, 0.9087693 ],
[-0.12564984, -1.0920651 ]],
[[-0.99055105, 1.1499453 ],
[-0.15721127, -0.62520015]]], dtype=float32)}}}
与之前类似,我们可以控制是否拆分 RNG 流,例如,如果我们希望所有堆叠的模块都初始化为相同的参数值,则可以传入 split_rngs={'params': False}
ScanMLP = nn.scan(
ResidualMLPBlock, variable_axes={'params': 0},
variable_broadcast=False, split_rngs={'params': False},
length=3)
ScanMLP().init(jax.random.key(0), x, None)
{'params': {'Dense_0': {'bias': Array([[0., 0.],
[0., 0.],
[0., 0.]], dtype=float32),
'kernel': Array([[[-0.66715515, -0.0484313 ],
[ 0.9867164 , 0.75408363]],
[[-0.66715515, -0.0484313 ],
[ 0.9867164 , 0.75408363]],
[[-0.66715515, -0.0484313 ],
[ 0.9867164 , 0.75408363]]], dtype=float32)}}}