将我的代码库升级到 Optax#

我们已在 2021 年提出用 Optax 替换 flax.optim,详见 FLIP #1009,并且 Flax 优化器已在 v0.6.0 中删除 - 本指南面向 flax.optim 用户,帮助他们将代码更新为 Optax。

另请参阅 Optax 的快速入门文档:https://optax.readthedocs.io/en/latest/getting_started.html

optax 替换 flax.optim#

Optax 为所有 Flax 的优化器提供了直接替换。有关 API 详细信息,请参阅 Optax 的文档 常用优化器

用法非常相似,不同之处在于 optax 不会保留 params 的副本,因此需要单独传递它们。Flax 提供了实用程序 TrainState,用于将优化器状态、参数和其他相关数据存储在单个数据类中(在下面的代码中未使用)。

@jax.jit
def train_step(optimizer, batch):
  grads = jax.grad(loss)(optimizer.target, batch)


  return optimizer.apply_gradient(grads)

optimizer_def = flax.optim.Momentum(
    learning_rate, momentum)
optimizer = optimizer_def.create(variables['params'])

for batch in get_ds_train():
  optimizer = train_step(optimizer, batch)
@jax.jit
def train_step(params, opt_state, batch):
  grads = jax.grad(loss)(params, batch)
  updates, opt_state = tx.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  return params, opt_state

tx = optax.sgd(learning_rate, momentum)
params = variables['params']
opt_state = tx.init(params)

for batch in ds_train:
  params, opt_state = train_step(params, opt_state, batch)

可组合的梯度变换#

上面代码片段中使用的函数 optax.sgd() 只是两个梯度变换的顺序应用的包装器。通常不使用此别名,而是使用 optax.chain() 来组合多个这些通用构建块。

# Note that the aliases follow the convention to use positive
# values for the learning rate by default.
tx = optax.sgd(learning_rate, momentum)
#

tx = optax.chain(
    # 1. Step: keep a trace of past updates and add to gradients.
    optax.trace(decay=momentum),
    # 2. Step: multiply result from step 1 with negative learning rate.
    # Note that `optax.apply_updates()` simply adds the final updates to the
    # parameters, so we must make sure to flip the sign here for gradient
    # descent.
    optax.scale(-learning_rate),
)

权重衰减#

Flax 的一些优化器还包括权重衰减。在 Optax 中,一些优化器也有权重衰减参数(例如 optax.adamw()),对于其他优化器,可以将权重衰减添加为另一个“梯度变换” optax.add_decayed_weights(),该变换会添加从参数派生的更新。

optimizer_def = flax.optim.Adam(
    learning_rate, weight_decay=weight_decay)
optimizer = optimizer_def.create(variables['params'])
# (Note that you could also use `optax.adamw()` in this case)
tx = optax.chain(
    optax.scale_by_adam(),
    optax.add_decayed_weights(weight_decay),
    # params -= learning_rate * (adam(grads) + params * weight_decay)
    optax.scale(-learning_rate),
)
# Note that you'll need to specify `params` when computing the udpates:
# tx.update(grads, opt_state, params)

梯度裁剪#

通过将梯度裁剪为全局范数可以稳定训练(Pascanu 等人,2012)。在 Flax 中,通常在将梯度传递给优化器之前对其进行处理。对于 Optax,这只是另一个梯度变换 optax.clip_by_global_norm()

def train_step(optimizer, batch):
  grads = jax.grad(loss)(optimizer.target, batch)
  grads_flat, _ = jax.tree_util.tree_flatten(grads)
  global_l2 = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads_flat]))
  g_factor = jnp.minimum(1.0, grad_clip_norm / global_l2)
  grads = jax.tree_util.tree_map(lambda g: g * g_factor, grads)
  return optimizer.apply_gradient(grads)
tx = optax.chain(
    optax.clip_by_global_norm(grad_clip_norm),
    optax.trace(decay=momentum),
    optax.scale(-learning_rate),
)

学习率调度#

对于学习率调度,Flax 允许在应用梯度时覆盖超参数。Optax 维护一个步长计数器,并将其作为参数传递给一个函数,用于缩放使用 optax.scale_by_schedule() 添加的更新。Optax 还允许指定一个函数,通过 optax.inject_hyperparams() 为其他梯度更新注入任意标量值。

lr_schedule 指南中阅读有关学习率调度的更多信息。

优化器调度 下阅读有关 Optax 中定义的调度的更多信息。标准优化器(如 optax.adam()optax.sgd() 等)也接受学习率调度作为 learning_rate 的参数。

def train_step(step, optimizer, batch):
  grads = jax.grad(loss)(optimizer.target, batch)
  return step + 1, optimizer.apply_gradient(grads, learning_rate=schedule(step))
tx = optax.chain(
    optax.trace(decay=momentum),
    # Note that we still want a negative value for scaling the updates!
    optax.scale_by_schedule(lambda step: -schedule(step)),
)

多个优化器 / 更新参数子集#

在 Flax 中,使用遍历来指定应由优化器更新的参数。你可以使用 flax.optim.MultiOptimizer 组合遍历,以便在不同的参数上应用不同的优化器。Optax 中的等效项是 optax.masked()optax.chain()

请注意,下面的示例使用 flax.traverse_util 来创建 optax.masked() 所需的布尔掩码 - 或者你也可以手动创建它们,或者使用 optax.multi_transform(),该函数接受一个多值 pytree 来指定梯度变换。

请注意,optax.masked() 在内部展平 pytree,并且内部梯度变换将仅使用参数/梯度的部分展平视图来调用。这通常不是问题,但这使得难以嵌套多个级别的掩码梯度变换(因为内部掩码将期望根据部分展平的视图定义掩码,而该视图在外部掩码之外不容易获得)。

kernels = flax.traverse_util.ModelParamTraversal(lambda p, _: 'kernel' in p)
biases = flax.traverse_util.ModelParamTraversal(lambda p, _: 'bias' in p)

kernel_opt = flax.optim.Momentum(learning_rate, momentum)
bias_opt = flax.optim.Momentum(learning_rate * 0.1, momentum)


optimizer = flax.optim.MultiOptimizer(
    (kernels, kernel_opt),
    (biases, bias_opt)
).create(variables['params'])
kernels = flax.traverse_util.ModelParamTraversal(lambda p, _: 'kernel' in p)
biases = flax.traverse_util.ModelParamTraversal(lambda p, _: 'bias' in p)

all_false = jax.tree_util.tree_map(lambda _: False, params)
kernels_mask = kernels.update(lambda _: True, all_false)
biases_mask = biases.update(lambda _: True, all_false)

tx = optax.chain(
    optax.trace(decay=momentum),
    optax.masked(optax.scale(-learning_rate), kernels_mask),
    optax.masked(optax.scale(-learning_rate * 0.1), biases_mask),
)

结语#

当然,以上所有模式也可以混合使用,Optax 使得可以将所有这些变换封装到主训练循环之外的单个位置,从而使测试更加容易。