提升的变换#

⚠️ 高级主题 ⚠️

此设计说明解释了 flax.linen.transform 的底层实现,它使得在 Flax Module 中使用 JAX 变换成为可能。

简介#

JAX 使用函数式 API,这意味着只有在使用没有副作用的函数时才能保证正确的行为(JAX 文档)。通常,这些副作用是由于改变函数外部的对象而产生的。

函数式范式有一些优点,例如能够明确地推理状态和随机性。只有当输入参数发生变化时,函数输出才会改变。因此,保证函数以确定性的方式运行。

但是纯函数为 JAX 提供了另一个巨大的优势:具体来说,它们使函数变换成为可能。例如,jax.vmap(f) 将向量化函数 f。由于 f 不能有副作用,因此 f 的向量化/并行版本是明确定义的。要了解为什么我们需要这个限制,请考虑如果 f 会递增计数器或抽取随机数会发生什么。对于向量中的每个项目,f 会抽取相同还是不同的随机数?批次中的每个项目都有自己的计数器,还是计数器在项目之间共享?如果 f 并行计算,计数器以什么顺序递增?所有这些问题的答案是“这取决于”。行为是模糊的,函数式约束优雅地避免了这个问题。

Flax 引入了一种安全的方式,以 JAX 兼容的形式拥有有限的随机性和有状态变量。Flax 中的状态没有问题的原因是它是局部的:在 Flax Module 内部有变量和 PRNG 序列,但在外部只有 JAX 数组和 PRNG 密钥。

在大多数用例中,Flax 用于以有状态的方式定义模型。由于 Module 在外部表现得像一个纯函数,因此我们可以充分利用 JAX 及其所有转换。但是,在某些情况下,我们希望将转换和 Module 结合使用,以获得两全其美的效果。本设计说明解释了我们如何扩展 JAX 的函数变换以使其在具有内部状态和随机性的 Module 上工作。

函数化#

在深入细节之前,让我们考虑一个简单的示例,其中我们希望在 Module 内部使用 vmap

首先,我们定义一个简单的 MLP,没有任何转换

import jax
from jax import random, numpy as jnp
from flax import linen as nn

class MLP(nn.Module):
  @nn.compact
  def __call__(self, xs):
    h = nn.Dense(4, name='hidden')(xs)
    h = nn.relu(h)
    return nn.Dense(1, name='out')(h)

现在,如果我们希望 xs 中的每个项目都有单独的 MLP 参数,该怎么办?如果这是“原始 JAX”,我们可以想象编写类似 jax.vmap(apply_mlp)(mlp_params, xs) 的内容。但是在 Linen 中执行类似的操作实际上会失败

class NaiveVmapMLP(nn.Module):
  @nn.compact
  def __call__(self, xs):
    mlp = MLP()
    return jax.vmap(lambda mlp, x: mlp(x))(mlp, xs)  # fails

当在 mlp 上使用 vmap 时,JAX 将引发错误,因为它不是 JAX 数组或简单的数组容器。我们不能真的责怪 JAX 拒绝执行这项未明确指定的工作。毕竟,甚至不清楚这里应该发生什么。MLP 内部的参数尚未初始化,我们需要为每组参数单独的 PRNG 密钥。jax.vmap 只能广播或映射到一个轴上,但它不能自动拆分 PRNG 密钥。因此,我们必须手动调用 jax.random.split

我们可以通过首先将 MLP 转换为纯初始化和应用函数来解决此问题。之后,我们使用 param 方法存储参数

class ManualVmapMLP(nn.Module):
  @nn.compact
  def __call__(self, xs):
    mlp = MLP(parent=None)
    init_fn = lambda rng, xs: jax.vmap(mlp.init, in_axes=0)(random.split(rng, xs.shape[0]), xs)['params']
    apply_fn = jax.vmap(mlp.apply, in_axes=0)
    mlp_params = self.param('mlp', init_fn, xs)
    return apply_fn({'params': mlp_params}, xs)

