Flax 模块生命周期#

本设计说明适用于已经熟悉 Flax Linen 模块但想了解该抽象背后的设计原则的用户。本说明应该让您很好地理解模块 API 构建所基于的假设和保证。如果您还没有模块的实际经验,请查看快速入门指南

Flax Linen 模块在 Flax 核心之上提供了 Pythonic 抽象。模块抽象允许您在 JAX 之上创建具有状态、参数和随机性的类。这是一个关于 Module 类的设计和行为的实用指南。最后,您应该可以轻松地脱离常规,以新的方式使用模块。

概述#

定义#

让我们从模块生命周期的高级概述开始。首先,定义一个简单的模块

class MLP(nn.Module):
  # 1. Attribute annotations
  hidden_size: int
  out_size: int

  # 2. The ``setup`` method
  def setup(self):
    self.hidden = nn.Dense(self.hidden_size)
    self.out = nn.Dense(self.out_size)

  # 3. User methods
  def __call__(self, x):
    a = self.hidden(x)
    h = nn.relu(a)
    return self.out(h)

这个模块由以下部分组成:

  1. 属性注释,定义为数据类字段。这些注释自动定义一个构造函数。

  2. ``setup`` 方法,用于创建子模块并将它们分配给属性。

  3. 用户方法。按照惯例,大多数模块只有一个 __call__ 方法,但您可以定义多个方法或使用不同的方法名称。

构造/初始化#

现在我们想构造和使用 MLP 模块

mlp = MLP(hidden_size=5, out_size=3)
x = jax.numpy.ones((1, 2))
variables = mlp.init(random.key(0), x)
y = mlp.apply(variables, x)

首先,我们构造一个 MLP 的实例并传递构造属性。请注意,如果您不习惯函数式编程模式,这里的构造与您期望的有所不同。MLP 构造函数实际上不创建任何变量或任何内部状态。最好将其视为模块的规范或模板,其中包含功能但没有数据。

让我们仔细看看初始化。令人惊讶的是,Flax 中实际上没有单独的初始化路径。调用 init 只是 apply 的一种特殊情况,您也可以将其写成

# equivalent to: variables = mlp.init(random.key(0), x)
_, variables = mlp.apply({}, x, rngs={"params": random.key(0)}, mutable=True)

因此,init 只是 apply 的一个包装器,其中

  1. 我们调用一个没有初始变量的模块(一个空字典)。

  2. 始终传递一个名为 "params" 的 PRNG 生成器,用于随机初始化参数(使用参数初始化函数)。

  3. 所有变量集合都设置为可变(mutable=True)。当集合是可变的时,可以更新现有变量,并且可以创建新变量。因此,在 init 内部,可以在任何变量集合中初始化变量,并且它们都添加到返回的变量字典中。

生命周期#

现在您已经了解了 initapply 的一种特殊情况,让我们更详细地看看 .apply(...)。实际上,模块的大部分复杂性都存在于 apply 方法中。“模块生命周期”包括构造和 apply -ing 一个模块。我们可以将模块生命周期总结如下:

  1. 我们构造 mlp = MLP(hidden_size=5, out_size=3),使得 mlp.hidden_size=5mlp.out_size=3

  2. 然后,调用 mlp.apply,它会

    1. 创建一个 mlp 的克隆,我们称之为 mlp_copy

    2. 调用 mlp_copy.setup()

    3. 返回 mlp_copy.__call__() 的输出,并可选地返回使用关键字参数 mutable= 指定为可变的变量集合。

请注意,生命周期包括克隆模块实例。这样做是为了确保 apply 可以被视为纯函数(即,如果您传入相同的参数,它将返回相同的输出)。您将在后面的顶级模块部分中了解更多详细信息。

变量#

“变量”这个词在编程和数学中无处不在。但是,重要的是要很好地理解变量在 JAX 和 Flax 的上下文中的含义。在 Flax 模块内部,变量的行为就像您对 Python 的期望一样。它们被初始化一次,读取,甚至可能不时更新。但是,JAX 没有变量的概念。相反,值存储在类似于 NumPy 数组的数组中 - 但有一个重要的区别:它们是不可变的。

