处理 Flax 模块参数#
简介#
在 Flax Linen 中,我们可以将 Module
参数定义为 dataclass 属性或方法的参数(通常是 __call__
)。通常,这种区别很明显
完全固定的属性,例如内核初始化器的选择或输出特征的数量,是超参数,应定义为 dataclass 属性。通常,具有不同超参数的两个模块实例不能以有意义的方式共享。
动态属性,例如输入数据和顶层“模式开关”,例如
train=True/False
,应作为参数传递给__call__
或其他方法。
但是,有些情况不太明确。例如,以 Dropout
模块为例。我们有一些清晰的超参数
dropout 率
生成 dropout 掩码的轴
以及一些清晰的调用时参数
应该使用 dropout 掩码的输入
用于采样随机掩码的(可选)rng
但是,有一个属性不明确 - Dropout 模块中的 deterministic
属性。
如果 deterministic
为 True
,则不采样 dropout 掩码。这通常在模型评估期间使用。但是,如果我们向顶层模块传递 eval=True
或 train=False
,则 deterministic
参数需要在所有地方应用,并且布尔参数需要传递给所有可能使用 Dropout
的层。如果 deterministic
是 dataclass 属性,我们可以执行以下操作
from functools import partial
from flax import linen as nn
class ResidualModel(nn.Module):
drop_rate: float
@nn.compact
def __call__(self, x, *, train):
dropout = partial(nn.Dropout, rate=self.drop_rate, deterministic=not train)
for i in range(10):
x += ResidualBlock(dropout=dropout, ...)(x)
在这里将 determinstic
传递给构造函数是有意义的,因为这样我们可以将 dropout 模板传递给子模块。现在子模块不再需要处理训练与评估模式,只需使用 dropout
参数即可。请注意,由于 dropout 层只能在子模块中构造,因此我们只能将 deterministic
部分应用于构造函数,而不是应用于 __call__
。
但是,如果 deterministic
是 dataclass 属性,我们在使用 setup 模式时会遇到问题。我们**希望**像这样编写我们的模块代码
class SomeModule(nn.Module):
drop_rate: float
def setup(self):
self.dropout = nn.Dropout(rate=self.drop_rate)
@nn.compact
def __call__(self, x, *, train):
# ...
x = self.dropout(x, deterministic=not train)
# ...
但是,如上所述,deterministic
将是一个属性,因此这不起作用。在这里,在 __call__
期间传递 deterministic
是有意义的,因为它取决于 train
参数。
解决方案#
我们可以通过允许某些属性作为 dataclass 属性或作为方法参数传递(但不能同时传递!)来支持之前描述的两种用例。这可以按如下方式实现
class MyDropout(nn.Module):
drop_rate: float
deterministic: Optional[bool] = None
@nn.compact
def __call__(self, x, deterministic=None):
deterministic = nn.merge_param('deterministic', self.deterministic, deterministic)
# ...
在此示例中,nn.merge_param
将确保设置 self.deterministic
或 deterministic
,但不能同时设置两者。如果两个值都为 None
或两个值都不为 None
,则会引发错误。这避免了代码的两个不同部分设置相同参数,并且一个被另一个覆盖的混乱行为。它还避免了默认值,默认值可能会导致训练过程的训练步骤或评估步骤被破坏。
函数式核心#
函数式核心定义函数而不是类。因此,超参数和调用时参数之间没有明确的区别。预先确定超参数的唯一方法是使用 partial
。另一方面,不存在方法参数也可能是属性的模糊情况。