从 Haiku 迁移到 Flax#

本指南将逐步介绍将 Haiku 模型迁移到 Flax 的过程,并强调两个库之间的差异。

基本示例#

要创建自定义模块,你需要在 Haiku 和 Flax 中都从 Module 基类继承。但是,Haiku 类使用常规的 __init__ 方法,而 Flax 类是 dataclasses,这意味着你定义一些用于自动生成构造函数的类属性。此外,所有 Flax 模块都接受一个 name 参数而无需定义它,而在 Haiku 中,name 必须在构造函数签名中显式定义并传递给超类构造函数。

import haiku as hk

class Block(hk.Module):
  def __init__(self, features: int, name=None):
    super().__init__(name=name)
    self.features = features

  def __call__(self, x, training: bool):
    x = hk.Linear(self.features)(x)
    x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
    x = jax.nn.relu(x)
    return x

class Model(hk.Module):
  def __init__(self, dmid: int, dout: int, name=None):
    super().__init__(name=name)
    self.dmid = dmid
    self.dout = dout

  def __call__(self, x, training: bool):
    x = Block(self.dmid)(x, training)
    x = hk.Linear(self.dout)(x)
    return x
import flax.linen as nn

class Block(nn.Module):
  features: int


  @nn.compact
  def __call__(self, x, training: bool):
    x = nn.Dense(self.features)(x)
    x = nn.Dropout(0.5, deterministic=not training)(x)
    x = jax.nn.relu(x)
    return x

class Model(nn.Module):
  dmid: int
  dout: int


  @nn.compact
  def __call__(self, x, training: bool):
    x = Block(self.dmid)(x, training)
    x = nn.Dense(self.dout)(x)
    return x

在两个库中,__call__ 方法看起来非常相似,但是,在 Flax 中,你必须使用 @nn.compact 装饰器才能内联定义子模块。在 Haiku 中,这是默认行为。

现在,Haiku 和 Flax 差异很大的地方在于你如何构建模型。在 Haiku 中,你在调用你的模块的函数上使用 hk.transformtransform 将返回一个带有 initapply 方法的对象。在 Flax 中,你只需实例化你的模块。

def forward(x, training: bool):
  return Model(256, 10)(x, training)

model = hk.transform(forward)
...


model = Model(256, 10)

要在两个库中获取模型参数,你使用带有 random.key 以及一些运行模型的输入的 init 方法。这里的主要区别在于,Flax 返回从集合名称到嵌套数组字典的映射,params 只是这些可能的集合之一。在 Haiku 中,你直接获取 params 结构。

sample_x = jax.numpy.ones((1, 784))
params = model.init(
  random.key(0),
  sample_x, training=False # <== inputs
)
...
sample_x = jax.numpy.ones((1, 784))
variables = model.init(
  random.key(0),
  sample_x, training=False # <== inputs
)
params = variables["params"]

需要注意的一个非常重要的事情是,在 Flax 中,参数结构是分层的,每个嵌套模块一级,参数名称最后一级。在 Haiku 中,参数结构是一个 python 字典,具有两层层次结构:完全限定的模块名称映射到参数名称。模块名称由所有嵌套模块的 / 分隔的字符串路径组成。

...
{
  'model/block/linear': {
    'b': (256,),
    'w': (784, 256),
  },
  'model/linear': {
    'b': (10,),
    'w': (256, 10),
  }
}
...
FrozenDict({
  Block_0: {
    Dense_0: {
      bias: (256,),
      kernel: (784, 256),
    },
  },
  Dense_0: {
    bias: (10,),
    kernel: (256, 10),
  },
})

在两个框架中的训练期间,你将参数结构传递给 apply 方法以运行前向传递。由于我们正在使用 dropout,因此在这两种情况下,我们都必须向 apply 提供一个 key 以生成随机 dropout 掩码。

def train_step(key, params, inputs, labels):
  def loss_fn(params):
      logits = model.apply(
        params,
        key,
        inputs, training=True # <== inputs
      )
      return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = jax.grad(loss_fn)(params)
  params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)

  return params
def train_step(key, params, inputs, labels):
  def loss_fn(params):
      logits = model.apply(
        {'params': params},
        inputs, training=True, # <== inputs
        rngs={'dropout': key}
      )
      return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

  grads = jax.grad(loss_fn)(params)
  params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)

  return params

最显著的区别是,在 Flax 中,你必须将参数传递到带有 params 键的字典中,并将键传递到带有 dropout 键的字典中。这是因为在 Flax 中,你可以有许多类型的模型状态和随机状态。在 Haiku 中,你只需直接传递参数和键。

