处理 Flax 模块参数

处理 Flax 模块参数#

简介#

在 Flax Linen 中,我们可以将 Module 参数定义为 dataclass 属性或方法的参数(通常是 __call__)。通常,这种区别很明显

  • 完全固定的属性,例如内核初始化器的选择或输出特征的数量,是超参数,应定义为 dataclass 属性。通常,具有不同超参数的两个模块实例不能以有意义的方式共享。

  • 动态属性,例如输入数据和顶层“模式开关”,例如 train=True/False,应作为参数传递给 __call__ 或其他方法。

但是,有些情况不太明确。例如,以 Dropout 模块为例。我们有一些清晰的超参数

  1. dropout 率

  2. 生成 dropout 掩码的轴

以及一些清晰的调用时参数

  1. 应该使用 dropout 掩码的输入

  2. 用于采样随机掩码的(可选)rng

但是,有一个属性不明确 - Dropout 模块中的 deterministic 属性。

如果 deterministicTrue,则不采样 dropout 掩码。这通常在模型评估期间使用。但是,如果我们向顶层模块传递 eval=Truetrain=False,则 deterministic 参数需要在所有地方应用,并且布尔参数需要传递给所有可能使用 Dropout 的层。如果 deterministic 是 dataclass 属性,我们可以执行以下操作

from functools import partial
from flax import linen as nn

class ResidualModel(nn.Module):
  drop_rate: float

  @nn.compact
  def __call__(self, x, *, train):
    dropout = partial(nn.Dropout, rate=self.drop_rate, deterministic=not train)
    for i in range(10):
      x += ResidualBlock(dropout=dropout, ...)(x)

在这里将 determinstic 传递给构造函数是有意义的,因为这样我们可以将 dropout 模板传递给子模块。现在子模块不再需要处理训练与评估模式,只需使用 dropout 参数即可。请注意,由于 dropout 层只能在子模块中构造,因此我们只能将 deterministic 部分应用于构造函数,而不是应用于 __call__

但是,如果 deterministic 是 dataclass 属性,我们在使用 setup 模式时会遇到问题。我们**希望**像这样编写我们的模块代码

class SomeModule(nn.Module):
  drop_rate: float

  def setup(self):
    self.dropout = nn.Dropout(rate=self.drop_rate)

  @nn.compact
  def __call__(self, x, *, train):
    # ...
    x = self.dropout(x, deterministic=not train)
    # ...

但是,如上所述,deterministic 将是一个属性,因此这不起作用。在这里,在 __call__ 期间传递 deterministic 是有意义的,因为它取决于 train 参数。

解决方案#

我们可以通过允许某些属性作为 dataclass 属性或作为方法参数传递(但不能同时传递!)来支持之前描述的两种用例。这可以按如下方式实现

class MyDropout(nn.Module):
  drop_rate: float
  deterministic: Optional[bool] = None

  @nn.compact
  def __call__(self, x, deterministic=None):
    deterministic = nn.merge_param('deterministic', self.deterministic, deterministic)
    # ...

在此示例中,nn.merge_param 将确保设置 self.deterministicdeterministic,但不能同时设置两者。如果两个值都为 None 或两个值都不为 None,则会引发错误。这避免了代码的两个不同部分设置相同参数,并且一个被另一个覆盖的混乱行为。它还避免了默认值,默认值可能会导致训练过程的训练步骤或评估步骤被破坏。

函数式核心#

函数式核心定义函数而不是类。因此,超参数和调用时参数之间没有明确的区别。预先确定超参数的唯一方法是使用 partial。另一方面,不存在方法参数也可能是属性的模糊情况。