提升的变换#
⚠️ 高级主题 ⚠️
此设计说明解释了 flax.linen.transform
的底层实现,它使得在 Flax Module
中使用 JAX 变换成为可能。
简介#
JAX 使用函数式 API,这意味着只有在使用没有副作用的函数时才能保证正确的行为(JAX 文档)。通常,这些副作用是由于改变函数外部的对象而产生的。
函数式范式有一些优点,例如能够明确地推理状态和随机性。只有当输入参数发生变化时,函数输出才会改变。因此,保证函数以确定性的方式运行。
但是纯函数为 JAX 提供了另一个巨大的优势:具体来说,它们使函数变换成为可能。例如,jax.vmap(f)
将向量化函数 f
。由于 f
不能有副作用,因此 f
的向量化/并行版本是明确定义的。要了解为什么我们需要这个限制,请考虑如果 f
会递增计数器或抽取随机数会发生什么。对于向量中的每个项目,f
会抽取相同还是不同的随机数?批次中的每个项目都有自己的计数器,还是计数器在项目之间共享?如果 f
并行计算,计数器以什么顺序递增?所有这些问题的答案是“这取决于”。行为是模糊的,函数式约束优雅地避免了这个问题。
Flax 引入了一种安全的方式,以 JAX 兼容的形式拥有有限的随机性和有状态变量。Flax 中的状态没有问题的原因是它是局部的:在 Flax Module
内部有变量和 PRNG 序列,但在外部只有 JAX 数组和 PRNG 密钥。
在大多数用例中,Flax 用于以有状态的方式定义模型。由于 Module
在外部表现得像一个纯函数,因此我们可以充分利用 JAX 及其所有转换。但是,在某些情况下,我们希望将转换和 Module
结合使用,以获得两全其美的效果。本设计说明解释了我们如何扩展 JAX 的函数变换以使其在具有内部状态和随机性的 Module
上工作。
函数化#
在深入细节之前,让我们考虑一个简单的示例,其中我们希望在 Module
内部使用 vmap
。
首先,我们定义一个简单的 MLP,没有任何转换
import jax
from jax import random, numpy as jnp
from flax import linen as nn
class MLP(nn.Module):
@nn.compact
def __call__(self, xs):
h = nn.Dense(4, name='hidden')(xs)
h = nn.relu(h)
return nn.Dense(1, name='out')(h)
现在,如果我们希望 xs
中的每个项目都有单独的 MLP 参数,该怎么办?如果这是“原始 JAX”,我们可以想象编写类似 jax.vmap(apply_mlp)(mlp_params, xs)
的内容。但是在 Linen 中执行类似的操作实际上会失败
class NaiveVmapMLP(nn.Module):
@nn.compact
def __call__(self, xs):
mlp = MLP()
return jax.vmap(lambda mlp, x: mlp(x))(mlp, xs) # fails
当在 mlp
上使用 vmap
时,JAX 将引发错误,因为它不是 JAX 数组或简单的数组容器。我们不能真的责怪 JAX 拒绝执行这项未明确指定的工作。毕竟,甚至不清楚这里应该发生什么。MLP 内部的参数尚未初始化,我们需要为每组参数单独的 PRNG 密钥。jax.vmap
只能广播或映射到一个轴上,但它不能自动拆分 PRNG 密钥。因此,我们必须手动调用 jax.random.split
。
我们可以通过首先将 MLP
转换为纯初始化和应用函数来解决此问题。之后,我们使用 param
方法存储参数
class ManualVmapMLP(nn.Module):
@nn.compact
def __call__(self, xs):
mlp = MLP(parent=None)
init_fn = lambda rng, xs: jax.vmap(mlp.init, in_axes=0)(random.split(rng, xs.shape[0]), xs)['params']
apply_fn = jax.vmap(mlp.apply, in_axes=0)
mlp_params = self.param('mlp', init_fn, xs)
return apply_fn({'params': mlp_params}, xs)
xs = jnp.ones((3, 4))
variables = ManualVmapMLP().init(random.key(0), xs)
print(jax.tree_util.tree_map(jnp.shape, variables['params']))
"""==>
{
mlp: {
hidden: {
bias: (3, 4),
kernel: (3, 4, 4),
},
out: {
bias: (3, 1),
kernel: (3, 4, 1),
},
},
}
"""
在此,MLP(parent=None)
创建 MLP
的分离实例。这避免了在当前模块内部为子模块保留名称。虽然不是绝对必要的,但这也可以确保我们不会意外地以有状态的方式使用 MLP 实例,并且我们被迫通过 .init
或 .apply
来使用它。
此示例仍然相对简洁,但它已经需要一些额外的“簿记”语句才能使其工作。但是,此实现存在许多限制
在初始化期间,我们通过
init_fn
和apply_fn
调用子模块两次。如果子模块使用相同的技巧来执行函数转换,那么随着模块调用次数像 2^d 一样增长,其中 d 是嵌套函数转换的数量,我们将最终执行大量代码。该实现假设子模块仅需要参数 RNG 序列。
该实现假设我们仅在
init
期间在“params”集合中创建变量。但是,它不支持其他变量集合,也不支持在apply
中创建/更新变量。
特别是第 3 点使得手动函数化很麻烦。请随意尝试在 MLP
模块中使用 nn.BatchNorm
层来扩展上面的示例。这将需要处理一些额外的复杂性,例如存储更新的批量统计信息,并确保当批量统计信息应该是不可变的(例如:评估模式)时,它在 vmap
内部不是可变的。
我们将把有状态的模块转换为纯函数的过程称为“函数化”。通过临时将有状态的 Module
转换为函数,我们使其与 JAX 的函数转换兼容。
提升#
Flax 为手动函数化提供了一种替代方案,我们称之为提升的变换。提升的变换在 flax.core.lift
中定义。所有提升的 JAX 变换都使用称为 pack
的单个通用提升 API 来定义。
为了定义 pack
,必须做出许多决策。pack
的实现控制着变量和 rng 如何提升以及用户控制的精细程度。它还必须决定是在变量还是在变换定义时做出提升决策。
提升粒度#
通过 Linen API,用户可以定义任意的变量集合和 PRNG 序列。集合中的每个变量都以相同的方式提升。
集合通常会赋予具有语义意义的名称,如 “params” 或 “batch_stats”,而不是像 “state” 这样的通用名称。由于集合带有语义意义,我们可以在转换级别决定如何提升每个集合。例如,当我们在模型中添加批次维度时,我们希望共享所有参数变量。
同时,我们可以编写使用转换的通用代码,而无需确切知道子模块将创建哪种变量。因此,集合在细粒度控制和通用性之间取得了平衡。我们也避免了脆弱的字符串匹配代码,这些代码会循环遍历所有变量,并尝试根据诸如“目标名称前缀为‘kernel’的所有变量”这样的命名约定,以临时的方式拆分集合。如果需要更细粒度的控制,用户可以简单地将一组变量拆分到多个应区别处理的集合中。
转换与变量控制#
提升行为可以在转换级别或变量定义期间定义。我们使用转换级别的提升行为定义。选择这种方式的原因是存在许多具有各种行为的不同转换。例如:vmap
具有广播和矢量化的参数,而 scan
具有扫描、携带和广播参数。一个变量必须定义其对所有这些转换的行为,否则 Module
将与这些转换不兼容。或者,我们必须对如何处理转换做出默认决策。但是,这可能会导致潜在的错误,因为该行为可能实际上不符合用户的意图。
提升包还提供了一个通用的 transform
,它允许任意函数转换变量集合。例如,这可以用于通过转置权重来绑定绑定的自动编码器中的权重。如果提升决策是在变量定义时做出的,则尚不清楚是否可以定义类似的通用转换。
Linen#
提升模块不知道 Linen Module
API。相反,它直接对 flax.core.Scope
的实例进行操作。Scope
实例包含 Module
的变量和 PRNG 序列。如果 Module
实例有父级,或者它是使用 init
或 apply
创建的,则每个 Module
实例在 .scope
字段中都有一个 Scope
实例。通常,您在其中调用 init
或 apply
的顶层 Module
实例是唯一没有绑定 Scope
的 Module
实例。
当 Module
被转换时,我们使用 flax.core.lift
API 来提升作用域,并使用 Module.clone()
来创建一个新的 Module
实例,其中绑定了提升的作用域。
flax.linen.transforms
公开了 flax.core.lift
中转换的包装器。核心提升 API 对函数进行操作,而 Linen 包装器可以转换 Module
类或 Module
方法。
因此,提升的实现独立于 Linen API。这种关注点分离简化了实现,同时可能允许替代的 Module
抽象构建在用于提升和状态管理的公共核心之上。
实现#
pack(fn, in_vars, out_vars, rngs)
API 经历以下阶段
作用域去重
此阶段仅在多个作用域一起提升时相关。在这种情况下,我们必须首先找到根作用域的集合。如果其祖先中没有需要提升的作用域,则该作用域是根作用域。
通过仅提升根,我们避免了两次提升相同的变量。
对于非根作用域,我们存储对其祖先作用域的引用和路径,以便以后可以重建它(阶段 4)。
过滤阶段
变量和 PRNG 序列被分成组。这样,
fn
可以将每个组分别提升到转换中。组由指定为的过滤器定义集合/prng 名称列表
True
(匹配所有内容)False
(不匹配任何内容)DenyList(filter)
(匹配除指定集合以外的所有内容(例如:DenyList(['params'])
匹配除“params”集合以外的所有内容))。
一个集合或 PRNG 序列只能放入一个组中。如果一个集合匹配多个过滤器,它将放入第一个匹配过滤器的组中。如果一个集合或 PRNG 序列与任何过滤器都不匹配,则不会提升。这意味着它不能在转换内部使用,并且尝试这样做会导致引发错误。例如,
in_vars = (["params"], True)
将导致 “params” 集合放入第一组,而所有其他集合放入第二组。对于匹配的每个 PRNG 序列,我们通过调用
make_rng
来播种一个新的 PRNG 序列。这避免了在提升的转换完成后更新 PRNG 状态的需要。特定于转换的提升
使用变量和 PRNG 组调用
fn
。JAX 转换具有不同的签名和提升选项。可以说最清晰的例子是vmap
。在 vmap 的情况下,函数参数、PRNG 和变量集合被传递给jax.vmap
包装的函数。作用域重建
现在变量和 PRNG 已在转换内部提升,我们希望重建提升的作用域。Pack 使用
scope_fn
调用fn
,该scope_fn
接受提升的变量和 PRNG,并返回带有提升的变量和 rng 序列的重建的作用域。重新打包阶段
在我们使用提升的作用域后,我们必须检索更新的变量(PRNG 序列可以简单地丢弃)。pack 传递
repack_fn
以支持此操作。此阶段类似于阶段 2,不同之处在于我们仅提升变量,并且忽略不可变的变量。不可变变量无法更新。因此,不应从转换后的函数返回它们。提交阶段
pack
期望fn
返回一对,其中第一项将简单地从 pack 返回,第二项应是重新打包的变量。更新的变量存储在原始/未提升的作用域中,以便在转换完成后,转换内部发生的变异仍然存在。
使用 pack 示例#
使用 pack
来转置变量集合中每个矩阵的最小示例
from flax.core import lift
from flax.core import Scope, init, apply, nn as core_nn
def lift_transpose(fn, target='params', variables=True, rngs=True):
# by default we transpose 'params' and simply pass through all other variables.
def wrapper(scope_fn, repack_fn, variable_groups, rng_groups, *args):
# normally we would first call into a JAX transformed function here...
target, rest = variable_groups
def trans(x):
if x.ndim == 2:
return x.T
return x
target = jax.tree_util.tree_map(trans, target)
variable_groups = (target, rest)
scope = scope_fn(variable_groups, rng_groups)
y = fn(scope, *args)
out_variables = repack_fn(scope)
return y, out_variables
return lift.pack(
wrapper,
in_variable_filters=(target, variables),
out_variable_filters=(variables,),
rng_filters=(rngs,))
x = jnp.ones((3, 2))
y, params = init(lift_transpose(core_nn.dense))(random.key(0), x, 4)
注意,大多数用户不需要直接与 pack
交互。当您发现现有提升的转换尚不支持的用例时,请打开一个 GitHub 问题。
支持的转换#
Jax 转换 |
在 Linen 中是否支持? |
注释 |
---|---|---|
vmap |
✅ |
|
scan |
✅ |
携带变量不能在扫描主体内部初始化。 |
remat |
✅ |
|
jit |
✅ |
当前实现可能会导致不必要的重新编译。 |
jvp |
✅ |
|
vjp |
✅ |
|
custom_vjp |
✅ |
|
custom_jvp |
❌ |
|
while_loop |
✅ |
携带变量不能在 while_loop 主体内部初始化。 |
cond |
✅ |
变量初始化/变异必须在分支之间结构匹配。 |
switch |
✅ |
变量初始化/变异必须在分支之间结构匹配。 |
pmap |
❌ |
|
xmap |
❌ |
参考
Linen 示例#
回到我们最初的示例,我们现在可以使用 nn.vmap
来简化我们的实现
class LinenVmapMLP(nn.Module):
@nn.compact
def __call__(self, xs):
VmapMLP = nn.vmap(MLP, variable_axes={'params': 0}, split_rngs={'params': True}, in_axes=0)
return VmapMLP(name='mlp')(xs)
variables = LinenVmapMLP().init(random.key(0), xs)
print(jax.tree_util.tree_map(jnp.shape, variables['params']))
"""==>
{
mlp: {
Dense_0: {
bias: (3, 4),
kernel: (3, 2, 4),
},
Dense_1: {
bias: (3, 1),
kernel: (3, 4, 1),
},
},
}
"""
这里我们使用 variable_axes={'params': 0}
来表示参数是向量化的而不是共享的,而 split_rngs={'params': True}
意味着每组参数都是独立初始化的。
我们还可以通过添加一个 BatchNorm
层来扩展这个例子,使其包含一些内部状态
class StatefulMLP(nn.Module):
@nn.compact
def __call__(self, x, *, train):
h = nn.Dense(4, name='hidden')(x)
h = nn.BatchNorm(axis_name='batch')(h, use_running_average=not train)
h = nn.relu(h)
return nn.Dense(1, name='out')(h)
class LinenStatefulVmapMLP(nn.Module):
@nn.compact
def __call__(self, xs, *, train):
VmapMLP = nn.vmap(StatefulMLP, variable_axes={'params': 0, 'batch_stats': 0}, split_rngs={'params': True}, in_axes=0)
return VmapMLP(name='mlp')(xs, train=train)
variables = LinenStatefulVmapMLP().init(random.key(0), xs)
我们只需要在 nn.vmap
中添加 'batch_stats': 0
,表示批统计信息是向量化的,而不是沿着第一个轴共享的。