处理状态#

现在让我们看看两个库中如何处理可变状态。我们将采用与之前相同的模型,但现在我们将用 BatchNorm 替换 Dropout。

class Block(hk.Module):
  def __init__(self, features: int, name=None):
    super().__init__(name=name)
    self.features = features

  def __call__(self, x, training: bool):
    x = hk.Linear(self.features)(x)
    x = hk.BatchNorm(
      create_scale=True, create_offset=True, decay_rate=0.99
    )(x, is_training=training)
    x = jax.nn.relu(x)
    return x
class Block(nn.Module):
  features: int


  @nn.compact
  def __call__(self, x, training: bool):
    x = nn.Dense(self.features)(x)
    x = nn.BatchNorm(
      momentum=0.99
    )(x, use_running_average=not training)
    x = jax.nn.relu(x)
    return x

在这种情况下,代码非常相似,因为这两个库都提供了 BatchNorm 层。最显著的区别在于,Haiku 使用 is_training 来控制是否更新运行统计数据,而 Flax 使用 use_running_average 来达到相同的目的。

要实例化 Haiku 中的有状态模型,你使用 hk.transform_with_state,它会更改 initapply 的签名以接受和返回状态。如前所述,在 Flax 中,你直接构造模块。

def forward(x, training: bool):
  return Model(256, 10)(x, training)

model = hk.transform_with_state(forward)
...


model = Model(256, 10)

要初始化参数和状态,你只需像之前一样调用 init 方法。但是,在 Haiku 中,你现在会得到 state 作为第二个返回值,而在 Flax 中,你会在 variables 字典中获得一个新的 batch_stats 集合。请注意,由于 hk.BatchNorm 仅在 is_training=True 时初始化批量统计数据,因此我们在初始化带有 hk.BatchNorm 层的 Haiku 模型的参数时必须设置 training=True。在 Flax 中,我们可以像往常一样设置 training=False

sample_x = jax.numpy.ones((1, 784))
params, state = model.init(
  random.key(0),
  sample_x, training=True # <== inputs
)
...
sample_x = jax.numpy.ones((1, 784))
variables = model.init(
  random.key(0),
  sample_x, training=False # <== inputs
)
params, batch_stats = variables["params"], variables["batch_stats"]

一般来说,在 Flax 中,你可能会在 variables 字典中找到其他状态集合,例如用于自回归转换器模型的 cache、用于使用 Module.sow 添加的中间值的 intermediates,或者自定义层定义的其他集合名称。Haiku 仅区分 params(在运行 apply 时不会更改的变量)和 state(在运行 apply 时可以更改的变量)。

现在,训练在两个框架中看起来非常相似,因为你使用相同的 apply 方法来运行前向传递。在 Haiku 中,现在将 state 作为第二个参数传递给 apply,并获得新的状态作为第二个返回值。在 Flax 中,你将 batch_stats 作为新键添加到输入字典中,并获得 updates 变量字典作为第二个返回值。

def train_step(params, state, inputs, labels):
  def loss_fn(params):
    logits, new_state = model.apply(
      params, state,
      None, # <== rng
      inputs, training=True # <== inputs
    )
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
    return loss, new_state

  grads, new_state = jax.grad(loss_fn, has_aux=True)(params)
  params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)

  return params, new_state
def train_step(params, batch_stats, inputs, labels):
  def loss_fn(params):
    logits, updates = model.apply(
      {'params': params, 'batch_stats': batch_stats},
      inputs, training=True, # <== inputs
      mutable='batch_stats',
    )
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
    return loss, updates["batch_stats"]

  grads, batch_stats = jax.grad(loss_fn, has_aux=True)(params)
  params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)

  return params, batch_stats

一个主要的区别是,在 Flax 中,状态集合可以是可变的或不可变的。在 init 期间,所有集合默认都是可变的,但是,在 apply 期间,你必须明确指定哪些集合是可变的。在这个例子中,我们指定 batch_stats 是可变的。这里传递了一个字符串,但如果有多个可变集合,也可以传递列表。如果没有这样做,当试图修改 batch_stats 时,会在运行时引发错误。此外,当 mutable 不是 False 时,updates 字典会作为 apply 的第二个返回值返回,否则只返回模型输出。Haiku 通过使用 params(不可变)和 state(可变),以及使用 hk.transformhk.transform_with_state 来区分可变/不可变。

