从 Haiku 迁移到 Flax#
本指南将逐步介绍将 Haiku 模型迁移到 Flax 的过程,并强调两个库之间的差异。
基本示例#
要创建自定义模块,你需要在 Haiku 和 Flax 中都从 Module
基类继承。但是,Haiku 类使用常规的 __init__
方法,而 Flax 类是 dataclasses
,这意味着你定义一些用于自动生成构造函数的类属性。此外,所有 Flax 模块都接受一个 name
参数而无需定义它,而在 Haiku 中,name
必须在构造函数签名中显式定义并传递给超类构造函数。
import haiku as hk
class Block(hk.Module):
def __init__(self, features: int, name=None):
super().__init__(name=name)
self.features = features
def __call__(self, x, training: bool):
x = hk.Linear(self.features)(x)
x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
x = jax.nn.relu(x)
return x
class Model(hk.Module):
def __init__(self, dmid: int, dout: int, name=None):
super().__init__(name=name)
self.dmid = dmid
self.dout = dout
def __call__(self, x, training: bool):
x = Block(self.dmid)(x, training)
x = hk.Linear(self.dout)(x)
return x
import flax.linen as nn
class Block(nn.Module):
features: int
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(self.features)(x)
x = nn.Dropout(0.5, deterministic=not training)(x)
x = jax.nn.relu(x)
return x
class Model(nn.Module):
dmid: int
dout: int
@nn.compact
def __call__(self, x, training: bool):
x = Block(self.dmid)(x, training)
x = nn.Dense(self.dout)(x)
return x
在两个库中,__call__
方法看起来非常相似,但是,在 Flax 中,你必须使用 @nn.compact
装饰器才能内联定义子模块。在 Haiku 中,这是默认行为。
现在,Haiku 和 Flax 差异很大的地方在于你如何构建模型。在 Haiku 中,你在调用你的模块的函数上使用 hk.transform
,transform
将返回一个带有 init
和 apply
方法的对象。在 Flax 中,你只需实例化你的模块。
def forward(x, training: bool):
return Model(256, 10)(x, training)
model = hk.transform(forward)
...
model = Model(256, 10)
要在两个库中获取模型参数,你使用带有 random.key
以及一些运行模型的输入的 init
方法。这里的主要区别在于,Flax 返回从集合名称到嵌套数组字典的映射,params
只是这些可能的集合之一。在 Haiku 中,你直接获取 params
结构。
sample_x = jax.numpy.ones((1, 784))
params = model.init(
random.key(0),
sample_x, training=False # <== inputs
)
...
sample_x = jax.numpy.ones((1, 784))
variables = model.init(
random.key(0),
sample_x, training=False # <== inputs
)
params = variables["params"]
需要注意的一个非常重要的事情是,在 Flax 中,参数结构是分层的,每个嵌套模块一级,参数名称最后一级。在 Haiku 中,参数结构是一个 python 字典,具有两层层次结构:完全限定的模块名称映射到参数名称。模块名称由所有嵌套模块的 /
分隔的字符串路径组成。
...
{
'model/block/linear': {
'b': (256,),
'w': (784, 256),
},
'model/linear': {
'b': (10,),
'w': (256, 10),
}
}
...
FrozenDict({
Block_0: {
Dense_0: {
bias: (256,),
kernel: (784, 256),
},
},
Dense_0: {
bias: (10,),
kernel: (256, 10),
},
})
在两个框架中的训练期间,你将参数结构传递给 apply
方法以运行前向传递。由于我们正在使用 dropout,因此在这两种情况下,我们都必须向 apply
提供一个 key
以生成随机 dropout 掩码。
def train_step(key, params, inputs, labels):
def loss_fn(params):
logits = model.apply(
params,
key,
inputs, training=True # <== inputs
)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(params)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
return params
def train_step(key, params, inputs, labels):
def loss_fn(params):
logits = model.apply(
{'params': params},
inputs, training=True, # <== inputs
rngs={'dropout': key}
)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(params)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
return params
最显著的区别是,在 Flax 中,你必须将参数传递到带有 params
键的字典中,并将键传递到带有 dropout
键的字典中。这是因为在 Flax 中,你可以有许多类型的模型状态和随机状态。在 Haiku 中,你只需直接传递参数和键。
处理状态#
现在让我们看看两个库中如何处理可变状态。我们将采用与之前相同的模型,但现在我们将用 BatchNorm 替换 Dropout。
class Block(hk.Module):
def __init__(self, features: int, name=None):
super().__init__(name=name)
self.features = features
def __call__(self, x, training: bool):
x = hk.Linear(self.features)(x)
x = hk.BatchNorm(
create_scale=True, create_offset=True, decay_rate=0.99
)(x, is_training=training)
x = jax.nn.relu(x)
return x
class Block(nn.Module):
features: int
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(self.features)(x)
x = nn.BatchNorm(
momentum=0.99
)(x, use_running_average=not training)
x = jax.nn.relu(x)
return x
在这种情况下,代码非常相似,因为这两个库都提供了 BatchNorm 层。最显著的区别在于,Haiku 使用 is_training
来控制是否更新运行统计数据,而 Flax 使用 use_running_average
来达到相同的目的。
要实例化 Haiku 中的有状态模型,你使用 hk.transform_with_state
,它会更改 init
和 apply
的签名以接受和返回状态。如前所述,在 Flax 中,你直接构造模块。
def forward(x, training: bool):
return Model(256, 10)(x, training)
model = hk.transform_with_state(forward)
...
model = Model(256, 10)
要初始化参数和状态,你只需像之前一样调用 init
方法。但是,在 Haiku 中,你现在会得到 state
作为第二个返回值,而在 Flax 中,你会在 variables
字典中获得一个新的 batch_stats
集合。请注意,由于 hk.BatchNorm
仅在 is_training=True
时初始化批量统计数据,因此我们在初始化带有 hk.BatchNorm
层的 Haiku 模型的参数时必须设置 training=True
。在 Flax 中,我们可以像往常一样设置 training=False
。
sample_x = jax.numpy.ones((1, 784))
params, state = model.init(
random.key(0),
sample_x, training=True # <== inputs
)
...
sample_x = jax.numpy.ones((1, 784))
variables = model.init(
random.key(0),
sample_x, training=False # <== inputs
)
params, batch_stats = variables["params"], variables["batch_stats"]
一般来说,在 Flax 中,你可能会在 variables
字典中找到其他状态集合,例如用于自回归转换器模型的 cache
、用于使用 Module.sow
添加的中间值的 intermediates
,或者自定义层定义的其他集合名称。Haiku 仅区分 params
(在运行 apply
时不会更改的变量)和 state
(在运行 apply
时可以更改的变量)。
现在,训练在两个框架中看起来非常相似,因为你使用相同的 apply
方法来运行前向传递。在 Haiku 中,现在将 state
作为第二个参数传递给 apply
,并获得新的状态作为第二个返回值。在 Flax 中,你将 batch_stats
作为新键添加到输入字典中,并获得 updates
变量字典作为第二个返回值。
def train_step(params, state, inputs, labels):
def loss_fn(params):
logits, new_state = model.apply(
params, state,
None, # <== rng
inputs, training=True # <== inputs
)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
return loss, new_state
grads, new_state = jax.grad(loss_fn, has_aux=True)(params)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
return params, new_state
def train_step(params, batch_stats, inputs, labels):
def loss_fn(params):
logits, updates = model.apply(
{'params': params, 'batch_stats': batch_stats},
inputs, training=True, # <== inputs
mutable='batch_stats',
)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
return loss, updates["batch_stats"]
grads, batch_stats = jax.grad(loss_fn, has_aux=True)(params)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
return params, batch_stats
一个主要的区别是,在 Flax 中,状态集合可以是可变的或不可变的。在 init
期间,所有集合默认都是可变的,但是,在 apply
期间,你必须明确指定哪些集合是可变的。在这个例子中,我们指定 batch_stats
是可变的。这里传递了一个字符串,但如果有多个可变集合,也可以传递列表。如果没有这样做,当试图修改 batch_stats
时,会在运行时引发错误。此外,当 mutable
不是 False
时,updates
字典会作为 apply
的第二个返回值返回,否则只返回模型输出。Haiku 通过使用 params
(不可变)和 state
(可变),以及使用 hk.transform
或 hk.transform_with_state
来区分可变/不可变。
使用多种方法#
在本节中,我们将了解如何在 Haiku 和 Flax 中使用多种方法。作为一个例子,我们将实现一个具有三种方法的自动编码器模型:encode
、decode
和 __call__
。
在 Haiku 中,我们可以直接在 __init__
中定义 encode
和 decode
需要的子模块,在这种情况下,每个子模块都只使用一个 Linear
层。在 Flax 中,我们会在 setup
中预先定义一个 encoder
和一个 decoder
模块,并在 encode
和 decode
中分别使用它们。
class AutoEncoder(hk.Module):
def __init__(self, embed_dim: int, output_dim: int, name=None):
super().__init__(name=name)
self.encoder = hk.Linear(embed_dim, name="encoder")
self.decoder = hk.Linear(output_dim, name="decoder")
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
def __call__(self, x):
x = self.encode(x)
x = self.decode(x)
return x
class AutoEncoder(nn.Module):
embed_dim: int
output_dim: int
def setup(self):
self.encoder = nn.Dense(self.embed_dim)
self.decoder = nn.Dense(self.output_dim)
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
def __call__(self, x):
x = self.encode(x)
x = self.decode(x)
return x
请注意,在 Flax 中,setup
不会在 __init__
之后运行,而是在调用 init
或 apply
时运行。
现在,我们希望能够从我们的 AutoEncoder
模型中调用任何方法。在 Haiku 中,我们可以通过 hk.multi_transform
为一个模块定义多个 apply
方法。传递给 multi_transform
的函数定义了如何初始化模块以及生成哪些不同的应用方法。
def forward():
module = AutoEncoder(256, 784)
init = lambda x: module(x)
return init, (module.encode, module.decode)
model = hk.multi_transform(forward)
...
model = AutoEncoder(256, 784)
要初始化我们模型的参数,可以使用 init
来触发 __call__
方法,该方法同时使用 encode
和 decode
方法。这将为模型创建所有必要的参数。
params = model.init(
random.key(0),
x=jax.numpy.ones((1, 784)),
)
...
variables = model.init(
random.key(0),
x=jax.numpy.ones((1, 784)),
)
params = variables["params"]
这将生成以下参数结构。
{
'auto_encoder/~/decoder': {
'b': (784,),
'w': (256, 784)
},
'auto_encoder/~/encoder': {
'b': (256,),
'w': (784, 256)
}
}
FrozenDict({
decoder: {
bias: (784,),
kernel: (256, 784),
},
encoder: {
bias: (256,),
kernel: (784, 256),
},
})
最后,让我们探讨一下如何使用 apply
函数来调用 encode
方法
encode, decode = model.apply
z = encode(
params,
None, # <== rng
x=jax.numpy.ones((1, 784)),
)
...
z = model.apply(
{"params": params},
x=jax.numpy.ones((1, 784)),
method="encode",
)
因为 Haiku 的 apply
函数是通过 hk.multi_transform
生成的,所以它是一个由两个函数组成的元组,我们可以将其解包为 encode
和 decode
函数,它们对应于 AutoEncoder
模块上的方法。在 Flax 中,我们通过传递方法名称作为字符串来调用 encode
方法。另一个值得注意的区别是,在 Haiku 中,rng
需要显式传递,即使该模块在 apply
期间不使用任何随机操作。在 Flax 中,这是不必要的(请查看 Flax 中的随机性和 PRNG)。这里的 Haiku rng
设置为 None
,但你也可以在 apply
函数上使用 hk.without_apply_rng
来删除 rng
参数。
提升的变换#
Flax 和 Haiku 都提供了一组变换,我们将其称为提升的变换,它们以一种可以与模块一起使用的方式封装了 JAX 变换,有时还提供额外的功能。在本节中,我们将了解如何在 Flax 和 Haiku 中使用提升版本的 scan
来实现一个简单的 RNN 层。
首先,我们将定义一个 RNNCell
模块,它将包含 RNN 的单步逻辑。我们还将定义一个 initial_state
方法,该方法将用于初始化 RNN 的状态(又名 carry
)。与 jax.lax.scan
类似,RNNCell.__call__
方法将是一个接受 carry 和输入并返回新的 carry 和输出的函数。在这种情况下,carry 和输出是相同的。
class RNNCell(hk.Module):
def __init__(self, hidden_size: int, name=None):
super().__init__(name=name)
self.hidden_size = hidden_size
def __call__(self, carry, x):
x = jnp.concatenate([carry, x], axis=-1)
x = hk.Linear(self.hidden_size)(x)
x = jax.nn.relu(x)
return x, x
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.hidden_size))
class RNNCell(nn.Module):
hidden_size: int
@nn.compact
def __call__(self, carry, x):
x = jnp.concatenate([carry, x], axis=-1)
x = nn.Dense(self.hidden_size)(x)
x = jax.nn.relu(x)
return x, x
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.hidden_size))
接下来,我们将定义一个 RNN
模块,它将包含整个 RNN 的逻辑。在 Haiku 中,我们将首先初始化 RNNCell
,然后使用它来构造 carry
,最后使用 hk.scan
在输入序列上运行 RNNCell
。在 Flax 中,它的做法略有不同,我们将使用 nn.scan
来定义一个临时的新的类型,该类型封装了 RNNCell
。在这个过程中,我们还将指定指示 nn.scan
广播 params
集合(所有步骤共享相同的参数),并且不拆分 params
rng 流(因此所有步骤都使用相同的参数初始化),最后我们将指定我们希望 scan 在输入的第二个轴上运行,并将输出也沿着第二个轴堆叠。然后,我们将立即使用这个临时类型来创建提升的 RNNCell
的实例,并使用它来创建 carry
并运行 __call__
方法,该方法将 scan
遍历序列。
class RNN(hk.Module):
def __init__(self, hidden_size: int, name=None):
super().__init__(name=name)
self.hidden_size = hidden_size
def __call__(self, x):
cell = RNNCell(self.hidden_size)
carry = cell.initial_state(x.shape[0])
carry, y = hk.scan(cell, carry, jnp.swapaxes(x, 1, 0))
y = jnp.swapaxes(y, 0, 1)
return y
class RNN(nn.Module):
hidden_size: int
@nn.compact
def __call__(self, x):
rnn = nn.scan(RNNCell, variable_broadcast='params', split_rngs={'params': False},
in_axes=1, out_axes=1)(self.hidden_size)
carry = rnn.initial_state(x.shape[0])
carry, y = rnn(carry, x)
return y
总的来说,Flax 和 Haiku 之间提升的变换的主要区别在于,在 Haiku 中,提升的变换不会对状态进行操作,也就是说,Haiku 将以某种方式处理 params
和 state
,使其在变换内部和外部保持相同的形状。在 Flax 中,提升的变换可以对变量集合和 rng 流进行操作,用户必须根据变换的语义定义每个变换如何处理不同的集合。
最后,让我们快速查看一下如何在 Haiku 和 Flax 中使用 RNN
模块。
def forward(x):
return RNN(64)(x)
model = hk.without_apply_rng(hk.transform(forward))
params = model.init(
random.key(0),
x=jax.numpy.ones((3, 12, 32)),
)
y = model.apply(
params,
x=jax.numpy.ones((3, 12, 32)),
)
...
model = RNN(64)
variables = model.init(
random.key(0),
x=jax.numpy.ones((3, 12, 32)),
)
params = variables['params']
y = model.apply(
{'params': params},
x=jax.numpy.ones((3, 12, 32)),
)
与前几节中的示例相比,唯一值得注意的变化是,这次我们在 Haiku 中使用了 hk.without_apply_rng
,因此我们不必将 rng
参数作为 None
传递给 apply
方法。
扫描层#
scan
的一个非常重要的应用是迭代地将一系列层应用于输入,将每一层的输出作为下一层的输入。这对于减少大型模型的编译时间非常有用。例如,我们将创建一个简单的 Block
模块,然后在 MLP
模块内部使用它,该模块将 Block
模块应用 num_layers
次。
在 Haiku 中,我们像往常一样定义 Block
模块,然后在 MLP
内部,我们将使用 hk.experimental.layer_stack
对 stack_block
函数进行操作,以创建 Block
模块的堆叠。在 Flax 中,Block
的定义略有不同,__call__
将接受并返回第二个虚拟输入/输出,在这两种情况下都将为 None
。在 MLP
中,我们将像前面的示例一样使用 nn.scan
,但通过设置 split_rngs={'params': True}
和 variable_axes={'params': 0}
,我们告诉 nn.scan
为每个步骤创建不同的参数,并沿着第一个轴切分 params
集合,从而有效地实现与 Haiku 中相同的 Block
模块堆叠。
class Block(hk.Module):
def __init__(self, features: int, name=None):
super().__init__(name=name)
self.features = features
def __call__(self, x, training: bool):
x = hk.Linear(self.features)(x)
x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
x = jax.nn.relu(x)
return x
class MLP(hk.Module):
def __init__(self, features: int, num_layers: int, name=None):
super().__init__(name=name)
self.features = features
self.num_layers = num_layers
def __call__(self, x, training: bool):
@hk.experimental.layer_stack(self.num_layers)
def stack_block(x):
return Block(self.features)(x, training)
stack = hk.experimental.layer_stack(self.num_layers)
return stack_block(x)
class Block(nn.Module):
features: int
training: bool
@nn.compact
def __call__(self, x, _):
x = nn.Dense(self.features)(x)
x = nn.Dropout(0.5)(x, deterministic=not self.training)
x = jax.nn.relu(x)
return x, None
class MLP(nn.Module):
features: int
num_layers: int
@nn.compact
def __call__(self, x, training: bool):
ScanBlock = nn.scan(
Block, variable_axes={'params': 0}, split_rngs={'params': True},
length=self.num_layers)
y, _ = ScanBlock(self.features, training)(x, None)
return y
请注意,在 Flax 中,我们如何将 None
作为第二个参数传递给 ScanBlock
并忽略其第二个输出。这些表示每个步骤的输入/输出,但它们是 None
,因为在这种情况下我们没有任何输入/输出。
初始化每个模型与前面的示例相同。在这种情况下,我们将指定要使用 5
层,每层具有 64
个特征。
def forward(x, training: bool):
return MLP(64, num_layers=5)(x, training)
model = hk.transform(forward)
sample_x = jax.numpy.ones((1, 64))
params = model.init(
random.key(0),
sample_x, training=False # <== inputs
)
...
...
model = MLP(64, num_layers=5)
sample_x = jax.numpy.ones((1, 64))
variables = model.init(
random.key(0),
sample_x, training=False # <== inputs
)
params = variables['params']
当在层上使用 scan 时,您应该注意到的一件事是,所有层都被融合到一个单层中,该单层的参数在第一个轴上有一个额外的“层”维度。在这种情况下,所有参数的形状都将以 (5, ...)
开头,因为我们正在使用 5
层。
...
{
'mlp/__layer_stack_no_per_layer/block/linear': {
'b': (5, 64),
'w': (5, 64, 64)
}
}
...
FrozenDict({
ScanBlock_0: {
Dense_0: {
bias: (5, 64),
kernel: (5, 64, 64),
},
},
})
顶层 Haiku 函数与顶层 Flax 模块#
在 Haiku 中,可以通过使用原始的 hk.{get,set}_{parameter,state}
来定义/访问模型参数和状态,从而将整个模型编写为单个函数。将顶层“模块”编写为函数是很常见的做法。
Flax 团队推荐一种更以模块为中心的方法,该方法使用 __call__ 来定义前向函数。对应的访问器将是 nn.module.param 和 nn.module.variable(有关集合的解释,请转到处理状态)。
def forward(x):
counter = hk.get_state('counter', shape=[], dtype=jnp.int32, init=jnp.ones)
multiplier = hk.get_parameter('multiplier', shape=[1,], dtype=x.dtype, init=jnp.ones)
output = x + multiplier * counter
hk.set_state("counter", counter + 1)
return output
model = hk.transform_with_state(forward)
params, state = model.init(random.key(0), jax.numpy.ones((1, 64)))
class FooModule(nn.Module):
@nn.compact
def __call__(self, x):
counter = self.variable('counter', 'count', lambda: jnp.ones((), jnp.int32))
multiplier = self.param('multiplier', nn.initializers.ones_init(), [1,], x.dtype)
output = x + multiplier * counter.value
if not self.is_initializing(): # otherwise model.init() also increases it
counter.value += 1
return output
model = FooModule()
variables = model.init(random.key(0), jax.numpy.ones((1, 64)))
params, counter = variables['params'], variables['counter']