initapply 方法将变量作为嵌套字典返回,其中字符串键和 JAX 数组位于叶节点。在顶层,每个键对应一个变量集合。在每个集合内部,嵌套的字典结构与 Module 层次结构相对应。变量字典是不可变的,因此实际上只是变量所处状态的快照。当再次调用 apply 时,变量字典将作为参数传递。这样,变量的状态与上一次 init / apply 调用完成时的状态相同。

注意

模块字段使用 field_name: TypeHint 语法声明(与数据类相同)。如果没有类型提示,则属性被视为该类的静态属性。如果您无法指定类型,则可以使用 typing.Any 作为通配符类型。

紧凑模块#

Linen 提供了一种替代 API,可以更紧凑地定义模块。这对于模块仅包含一个使用参数和/或子模块的方法的常见情况尤其有用。使用紧凑的 API,可以将 MLP 重写如下:

class CompactMLP(nn.Module):
  hidden_size: int
  out_size: int

  @nn.compact
  def __call__(self, x):
    a = nn.Dense(self.hidden_size)(x)
    h = nn.relu(a)
    return nn.Dense(self.out_size)(h)

一个紧凑的 Module 在精神上类似于一个函数。它提供了一种简洁的表示法,并将外部交互限制在函数的输入和返回值上。在这种情况下,简洁的表示法可能使其他人更容易理解模块的作用。无需在 setup__call__ 方法之间来回跳转来理解子模块正在做什么。相反,只需从上到下阅读一次 __call__ 方法,就可以获得一个简洁的概述。如果您正在实现具有许多超参数的复杂模块,这可能会产生重大影响。有关如何在 setup 和 compact 之间做出选择的实用指南,请参阅setup 或 compact

内联定义子模块和/或变量的另一个好处是,您可以在构造变量时向方法添加参数。最常见的例子是使用形状信息来确定参数的形状,如下所示

class CompactScaledMLP(nn.Module):
  hidden_size: int
  out_size: int

  @nn.compact
  def __call__(self, x):
    scale = self.param("scale", nn.initializers.ones_init(), x.shape[-1:])
    x *= scale[None]
    a = nn.Dense(self.hidden_size)(x)
    h = nn.relu(a)
    return nn.Dense(self.out_size)(h)

许多标准的 Linen 模块,如 nn.Dense,已经使用了形状推断,从而避免了指定输入形状(如 Dense 层的输入特征数量)的需求。

紧凑的控制流#

如果您没有显式提供子模块的名称(使用传递给模块构造函数的 name= 关键字参数),则您定义子模块的顺序将决定子模块的名称。由于 name 决定了参数如何映射到子模块,因此您必须小心地将控制流与自动生成的名称混合使用。使用控制流可能会更改顺序或完全删除某些子模块。如果子模块仅应在某些构造参数存在的情况下才存在,则此功能很有用。但是,当控制流取决于模块的输入参数时,您应该小心。例如,以下模块会中断

class WrongModule(nn.Module):
  @nn.compact
  def __call__(self, x, mode):
    if mode == "encode":
      return nn.Dense(features=8)(x)
    elif mode == "decode":
      return nn.Dense(features=4)(x)

上面的模块会中断,因为编码器或解码器路径都会构造一个名为“Dense_0”的模块。这意味着这两个模块将共享参数,这不是预期的。实际上,这两个模块不能共享参数,因为它们各自具有不同的特征数量。

这个问题可以通过多种方式解决
  • 提供显式名称

  • setup 中创建模块

  • 或将构造函数移出控制流。

后者如下所示

class CorrectModule(nn.Module):
  @nn.compact
  def __call__(self, x, mode):
    encoder = nn.Dense(8)
    decoder = nn.Dense(4)
    if mode == "encode":
      return encoder(x)
    elif mode == "decode":
      return decoder(x)

在上面的示例中,构造顺序是固定的。构造完成后,可以以任意顺序使用子模块。

注意

紧凑的模块与 React hooks 非常相似。

顶层模块#

当在“顶层”创建模块实例时,它将处于“未绑定”状态,也就是说,它没有附加任何变量。“顶层”意味着它不是作为另一个模块类中的子模块构建的。除了调用 initapply 之外,您对未绑定的模块无能为力。还要注意,不会在未绑定的模块上调用 setup,因此您只能访问构造参数。请参阅 未来工作 部分,了解将来这种情况可能会如何变化。

为什么顶层模块始终未绑定?#

