将我的代码库升级到 Optax#
我们已在 2021 年提出用 Optax 替换 flax.optim
,详见 FLIP #1009,并且 Flax 优化器已在 v0.6.0 中删除 - 本指南面向 flax.optim
用户,帮助他们将代码更新为 Optax。
另请参阅 Optax 的快速入门文档:https://optax.readthedocs.io/en/latest/getting_started.html
用 optax
替换 flax.optim
#
Optax 为所有 Flax 的优化器提供了直接替换。有关 API 详细信息,请参阅 Optax 的文档 常用优化器。
用法非常相似,不同之处在于 optax
不会保留 params
的副本,因此需要单独传递它们。Flax 提供了实用程序 TrainState
,用于将优化器状态、参数和其他相关数据存储在单个数据类中(在下面的代码中未使用)。
@jax.jit
def train_step(optimizer, batch):
grads = jax.grad(loss)(optimizer.target, batch)
return optimizer.apply_gradient(grads)
optimizer_def = flax.optim.Momentum(
learning_rate, momentum)
optimizer = optimizer_def.create(variables['params'])
for batch in get_ds_train():
optimizer = train_step(optimizer, batch)
@jax.jit
def train_step(params, opt_state, batch):
grads = jax.grad(loss)(params, batch)
updates, opt_state = tx.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state
tx = optax.sgd(learning_rate, momentum)
params = variables['params']
opt_state = tx.init(params)
for batch in ds_train:
params, opt_state = train_step(params, opt_state, batch)
可组合的梯度变换#
上面代码片段中使用的函数 optax.sgd()
只是两个梯度变换的顺序应用的包装器。通常不使用此别名,而是使用 optax.chain()
来组合多个这些通用构建块。
# Note that the aliases follow the convention to use positive
# values for the learning rate by default.
tx = optax.sgd(learning_rate, momentum)
#
tx = optax.chain(
# 1. Step: keep a trace of past updates and add to gradients.
optax.trace(decay=momentum),
# 2. Step: multiply result from step 1 with negative learning rate.
# Note that `optax.apply_updates()` simply adds the final updates to the
# parameters, so we must make sure to flip the sign here for gradient
# descent.
optax.scale(-learning_rate),
)
权重衰减#
Flax 的一些优化器还包括权重衰减。在 Optax 中,一些优化器也有权重衰减参数(例如 optax.adamw()
),对于其他优化器,可以将权重衰减添加为另一个“梯度变换” optax.add_decayed_weights()
,该变换会添加从参数派生的更新。
optimizer_def = flax.optim.Adam(
learning_rate, weight_decay=weight_decay)
optimizer = optimizer_def.create(variables['params'])
# (Note that you could also use `optax.adamw()` in this case)
tx = optax.chain(
optax.scale_by_adam(),
optax.add_decayed_weights(weight_decay),
# params -= learning_rate * (adam(grads) + params * weight_decay)
optax.scale(-learning_rate),
)
# Note that you'll need to specify `params` when computing the udpates:
# tx.update(grads, opt_state, params)
梯度裁剪#
通过将梯度裁剪为全局范数可以稳定训练(Pascanu 等人,2012)。在 Flax 中,通常在将梯度传递给优化器之前对其进行处理。对于 Optax,这只是另一个梯度变换 optax.clip_by_global_norm()
。
def train_step(optimizer, batch):
grads = jax.grad(loss)(optimizer.target, batch)
grads_flat, _ = jax.tree_util.tree_flatten(grads)
global_l2 = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads_flat]))
g_factor = jnp.minimum(1.0, grad_clip_norm / global_l2)
grads = jax.tree_util.tree_map(lambda g: g * g_factor, grads)
return optimizer.apply_gradient(grads)
tx = optax.chain(
optax.clip_by_global_norm(grad_clip_norm),
optax.trace(decay=momentum),
optax.scale(-learning_rate),
)
学习率调度#
对于学习率调度,Flax 允许在应用梯度时覆盖超参数。Optax 维护一个步长计数器,并将其作为参数传递给一个函数,用于缩放使用 optax.scale_by_schedule()
添加的更新。Optax 还允许指定一个函数,通过 optax.inject_hyperparams()
为其他梯度更新注入任意标量值。
在 lr_schedule 指南中阅读有关学习率调度的更多信息。
在 优化器调度 下阅读有关 Optax 中定义的调度的更多信息。标准优化器(如 optax.adam()
、optax.sgd()
等)也接受学习率调度作为 learning_rate
的参数。
def train_step(step, optimizer, batch):
grads = jax.grad(loss)(optimizer.target, batch)
return step + 1, optimizer.apply_gradient(grads, learning_rate=schedule(step))
tx = optax.chain(
optax.trace(decay=momentum),
# Note that we still want a negative value for scaling the updates!
optax.scale_by_schedule(lambda step: -schedule(step)),
)
多个优化器 / 更新参数子集#
在 Flax 中,使用遍历来指定应由优化器更新的参数。你可以使用 flax.optim.MultiOptimizer
组合遍历,以便在不同的参数上应用不同的优化器。Optax 中的等效项是 optax.masked()
和 optax.chain()
。
请注意,下面的示例使用 flax.traverse_util
来创建 optax.masked()
所需的布尔掩码 - 或者你也可以手动创建它们,或者使用 optax.multi_transform()
,该函数接受一个多值 pytree 来指定梯度变换。
请注意,optax.masked()
在内部展平 pytree,并且内部梯度变换将仅使用参数/梯度的部分展平视图来调用。这通常不是问题,但这使得难以嵌套多个级别的掩码梯度变换(因为内部掩码将期望根据部分展平的视图定义掩码,而该视图在外部掩码之外不容易获得)。
kernels = flax.traverse_util.ModelParamTraversal(lambda p, _: 'kernel' in p)
biases = flax.traverse_util.ModelParamTraversal(lambda p, _: 'bias' in p)
kernel_opt = flax.optim.Momentum(learning_rate, momentum)
bias_opt = flax.optim.Momentum(learning_rate * 0.1, momentum)
optimizer = flax.optim.MultiOptimizer(
(kernels, kernel_opt),
(biases, bias_opt)
).create(variables['params'])
kernels = flax.traverse_util.ModelParamTraversal(lambda p, _: 'kernel' in p)
biases = flax.traverse_util.ModelParamTraversal(lambda p, _: 'bias' in p)
all_false = jax.tree_util.tree_map(lambda _: False, params)
kernels_mask = kernels.update(lambda _: True, all_false)
biases_mask = biases.update(lambda _: True, all_false)
tx = optax.chain(
optax.trace(decay=momentum),
optax.masked(optax.scale(-learning_rate), kernels_mask),
optax.masked(optax.scale(-learning_rate * 0.1), biases_mask),
)
结语#
当然,以上所有模式也可以混合使用,Optax 使得可以将所有这些变换封装到主训练循环之外的单个位置,从而使测试更加容易。