使用多种方法#

在本节中,我们将了解如何在 Haiku 和 Flax 中使用多种方法。作为一个例子,我们将实现一个具有三种方法的自动编码器模型:encodedecode__call__

在 Haiku 中,我们可以直接在 __init__ 中定义 encodedecode 需要的子模块,在这种情况下,每个子模块都只使用一个 Linear 层。在 Flax 中,我们会在 setup 中预先定义一个 encoder 和一个 decoder 模块,并在 encodedecode 中分别使用它们。

class AutoEncoder(hk.Module):


  def __init__(self, embed_dim: int, output_dim: int, name=None):
    super().__init__(name=name)
    self.encoder = hk.Linear(embed_dim, name="encoder")
    self.decoder = hk.Linear(output_dim, name="decoder")

  def encode(self, x):
    return self.encoder(x)

  def decode(self, x):
    return self.decoder(x)

  def __call__(self, x):
    x = self.encode(x)
    x = self.decode(x)
    return x
class AutoEncoder(nn.Module):
  embed_dim: int
  output_dim: int

  def setup(self):
    self.encoder = nn.Dense(self.embed_dim)
    self.decoder = nn.Dense(self.output_dim)

  def encode(self, x):
    return self.encoder(x)

  def decode(self, x):
    return self.decoder(x)

  def __call__(self, x):
    x = self.encode(x)
    x = self.decode(x)
    return x

请注意,在 Flax 中,setup 不会在 __init__ 之后运行,而是在调用 initapply 时运行。

现在,我们希望能够从我们的 AutoEncoder 模型中调用任何方法。在 Haiku 中,我们可以通过 hk.multi_transform 为一个模块定义多个 apply 方法。传递给 multi_transform 的函数定义了如何初始化模块以及生成哪些不同的应用方法。

def forward():
  module = AutoEncoder(256, 784)
  init = lambda x: module(x)
  return init, (module.encode, module.decode)

model = hk.multi_transform(forward)
...




model = AutoEncoder(256, 784)

要初始化我们模型的参数,可以使用 init 来触发 __call__ 方法,该方法同时使用 encodedecode 方法。这将为模型创建所有必要的参数。

params = model.init(
  random.key(0),
  x=jax.numpy.ones((1, 784)),
)
...
variables = model.init(
  random.key(0),
  x=jax.numpy.ones((1, 784)),
)
params = variables["params"]

这将生成以下参数结构。

{
    'auto_encoder/~/decoder': {
        'b': (784,),
        'w': (256, 784)
    },
    'auto_encoder/~/encoder': {
        'b': (256,),
        'w': (784, 256)
    }
}
FrozenDict({
    decoder: {
        bias: (784,),
        kernel: (256, 784),
    },
    encoder: {
        bias: (256,),
        kernel: (784, 256),
    },
})

最后,让我们探讨一下如何使用 apply 函数来调用 encode 方法

encode, decode = model.apply
z = encode(
  params,
  None, # <== rng
  x=jax.numpy.ones((1, 784)),

)
...
z = model.apply(
  {"params": params},

  x=jax.numpy.ones((1, 784)),
  method="encode",
)

因为 Haiku 的 apply 函数是通过 hk.multi_transform 生成的,所以它是一个由两个函数组成的元组,我们可以将其解包为 encodedecode 函数,它们对应于 AutoEncoder 模块上的方法。在 Flax 中,我们通过传递方法名称作为字符串来调用 encode 方法。另一个值得注意的区别是,在 Haiku 中,rng 需要显式传递,即使该模块在 apply 期间不使用任何随机操作。在 Flax 中,这是不必要的(请查看 Flax 中的随机性和 PRNG)。这里的 Haiku rng 设置为 None,但你也可以在 apply 函数上使用 hk.without_apply_rng 来删除 rng 参数。

提升的变换#

Flax 和 Haiku 都提供了一组变换,我们将其称为提升的变换,它们以一种可以与模块一起使用的方式封装了 JAX 变换,有时还提供额外的功能。在本节中,我们将了解如何在 Flax 和 Haiku 中使用提升版本的 scan 来实现一个简单的 RNN 层。

首先,我们将定义一个 RNNCell 模块,它将包含 RNN 的单步逻辑。我们还将定义一个 initial_state 方法,该方法将用于初始化 RNN 的状态(又名 carry)。与 jax.lax.scan 类似,RNNCell.__call__ 方法将是一个接受 carry 和输入并返回新的 carry 和输出的函数。在这种情况下,carry 和输出是相同的。

