将我的代码库升级到 Linen#

从 Flax v0.4.0 开始,flax.nn 不再存在,并被新的 Linen API(位于 flax.linen)所取代。如果您的代码库仍在使用旧的 API,您可以使用此升级指南将其升级到 Linen。

定义简单的 Flax 模块#

from flax import nn

class Dense(base.Module):
  def apply(self,
            inputs,
            features,
            use_bias=True,
            kernel_init=default_kernel_init,
            bias_init=initializers.zeros_init()):

    kernel = self.param('kernel',
      (inputs.shape[-1], features), kernel_init)
    y = jnp.dot(inputs, kernel)
    if use_bias:
      bias = self.param(
        'bias', (features,), bias_init)
      y = y + bias
    return y
from flax import linen as nn  # [1]

class Dense(nn.Module):
  features: int  # [2]
  use_bias: bool = True
  kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
  bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros_init()

  @nn.compact
  def __call__(self, inputs):  # [3]
    kernel = self.param('kernel',
      self.kernel_init, (inputs.shape[-1], self.features))  # [4]
    y = jnp.dot(inputs, kernel)
    if self.use_bias:
      bias = self.param(
        'bias', self.bias_init, (self.features,))  # [5]
      y = y + bias
    return y

  1. from flax import nn 替换为 from flax import linen as nn

  2. apply 的参数移动到数据类属性中。添加类型注释(或使用类型 Any 跳过)。

  3. 将方法 apply 重命名为 __call__,并(可选)使用 @compact 包装。用 @compact 包装的方法可以直接在该方法中定义子模块(如在旧的 Flax 中)。您只能用 @compact 包装一个方法。或者,您可以定义一个 setup 方法。有关更多详细信息,请参阅我们的其他 HOWTO 我应该使用 setup 还是 nn.compact?

  4. 通过 self.<attr> 在方法内部访问数据类属性值,例如 self.features

  5. 将形状移动到 self.param 参数的末尾(初始化器函数可以接受任意参数列表)。

在其他模块中使用 Flax 模块#

class Encoder(nn.Module):

  def apply(self, x):
    x = nn.Dense(x, 500)
    x = nn.relu(x)
    z = nn.Dense(x, 500, name="latents")
    return z
class Encoder(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Dense(500)(x)  # [1]
    x = nn.relu(x)
    z = nn.Dense(500, name='latents')(x)  # [2]
    return z

  1. 模块构造函数不再返回输出。相反,它们像普通构造函数一样工作并返回模块实例。这些实例可以像在普通 Python 中一样共享(而不是在旧的 Flax 中使用 .shared())。由于大多数模块都实现了 __call__,因此您可以保留旧 Flax 的简洁性。

  2. 名称可以选择传递给所有模块构造函数。

共享子模块和定义多个方法#

class AutoEncoder(nn.Module):
  def _create_submodules(self):
    return Decoder.shared(name="encoder")

  def apply(self, x, z_rng, latents=20):
    decoder = self._create_decoder()
    z = Encoder(x, latents, name="encoder")
    return decoder(z)

  @nn.module_method
  def generate(self, z, **unused_kwargs):
    decoder = self._create_decoder()
    return nn.sigmoid(decoder(z))
class AutoEncoder(nn.Module):
  latents: int = 20

  def setup(self):  # [1]
    self.encoder = Encoder(self.latents)  # [2]
    self.decoder = Decoder()

  def __call__(self, x):  # [3]
    z = self.encoder(x)
    return self.decoder(z)

  def generate(self, z):  # [4]
    return nn.sigmoid(self.decoder(z))

  1. 使用 setup 而不是 __init__,后者已在数据类库中定义。Flax 会在模块准备好使用后立即调用 setup。(如果您喜欢,可以对所有模块执行此操作,而不是使用 @compact,但我们喜欢 @compact 如何共同定位模块的定义和使用位置,特别是当您有循环或条件时)。

  2. 像常规 Python 一样,通过在初始化期间分配给 self 来共享子模块。与 PyTorch 类似,self.encoder 会自动具有名称 "encoder"

  3. 我们在这里不使用 @compact,因为我们没有定义任何内联子模块(所有子模块都在 setup 中定义)。

  4. 像在常规 Python 中一样定义其他方法。

在其他模块中使用 Module.partial#

# no import

class ResNet(nn.Module):
  """ResNetV1."""


  def apply(self, x,
            stage_sizes,
            num_filters=64,
            train=True):
    conv = nn.Conv.partial(bias=False)
    norm = nn.BatchNorm.partial(
        use_running_average=not train,
        momentum=0.9, epsilon=1e-5)

    x = conv(x, num_filters, (7, 7), (2, 2),
            padding=[(3, 3), (3, 3)],
            name='conv_init')
    x = norm(x, name='bn_init')

    # [...]
    return x
from functools import partial

class ResNet(nn.Module):
  """ResNetV1."""
  stage_sizes: Sequence[int]
  num_filters: int = 64
  train: bool = True

  @nn.compact
  def __call__(self, x):
    conv = partial(nn.Conv, use_bias=False)
    norm = partial(nn.BatchNorm,
                  use_running_average=not self.train,
                  momentum=0.9, epsilon=1e-5)

    x = conv(self.num_filters, (7, 7), (2, 2),
            padding=[(3, 3), (3, 3)],
            name='conv_init')(x)
    x = norm(name='bn_init')(x)

    # [...]
    return x

使用正常的 functools.partial 而不是 Module.partial。其余部分保持不变。

顶层训练代码模式#

def create_model(key):
  _, initial_params = CNN.init_by_shape(
    key, [((1, 28, 28, 1), jnp.float32)])
  model = nn.Model(CNN, initial_params)
  return model

def create_optimizer(model, learning_rate):
  optimizer_def = optim.Momentum(learning_rate=learning_rate)
  optimizer = optimizer_def.create(model)
  return optimizer

def cross_entropy_loss(*, logits, labels):
  one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
  return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1))