当我们调用 apply 时,会创建一个顶层模块的副本,该副本实际上将保存变量和 PRNG 序列。这种有状态的“绑定”克隆仅在我们执行 apply 方法时存在。这样做的原因是,如果您创建一个有状态的对象并在 apply 函数返回之前将其销毁,则 apply 函数本身的行为就像一个纯函数。纯函数有两个约束

  1. 如果您输入相同的参数,它将返回相同的输出

  2. 它不会更改函数外部的任何内容。这意味着您不能操作在纯函数外部可访问的有状态对象。

纯函数有很多优点,但在使用 JAX 时,它们通常是必不可少的。例如,大多数代码需要使用 jax.jit 进行编译才能快速运行,并且一旦创建了模块,您可能希望使用 jax.grad 优化其参数。但是,这些 API 需要一个纯函数,并且不能直接在有状态的绑定 Module 实例上工作。此外,纯函数允许与其他库进行灵活的互操作性。例如,我们建议使用 Optax 来优化参数。Optax 中的优化器期望并返回 JAX 数组的 PyTree 以进行优化,就像 Linen 模块的 apply 函数一样。

克隆#

为了使这种方法可靠地工作,我们需要明确定义的克隆行为。Flax 没有像 Python 的 deepcopy 那样依赖复杂的嵌套克隆过程,而是强制执行 Module 完全由其构造参数定义。因此,克隆模块简化为使用其原始构造参数调用构造函数。由于 Module 充当不可变的 dataclass,因此构造参数直接映射到实例属性。在 setup__post_init__ 中计算的非构造属性也应仅依赖于构造参数,以确保明确定义的克隆。

绑定#

有时,拥有一个已绑定的顶层模块而无需将代码包装在函数中会很有用。例如:在 Jupyter 笔记本中与模块进行交互。bind 方法返回具有无限生命周期的绑定克隆。缺点是您不能将其与 JAX 转换结合使用,也不能将其集成到需要无状态代码的 vanilla JAX 代码库中。例如,Optax 可以优化参数的 PyTree,但它不能直接优化使用 .bind 创建的绑定的 Module 实例(因为那不是 PyTree)。因此,您不能将 bind API 与像 Optax 这样的函数式优化器 API 结合使用。

设置#

setup 方法通常用作普通 Python 类中的构造函数钩子 (__init__)。但是,对于更高级的用例,最好意识到它与构造函数并不完全相同。

setup 仅在模块绑定后才会被调用。通常,这不是问题,因为大多数模块会(几乎)立即绑定(作为 initapply 的一部分)。在 setup 中,子模块在分配给属性时会变为绑定。在 nn.compact 修饰的方法中,子模块会在构造时立即绑定。如上一节所述,顶层模块永远不会绑定,因此在构造时不会调用 setup。这意味着您无法从未绑定的顶层模块访问在 setup 中分配的属性。

class TopLevelAccess(nn.Module):

  def setup(self):
    self.foo = nn.Dense(2)

mdl = TopLevelAccess()
assert not hasattr(mdl, "foo")  # foo is not defined because setup is not called

setup 方法不是在 Module 绑定后立即调用,而是在您与 Module 实例交互时才调用(例如:调用方法或访问属性)。这不应影响 Module 的行为,但延迟执行有时会影响调试期间的日志语句和堆栈跟踪。有关 函数化 的部分将解释为什么我们需要 setup 首先是延迟的。

函数化#

到目前为止,我们有一个纯 apply 函数,它通常使用一些 JAX 转换进行转换,并且在 apply 内部,我们有一个有状态的模块实例可以使用。换句话说:在模块外部,我们处于一个函数式世界中,我们可以利用 JAX 的函数式转换,而在模块内部,我们可以利用 Flax 的有状态变量和 PRNG 序列,并且 apply 方法是我们这两个世界之间的桥梁。

但是,如果我们想在模块内部使用 JAX 转换怎么办?答案是函数化。

此过程本身很繁琐且容易出错,但由 Flax 在内部处理。在高层,我们可以将其总结如下。对于在模块中定义的方法 fn

  1. 收集应在 JAX 转换内部可用的模块的状态(变量和 PRNG 序列)并拍摄快照。

  2. 使用原始参数和收集的状态调用 JAX 转换。然后在转换内部

    1. 解包状态并重新创建模块

    2. 调用用户代码 fn

    3. 收集更新的变量和 rng,并将其与 fn 中的原始返回值一起返回

  3. 使用从转换返回的更新状态更新原始状态。