class RNNCell(hk.Module):
  def __init__(self, hidden_size: int, name=None):
    super().__init__(name=name)
    self.hidden_size = hidden_size

  def __call__(self, carry, x):
    x = jnp.concatenate([carry, x], axis=-1)
    x = hk.Linear(self.hidden_size)(x)
    x = jax.nn.relu(x)
    return x, x

  def initial_state(self, batch_size: int):
    return jnp.zeros((batch_size, self.hidden_size))
class RNNCell(nn.Module):
  hidden_size: int


  @nn.compact
  def __call__(self, carry, x):
    x = jnp.concatenate([carry, x], axis=-1)
    x = nn.Dense(self.hidden_size)(x)
    x = jax.nn.relu(x)
    return x, x

  def initial_state(self, batch_size: int):
    return jnp.zeros((batch_size, self.hidden_size))

接下来,我们将定义一个 RNN 模块,它将包含整个 RNN 的逻辑。在 Haiku 中,我们将首先初始化 RNNCell,然后使用它来构造 carry,最后使用 hk.scan 在输入序列上运行 RNNCell。在 Flax 中,它的做法略有不同,我们将使用 nn.scan 来定义一个临时的新的类型,该类型封装了 RNNCell。在这个过程中,我们还将指定指示 nn.scan 广播 params 集合(所有步骤共享相同的参数),并且不拆分 params rng 流(因此所有步骤都使用相同的参数初始化),最后我们将指定我们希望 scan 在输入的第二个轴上运行,并将输出也沿着第二个轴堆叠。然后,我们将立即使用这个临时类型来创建提升的 RNNCell 的实例,并使用它来创建 carry 并运行 __call__ 方法,该方法将 scan 遍历序列。

class RNN(hk.Module):
  def __init__(self, hidden_size: int, name=None):
    super().__init__(name=name)
    self.hidden_size = hidden_size

  def __call__(self, x):
    cell = RNNCell(self.hidden_size)
    carry = cell.initial_state(x.shape[0])
    carry, y = hk.scan(cell, carry, jnp.swapaxes(x, 1, 0))
    y = jnp.swapaxes(y, 0, 1)
    return y
class RNN(nn.Module):
  hidden_size: int


  @nn.compact
  def __call__(self, x):
    rnn = nn.scan(RNNCell, variable_broadcast='params', split_rngs={'params': False},
                  in_axes=1, out_axes=1)(self.hidden_size)
    carry = rnn.initial_state(x.shape[0])
    carry, y = rnn(carry, x)
    return y

总的来说,Flax 和 Haiku 之间提升的变换的主要区别在于,在 Haiku 中,提升的变换不会对状态进行操作,也就是说,Haiku 将以某种方式处理 paramsstate,使其在变换内部和外部保持相同的形状。在 Flax 中,提升的变换可以对变量集合和 rng 流进行操作,用户必须根据变换的语义定义每个变换如何处理不同的集合。

最后,让我们快速查看一下如何在 Haiku 和 Flax 中使用 RNN 模块。

def forward(x):
  return RNN(64)(x)

model = hk.without_apply_rng(hk.transform(forward))

params = model.init(
  random.key(0),
  x=jax.numpy.ones((3, 12, 32)),
)

y = model.apply(
  params,
  x=jax.numpy.ones((3, 12, 32)),
)
...


model = RNN(64)

variables = model.init(
  random.key(0),
  x=jax.numpy.ones((3, 12, 32)),
)
params = variables['params']
y = model.apply(
  {'params': params},
  x=jax.numpy.ones((3, 12, 32)),
)

与前几节中的示例相比,唯一值得注意的变化是,这次我们在 Haiku 中使用了 hk.without_apply_rng,因此我们不必将 rng 参数作为 None 传递给 apply 方法。

扫描层#

scan 的一个非常重要的应用是迭代地将一系列层应用于输入,将每一层的输出作为下一层的输入。这对于减少大型模型的编译时间非常有用。例如,我们将创建一个简单的 Block 模块,然后在 MLP 模块内部使用它,该模块将 Block 模块应用 num_layers 次。

在 Haiku 中,我们像往常一样定义 Block 模块,然后在 MLP 内部,我们将使用 hk.experimental.layer_stackstack_block 函数进行操作,以创建 Block 模块的堆叠。在 Flax 中,Block 的定义略有不同,__call__ 将接受并返回第二个虚拟输入/输出,在这两种情况下都将为 None。在 MLP 中,我们将像前面的示例一样使用 nn.scan,但通过设置 split_rngs={'params': True}variable_axes={'params': 0},我们告诉 nn.scan 为每个步骤创建不同的参数,并沿着第一个轴切分 params 集合,从而有效地实现与 Haiku 中相同的 Block 模块堆叠。