def loss_fn(model):
  logits = model(batch['image'])
  one_hot = jax.nn.one_hot(batch['label'], num_classes=10)
  loss = -jnp.mean(jnp.sum(one_hot_labels * batch['label'],
                           axis=-1))
  return loss, logits
def create_train_state(rng, config):  # [1]
  variables = CNN().init(rng, jnp.ones([1, 28, 28, 1]))  # [2]
  params = variables['params']  # [3]
  tx = optax.sgd(config.learning_rate, config.momentum)  # [4]
  return train_state.TrainState.create(
      apply_fn=CNN.apply, params=params, tx=tx)


def loss_fn(params):
  logits = CNN().apply({'params': params}, batch['image'])  # [5]
  one_hot = jax.nn.one_hot(batch['label'], 10)
  loss = jnp.mean(optax.softmax_cross_entropy(logits=logits,
                                              labels=one_hot))
  return loss, logits

  1. 我们不再使用 Model 抽象 – 而是直接传递参数,通常封装在 TrainState 对象中,该对象可以直接传递给 JAX 转换。

  2. 要计算初始参数,请构造一个模块实例并调用 initinit_with_output。我们没有移植 init_by_shape,因为此函数执行了一些我们不喜欢的神奇操作(它按形状评估函数,但仍然返回真实值)。因此,您现在应该将具体值传递给初始化器函数,并且可以通过使用 jax.jit 包装它来优化初始化,强烈建议这样做以避免运行完整的正向传递。

  3. Linen 将参数泛化为变量。参数是变量的“集合”之一。变量是嵌套字典,其中顶层键反映不同的变量集合,“param”是其中之一。有关更多详细信息,请参阅 变量文档

  4. 我们建议使用 Optax 优化器。有关更多详细信息,请参阅我们单独的 HOWTO,名为 将我的代码库升级到 Optax

  5. 要使用您的模型进行预测,请在顶层创建一个实例(这是免费的 – 只是构造函数属性的包装器)并调用 apply 方法(它将在内部调用 __call__)。

不可训练的变量(“状态”):在模块中使用#

class BatchNorm(nn.Module):
  def apply(self, x):
    # [...]
    ra_mean = self.state(
      'mean', (x.shape[-1], ), initializers.zeros_init())
    ra_var = self.state(
      'var', (x.shape[-1], ), initializers.ones_init())
    # [...]
class BatchNorm(nn.Module):
  def __call__(self, x):
    # [...]
    ra_mean = self.variable(
      'batch_stats', 'mean', initializers.zeros_init(), (x.shape[-1], ))
    ra_var = self.variable(
      'batch_stats', 'var', initializers.ones_init(), (x.shape[-1], ))
    # [...]

第一个参数是变量集合的名称(“param”是唯一始终可用的变量集合)。一些集合可能被视为可变的,而另一些集合可能被视为顶层训练代码中的不可变的(有关详细信息,请参阅下一节)。当在模块内部使用 JAX 转换时,Flax 还允许您以不同的方式处理每个变量集合。

不可训练的变量(“状态”):顶层训练代码模式#