xs = jnp.ones((3, 4))
variables = ManualVmapMLP().init(random.key(0), xs)
print(jax.tree_util.tree_map(jnp.shape, variables['params']))
"""==>
{
    mlp: {
        hidden: {
            bias: (3, 4),
            kernel: (3, 4, 4),
        },
        out: {
            bias: (3, 1),
            kernel: (3, 4, 1),
        },
    },
}
"""

在此,MLP(parent=None) 创建 MLP 的分离实例。这避免了在当前模块内部为子模块保留名称。虽然不是绝对必要的,但这也可以确保我们不会意外地以有状态的方式使用 MLP 实例,并且我们被迫通过 .init.apply 来使用它。

此示例仍然相对简洁,但它已经需要一些额外的“簿记”语句才能使其工作。但是,此实现存在许多限制

  1. 在初始化期间,我们通过 init_fnapply_fn 调用子模块两次。如果子模块使用相同的技巧来执行函数转换,那么随着模块调用次数像 2^d 一样增长,其中 d 是嵌套函数转换的数量,我们将最终执行大量代码。

  2. 该实现假设子模块仅需要参数 RNG 序列。

  3. 该实现假设我们仅在 init 期间在“params”集合中创建变量。但是,它不支持其他变量集合,也不支持在 apply 中创建/更新变量。

特别是第 3 点使得手动函数化很麻烦。请随意尝试在 MLP 模块中使用 nn.BatchNorm 层来扩展上面的示例。这将需要处理一些额外的复杂性,例如存储更新的批量统计信息,并确保当批量统计信息应该是不可变的(例如:评估模式)时,它在 vmap 内部不是可变的。

我们将把有状态的模块转换为纯函数的过程称为“函数化”。通过临时将有状态的 Module 转换为函数,我们使其与 JAX 的函数转换兼容。

提升#

Flax 为手动函数化提供了一种替代方案,我们称之为提升的变换。提升的变换在 flax.core.lift 中定义。所有提升的 JAX 变换都使用称为 pack 的单个通用提升 API 来定义。

为了定义 pack,必须做出许多决策。pack 的实现控制着变量和 rng 如何提升以及用户控制的精细程度。它还必须决定是在变量还是在变换定义时做出提升决策。

提升粒度#

通过 Linen API,用户可以定义任意的变量集合和 PRNG 序列。集合中的每个变量都以相同的方式提升。

集合通常会赋予具有语义意义的名称,如 “params” 或 “batch_stats”,而不是像 “state” 这样的通用名称。由于集合带有语义意义,我们可以在转换级别决定如何提升每个集合。例如,当我们在模型中添加批次维度时,我们希望共享所有参数变量。

同时,我们可以编写使用转换的通用代码,而无需确切知道子模块将创建哪种变量。因此,集合在细粒度控制和通用性之间取得了平衡。我们也避免了脆弱的字符串匹配代码,这些代码会循环遍历所有变量,并尝试根据诸如“目标名称前缀为‘kernel’的所有变量”这样的命名约定,以临时的方式拆分集合。如果需要更细粒度的控制,用户可以简单地将一组变量拆分到多个应区别处理的集合中。

转换与变量控制#

提升行为可以在转换级别或变量定义期间定义。我们使用转换级别的提升行为定义。选择这种方式的原因是存在许多具有各种行为的不同转换。例如:vmap 具有广播和矢量化的参数,而 scan 具有扫描、携带和广播参数。一个变量必须定义其对所有这些转换的行为,否则 Module 将与这些转换不兼容。或者,我们必须对如何处理转换做出默认决策。但是,这可能会导致潜在的错误,因为该行为可能实际上不符合用户的意图。

提升包还提供了一个通用的 transform,它允许任意函数转换变量集合。例如,这可以用于通过转置权重来绑定绑定的自动编码器中的权重。如果提升决策是在变量定义时做出的,则尚不清楚是否可以定义类似的通用转换。