class Block(hk.Module):
  def __init__(self, features: int, name=None):
    super().__init__(name=name)
    self.features = features

  def __call__(self, x, training: bool):
    x = hk.Linear(self.features)(x)
    x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
    x = jax.nn.relu(x)
    return x

class MLP(hk.Module):
  def __init__(self, features: int, num_layers: int, name=None):
      super().__init__(name=name)
      self.features = features
      self.num_layers = num_layers

  def __call__(self, x, training: bool):
    @hk.experimental.layer_stack(self.num_layers)
    def stack_block(x):
      return Block(self.features)(x, training)

    stack = hk.experimental.layer_stack(self.num_layers)
    return stack_block(x)
class Block(nn.Module):
  features: int
  training: bool

  @nn.compact
  def __call__(self, x, _):
    x = nn.Dense(self.features)(x)
    x = nn.Dropout(0.5)(x, deterministic=not self.training)
    x = jax.nn.relu(x)
    return x, None

class MLP(nn.Module):
  features: int
  num_layers: int

  @nn.compact
  def __call__(self, x, training: bool):
    ScanBlock = nn.scan(
      Block, variable_axes={'params': 0}, split_rngs={'params': True},
      length=self.num_layers)

    y, _ = ScanBlock(self.features, training)(x, None)
    return y

请注意,在 Flax 中,我们如何将 None 作为第二个参数传递给 ScanBlock 并忽略其第二个输出。这些表示每个步骤的输入/输出,但它们是 None,因为在这种情况下我们没有任何输入/输出。

初始化每个模型与前面的示例相同。在这种情况下,我们将指定要使用 5 层,每层具有 64 个特征。

def forward(x, training: bool):
  return MLP(64, num_layers=5)(x, training)

model = hk.transform(forward)

sample_x = jax.numpy.ones((1, 64))
params = model.init(
  random.key(0),
  sample_x, training=False # <== inputs
)
...
...


model = MLP(64, num_layers=5)

sample_x = jax.numpy.ones((1, 64))
variables = model.init(
  random.key(0),
  sample_x, training=False # <== inputs
)
params = variables['params']

当在层上使用 scan 时,您应该注意到的一件事是,所有层都被融合到一个单层中,该单层的参数在第一个轴上有一个额外的“层”维度。在这种情况下,所有参数的形状都将以 (5, ...) 开头,因为我们正在使用 5 层。

...
{
    'mlp/__layer_stack_no_per_layer/block/linear': {
        'b': (5, 64),
        'w': (5, 64, 64)
    }
}
...
FrozenDict({
    ScanBlock_0: {
        Dense_0: {
            bias: (5, 64),
            kernel: (5, 64, 64),
        },
    },
})

顶层 Haiku 函数与顶层 Flax 模块#

在 Haiku 中,可以通过使用原始的 hk.{get,set}_{parameter,state} 来定义/访问模型参数和状态,从而将整个模型编写为单个函数。将顶层“模块”编写为函数是很常见的做法。

Flax 团队推荐一种更以模块为中心的方法,该方法使用 __call__ 来定义前向函数。对应的访问器将是 nn.module.paramnn.module.variable(有关集合的解释,请转到处理状态)。

def forward(x):


  counter = hk.get_state('counter', shape=[], dtype=jnp.int32, init=jnp.ones)
  multiplier = hk.get_parameter('multiplier', shape=[1,], dtype=x.dtype, init=jnp.ones)
  output = x + multiplier * counter
  hk.set_state("counter", counter + 1)

  return output

model = hk.transform_with_state(forward)

params, state = model.init(random.key(0), jax.numpy.ones((1, 64)))
class FooModule(nn.Module):
  @nn.compact
  def __call__(self, x):
    counter = self.variable('counter', 'count', lambda: jnp.ones((), jnp.int32))
    multiplier = self.param('multiplier', nn.initializers.ones_init(), [1,], x.dtype)
    output = x + multiplier * counter.value
    if not self.is_initializing():  # otherwise model.init() also increases it
      counter.value += 1
    return output

model = FooModule()
variables = model.init(random.key(0), jax.numpy.ones((1, 64)))
params, counter = variables['params'], variables['counter']