# initial params and state
def initial_model(key, init_batch):
  with nn.stateful() as initial_state:
    _, initial_params = ResNet.init(key, init_batch)
  model = nn.Model(ResNet, initial_params)
  return model, init_state


# updates batch statistics during training
def loss_fn(model, model_state):
  with nn.stateful(model_state) as new_model_state:
    logits = model(batch['image'])
  # [...]



# reads immutable batch statistics during evaluation
def eval_step(model, model_state, batch):
  with nn.stateful(model_state, mutable=False):
    logits = model(batch['image'], train=False)
  return compute_metrics(logits, batch['label'])
# initial variables ({"param": ..., "batch_stats": ...})
def initial_variables(key, init_batch):
  return ResNet().init(key, init_batch)  # [1]



# updates batch statistics during training
def loss_fn(params, batch_stats):
  variables = {'params': params, 'batch_stats': batch_stats}  # [2]
  logits, new_variables = ResNet(train=true).apply(
    variables, batch['image'], mutable=['batch_stats'])  # [3]
  new_batch_stats = new_variables['batch_stats']
  # [...]


# reads immutable batch statistics during evaluation
def eval_step(params, batch_stats, batch):
  variables = {'params': params, 'batch_stats': batch_stats}
  logits = ResNet(train=False).apply(
    variables, batch['image'], mutable=False)  # [4]
  return compute_metrics(logits, batch['label'])

  1. init 返回一个变量字典,例如 {"param": ..., "batch_stats": ...}(请参阅 变量文档)。

  2. 将不同的变量集合组合成一个变量字典。

  3. 在训练期间,batch_stats 变量集合会发生变化。由于我们在 mutable 参数中指定了这一点,因此 module.apply 的返回值将成为 output, new_variables 的有序对。

  4. 在评估期间,如果我们意外地在训练模式下应用批量归一化,我们希望引发一个错误。通过将 mutable=False 传递到 module.apply 中,我们可以强制执行此操作。由于没有变量发生变异,因此返回值再次只是输出。

加载 pre-Linen 检查点#

虽然大多数 Linen 模块应该能够直接使用预 Linen 的权重而无需任何修改,但有一个需要注意的地方:在预 Linen API 中,子模块的编号是递增的,与子模块的类无关。而在 Linen 中,此行为已更改为每个模块类保持单独的子模块计数。

在预 Linen 中,参数具有以下结构

{'Conv_0': { ... }, 'Dense_1': { ... } }

在 Linen 中,结构变为:

{'Conv_0': { ... }, 'Dense_0': { ... } }

待办事项:在此处添加一个如何加载新的 TrainState 对象的示例。

随机性#

def dropout(inputs, rate, deterministic=False):
  keep_prob = 1. - rate
  if deterministic:
    return inputs
  else:
    mask = random.bernoulli(
    make_rng(), p=keep_prob, shape=inputs.shape)
    return lax.select(
      mask, inputs / keep_prob, jnp.zeros_like(inputs))


def loss_fn(model, dropout_rng):
  with nn.stochastic(dropout_rng):
    logits = model(inputs)
class Dropout(nn.Module):
  rate: float

  @nn.compact
  def __call__(self, inputs, deterministic=False):
    keep_prob = 1. - self.rate
    if deterministic:
      return inputs
    else:
      mask = random.bernoulli(
        self.make_rng('dropout'), p=keep_prob, shape=inputs.shape)  # [1]
      return lax.select(
        mask, inputs / keep_prob, jnp.zeros_like(inputs))


def loss_fn(params, dropout_rng):
  logits = Transformer().apply(
    {'params': params}, inputs, rngs={'dropout': dropout_rng})  # [2]

  1. Linen 中的 RNG 具有“种类”——在本例中是 'dropout'。不同的种类可以在 JAX 转换中被区别对待(例如,你希望序列模型中每个时间步使用相同的 dropout 掩码还是不同的掩码?)

  2. 你不是使用 nn.stochastic 上下文管理器,而是将 RNG 显式传递给 module.apply。在评估期间,你不会传递任何 RNG - 那么如果你在非确定性模式下意外使用 dropout,self.make_rng('dropout') 会引发错误。

提升的转换#

在 Linen 中,我们不是直接使用 JAX 转换,而是使用“提升的转换”,这些转换是应用于 Flax 模块的 JAX 转换。

有关更多信息,请参阅关于 提升转换 的设计说明。

待办事项:给出 jax.scan_in_dim(预 Linen)与 nn.scan(Linen)的示例。