Linen#

提升模块不知道 Linen Module API。相反,它直接对 flax.core.Scope 的实例进行操作。Scope 实例包含 Module 的变量和 PRNG 序列。如果 Module 实例有父级,或者它是使用 initapply 创建的,则每个 Module 实例在 .scope 字段中都有一个 Scope 实例。通常,您在其中调用 initapply 的顶层 Module 实例是唯一没有绑定 ScopeModule 实例。

Module 被转换时,我们使用 flax.core.lift API 来提升作用域,并使用 Module.clone() 来创建一个新的 Module 实例,其中绑定了提升的作用域。

flax.linen.transforms 公开了 flax.core.lift 中转换的包装器。核心提升 API 对函数进行操作,而 Linen 包装器可以转换 Module 类或 Module 方法。

因此,提升的实现独立于 Linen API。这种关注点分离简化了实现,同时可能允许替代的 Module 抽象构建在用于提升和状态管理的公共核心之上。

实现#

pack(fn, in_vars, out_vars, rngs) API 经历以下阶段

  1. 作用域去重

    此阶段仅在多个作用域一起提升时相关。在这种情况下,我们必须首先找到根作用域的集合。如果其祖先中没有需要提升的作用域,则该作用域是根作用域。

    通过仅提升根,我们避免了两次提升相同的变量。

    对于非根作用域,我们存储对其祖先作用域的引用和路径,以便以后可以重建它(阶段 4)。

  2. 过滤阶段

    变量和 PRNG 序列被分成组。这样,fn 可以将每个组分别提升到转换中。组由指定为的过滤器定义

    • 集合/prng 名称列表

    • True (匹配所有内容)

    • False (不匹配任何内容)

    • DenyList(filter) (匹配除指定集合以外的所有内容(例如:DenyList(['params']) 匹配除“params”集合以外的所有内容))。

    一个集合或 PRNG 序列只能放入一个组中。如果一个集合匹配多个过滤器,它将放入第一个匹配过滤器的组中。如果一个集合或 PRNG 序列与任何过滤器都不匹配,则不会提升。这意味着它不能在转换内部使用,并且尝试这样做会导致引发错误。例如,in_vars = (["params"], True) 将导致 “params” 集合放入第一组,而所有其他集合放入第二组。

    对于匹配的每个 PRNG 序列,我们通过调用 make_rng 来播种一个新的 PRNG 序列。这避免了在提升的转换完成后更新 PRNG 状态的需要。

  3. 特定于转换的提升

    使用变量和 PRNG 组调用 fn。JAX 转换具有不同的签名和提升选项。可以说最清晰的例子是 vmap。在 vmap 的情况下,函数参数、PRNG 和变量集合被传递给 jax.vmap 包装的函数。

  4. 作用域重建

    现在变量和 PRNG 已在转换内部提升,我们希望重建提升的作用域。Pack 使用 scope_fn 调用 fn,该 scope_fn 接受提升的变量和 PRNG,并返回带有提升的变量和 rng 序列的重建的作用域。

  5. 重新打包阶段

    在我们使用提升的作用域后,我们必须检索更新的变量(PRNG 序列可以简单地丢弃)。pack 传递 repack_fn 以支持此操作。此阶段类似于阶段 2,不同之处在于我们仅提升变量,并且忽略不可变的变量。不可变变量无法更新。因此,不应从转换后的函数返回它们。

  6. 提交阶段

    pack 期望 fn 返回一对,其中第一项将简单地从 pack 返回,第二项应是重新打包的变量。更新的变量存储在原始/未提升的作用域中,以便在转换完成后,转换内部发生的变异仍然存在。

使用 pack 示例#

使用 pack 来转置变量集合中每个矩阵的最小示例

from flax.core import lift
from flax.core import Scope, init, apply, nn as core_nn

