setupcompact#

在 Flax 的模块系统(名为 Linen)中,子模块和变量(参数或其他)可以通过两种方式定义

  1. 显式地(使用 setup

    setup 方法中,将子模块或变量赋值给 self.<attr>。然后在类中定义的任何“前向传递”方法中使用分配给 self.<attr> 的子模块和变量。这类似于在 PyTorch 中定义模块的方式。

  2. 内联地(使用 nn.compact

    直接在用 nn.compact 注释的单个“前向传递”方法中编写网络的逻辑。这允许您在单个方法中定义整个模块,并将子模块和变量“并置”在它们被使用的地方。

这两种方法都是完全有效的,行为方式相同,并且可以与 Flax 的所有部分互操作.

这是一个以两种方式定义的模块的简短示例,它们的功能完全相同。

class MLP(nn.Module):
  def setup(self):
    # Submodule names are derived by the attributes you assign to. In this
    # case, "dense1" and "dense2". This follows the logic in PyTorch.
    self.dense1 = nn.Dense(32)
    self.dense2 = nn.Dense(32)

  def __call__(self, x):
    x = self.dense1(x)
    x = nn.relu(x)
    x = self.dense2(x)
    return x
class MLP(nn.Module):

  @nn.compact
  def __call__(self, x):
    x = nn.Dense(32, name="dense1")(x)
    x = nn.relu(x)
    x = nn.Dense(32, name="dense2")(x)
    return x

那么,您如何决定使用哪种风格呢?这可能取决于个人喜好,但这里有一些优缺点

倾向于使用 nn.compact 的理由:#

  1. 允许在它们被使用的地方旁边定义子模块、参数和其他变量:减少向上/向下滚动以查看所有内容的定义。

  2. 当存在有条件地定义子模块、参数或变量的条件语句或 for 循环时,减少代码重复。

  3. 代码通常看起来更像数学符号:y = self.param('W', ...) @ x + self.param('b', ...) 看起来类似于 \(y=Wx+b\))

  4. 如果您正在使用形状推断,即使用其形状/值取决于输入形状(在初始化时未知)的参数,则使用 setup 是不可能的。

倾向于使用 setup 的理由:#

  1. 更接近 PyTorch 的约定,因此从 PyTorch 移植模型时更容易

  2. 有些人觉得将子模块和变量的定义与它们的使用位置显式分离更自然

  3. 允许定义多个“前向传递”方法(请参阅 MultipleMethodsCompactError