有关函数化和提升的更深入解释,请参阅提升的转换设计说明。

实际影响#

在大多数情况下,函数化是自动为你处理的。但仍然有一些约束需要你考虑。最重要的是,Flax 只处理有状态的原语(Linen 变量和 RNG),而不是任意的有状态 Python 代码。最重要的是:你不能闭包有状态对象和 Module 对象,因为它们对 Flax 的内部机制(以及一般的 JAX)是不可见的。

class Foo(nn.Module):
  @nn.compact
  def __call__(self, x):
    dense = nn.Dense(x.shape[-1])
    fn = lambda x: dense(x) + 1
    # simply calling inner works fine
    # return self.inner(x, fn)
    # but applying a transformation doesn't:
    vmap_inner = nn.vmap(Foo.inner, in_axes=0, variable_axes={"params": 0}, split_rngs={"params": True})
    return vmap_inner(self, x, fn)

  def inner(self, x, fn):
    for i in range(3):
      x = fn(x)
    return x

这里 inner 接受一个闭包 Module 实例的函数。在这个例子中,这可以正常工作,因为我们没有使用提升的转换来转换 inner 方法。大多数方法不会被转换,但了解如何使 Module 方法可转换是很好的。

可转换性的主要障碍是 JAX 不识别的类型。JAX 只理解 Pytree 参数;即 (Jax) numpy ndarrays 和 Python 数字/布尔值的任意嵌套的 Python 容器(字典、列表、元组)。Flax 允许使用 flax.struct API 定义与 Pytree 兼容的数据类。

函数闭包是意外地从转换中隐藏 JAX 数组或 Linen Module 的最常见方式。但是,如果你想传递也与 JAX 和 Linen 转换兼容的闭包,则有一个简单的解决方法

class Partial(flax.struct.PyTreeNode):
  fn: Callable = flax.struct.field(pytree_node=False)
  args: Iterable[Any]

  def __call__(self, *args, **kwargs):
    return self.fn(*(tuple(self.args) + args), **kwargs)

class Foo(nn.Module):

  @nn.compact
  def __call__(self, x):
    dense = nn.Dense(x.shape[-1])
    fn = lambda mdl, x: mdl(x) + 1
    vmap_inner = nn.vmap(Foo.inner, in_axes=0, variable_axes={"params": 0}, split_rngs={"params": True})
    return vmap_inner(self, x, Partial(fn, [dense]))

  def inner(self, x, fn):
    for i in range(3):
      x = fn(x)
    return x

这里,闭包是使用 Flax 数据类实现的。函数本身用 flax.struct.field(pytree_node=False) 注释,以表明它不包含 JAX 数组或 Linen Module。另一方面,部分应用的 args 被视为 pytree 容器。我们将闭包重写为使用 Partial。现在可以使用提升的转换来转换 inner 方法。

未来工作#

为未绑定模块设置#

当涉及到构造后初始化字段时,当前的 Module 抽象尤其具有限制性。在当前的 Module API 中,setup 方法是初始化 Module 实例字段的地方。由于 setup 仅在绑定 Module 上调用,因此完整的 Module API 在 setup 内部可用,包括变量声明。但是,通常我们实际上不需要任何有状态的 API 来初始化字段。事实上,最常见的情况是我们只是想声明一个子模块。更重要的是,检查子模块以进行调试或部分运行模型通常很有用。例如考虑

class AutoEncoder(nn.Module):
  def setup(self):
    self.encoder = Encoder(...)
    self.decoder = Decoder(...)

想象一下,我们只想使用 auto_encoder.decoder.apply(decoder_variables, x) 调用解码器。使用当前的 setup API,这是行不通的,因为我们必须先绑定变量,然后才能调用 setup 并定义解码器属性。当然,我们可以使用与 setup 中相同的属性手动构造 Decoder Module,但这在许多情况下并不理想。

有两种可能的解决方案可以使此用例更符合人体工程学。首先,可以在绑定之前立即在构造后运行 setup。这意味着你仍然可以创建子模块,但你不能再定义或操作变量。因此,这将是一个重大更改,并且需要一个新的 API 来延迟定义变量

或者,可以引入一个额外的特殊方法,该方法在 Module 构造后并在绑定之前立即运行。在这种情况下,setup 方法将保留其原始语义。