def lift_transpose(fn, target='params', variables=True, rngs=True):
  # by default we transpose 'params' and simply pass through all other variables.
  def wrapper(scope_fn, repack_fn, variable_groups, rng_groups, *args):
    # normally we would first call into a JAX transformed function here...
    target, rest = variable_groups
    def trans(x):
      if x.ndim == 2:
        return x.T
      return x
    target = jax.tree_util.tree_map(trans, target)
    variable_groups = (target, rest)
    scope = scope_fn(variable_groups, rng_groups)
    y = fn(scope, *args)
    out_variables = repack_fn(scope)
    return y, out_variables
  return lift.pack(
      wrapper,
      in_variable_filters=(target, variables),
      out_variable_filters=(variables,),
      rng_filters=(rngs,))

x = jnp.ones((3, 2))
y, params = init(lift_transpose(core_nn.dense))(random.key(0), x, 4)

注意,大多数用户不需要直接与 pack 交互。当您发现现有提升的转换尚不支持的用例时,请打开一个 GitHub 问题。

支持的转换#

Jax 转换

在 Linen 中是否支持?

注释

vmap

scan

携带变量不能在扫描主体内部初始化。

remat

jit

当前实现可能会导致不必要的重新编译。

jvp

vjp

custom_vjp

custom_jvp

while_loop

携带变量不能在 while_loop 主体内部初始化。

cond

变量初始化/变异必须在分支之间结构匹配。

switch

变量初始化/变异必须在分支之间结构匹配。

pmap

xmap

参考

Linen 示例#

回到我们最初的示例,我们现在可以使用 nn.vmap 来简化我们的实现

class LinenVmapMLP(nn.Module):
  @nn.compact
  def __call__(self, xs):
    VmapMLP = nn.vmap(MLP, variable_axes={'params': 0}, split_rngs={'params': True}, in_axes=0)
    return VmapMLP(name='mlp')(xs)

variables = LinenVmapMLP().init(random.key(0), xs)
print(jax.tree_util.tree_map(jnp.shape, variables['params']))
"""==>
{
    mlp: {
        Dense_0: {
            bias: (3, 4),
            kernel: (3, 2, 4),
        },
        Dense_1: {
            bias: (3, 1),
            kernel: (3, 4, 1),
        },
    },
}
"""

这里我们使用 variable_axes={'params': 0} 来表示参数是向量化的而不是共享的,而 split_rngs={'params': True} 意味着每组参数都是独立初始化的。

我们还可以通过添加一个 BatchNorm 层来扩展这个例子,使其包含一些内部状态

class StatefulMLP(nn.Module):
  @nn.compact
  def __call__(self, x, *, train):
    h = nn.Dense(4, name='hidden')(x)
    h = nn.BatchNorm(axis_name='batch')(h, use_running_average=not train)
    h = nn.relu(h)
    return nn.Dense(1, name='out')(h)

class LinenStatefulVmapMLP(nn.Module):
  @nn.compact
  def __call__(self, xs, *, train):
    VmapMLP = nn.vmap(StatefulMLP, variable_axes={'params': 0, 'batch_stats': 0}, split_rngs={'params': True}, in_axes=0)
    return VmapMLP(name='mlp')(xs, train=train)
variables = LinenStatefulVmapMLP().init(random.key(0), xs)

我们只需要在 nn.vmap 中添加 'batch_stats': 0,表示批统计信息是向量化的,而不是沿着第一个轴共享的。

替代方案#

其他数值计算框架将变量视为一等公民。函数式的一个替代方案是使用集成在 JAX 中或在 JAX 之上的变量系统。这样做的一个优点是,每个变量的提升变得更加容易。如果变量是 JAX IR (JAXPR) 的一部分,我们可以检查在特定计算中必须提升哪些变量。可以选择使用集合标签来注释它们,以决定各种提升选项。

这种方法的缺点是变量系统更加复杂。变量是相关引用,打破了函数式编程的核心假设(参见引用透明性)。其他当前具有函数式接口的 API 可能也需要集成(例如:检查点和优化 API)。