学习率调度#
学习率被认为是训练深度神经网络最重要的超参数之一,但选择它可能相当困难。与其简单地使用固定的学习率,不如使用学习率调度器。在这个例子中,我们将使用余弦调度器。在余弦调度器生效之前,我们首先进行所谓的预热阶段,其中学习率在 warmup_epochs
个 epoch 内线性增加。有关余弦调度器的更多信息,请查看论文“SGDR:带热重启的随机梯度下降”。
我们将向您展示如何…
定义学习率调度
使用该调度训练一个简单的模型
def create_learning_rate_fn(config, base_learning_rate, steps_per_epoch):
"""Creates learning rate schedule."""
warmup_fn = optax.linear_schedule(
init_value=0., end_value=base_learning_rate,
transition_steps=config.warmup_epochs * steps_per_epoch)
cosine_epochs = max(config.num_epochs - config.warmup_epochs, 1)
cosine_fn = optax.cosine_decay_schedule(
init_value=base_learning_rate,
decay_steps=cosine_epochs * steps_per_epoch)
schedule_fn = optax.join_schedules(
schedules=[warmup_fn, cosine_fn],
boundaries=[config.warmup_epochs * steps_per_epoch])
return schedule_fn
要使用该调度,我们必须通过将超参数传递给 create_learning_rate_fn
函数来创建一个学习率函数,然后将该函数传递给您的 Optax
优化器。例如,在 MNIST 上使用此调度需要更改 train_step
函数
@jax.jit
def train_step(state, batch):
def loss_fn(params):
logits = CNN().apply({'params': params}, batch['image'])
one_hot = jax.nn.one_hot(batch['label'], 10)
loss = jnp.mean(optax.softmax_cross_entropy(logits, one_hot))
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(_, logits), grads = grad_fn(state.params)
new_state = state.apply_gradients(grads=grads)
metrics = compute_metrics(logits, batch['label'])
return new_state, metrics
@functools.partial(jax.jit, static_argnums=2)
def train_step(state, batch, learning_rate_fn):
def loss_fn(params):
logits = CNN().apply({'params': params}, batch['image'])
one_hot = jax.nn.one_hot(batch['label'], 10)
loss = jnp.mean(optax.softmax_cross_entropy(logits, one_hot))
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(_, logits), grads = grad_fn(state.params)
new_state = state.apply_gradients(grads=grads)
metrics = compute_metrics(logits, batch['label'])
lr = learning_rate_fn(state.step)
metrics['learning_rate'] = lr
return new_state, metrics
以及 train_epoch
函数
def train_epoch(state, train_ds, batch_size, epoch, rng):
"""Trains for a single epoch."""
train_ds_size = len(train_ds['image'])
steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rng, len(train_ds['image']))
perms = perms[:steps_per_epoch * batch_size]
perms = perms.reshape((steps_per_epoch, batch_size))
batch_metrics = []
for perm in perms:
batch = {k: v[perm, ...] for k, v in train_ds.items()}
state, metrics = train_step(state, batch)
batch_metrics.append(metrics)
# compute mean of metrics across each batch in epoch.
batch_metrics = jax.device_get(batch_metrics)
epoch_metrics = {
k: np.mean([metrics[k] for metrics in batch_metrics])
for k in batch_metrics[0]}
logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
epoch_metrics['loss'], epoch_metrics['accuracy'] * 100)
return state, epoch_metrics
def train_epoch(state, train_ds, batch_size, epoch, learning_rate_fn, rng):
"""Trains for a single epoch."""
train_ds_size = len(train_ds['image'])
steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rng, len(train_ds['image']))
perms = perms[:steps_per_epoch * batch_size]
perms = perms.reshape((steps_per_epoch, batch_size))
batch_metrics = []
for perm in perms:
batch = {k: v[perm, ...] for k, v in train_ds.items()}
state, metrics = train_step(state, batch, learning_rate_fn)
batch_metrics.append(metrics)
# compute mean of metrics across each batch in epoch.
batch_metrics = jax.device_get(batch_metrics)
epoch_metrics = {
k: np.mean([metrics[k] for metrics in batch_metrics])
for k in batch_metrics[0]}
logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
epoch_metrics['loss'], epoch_metrics['accuracy'] * 100)
return state, epoch_metrics
以及 create_train_state
函数
def create_train_state(rng, config):
"""Creates initial `TrainState`."""
cnn = CNN()
params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
tx = optax.sgd(config.learning_rate, config.momentum)
return train_state.TrainState.create(
apply_fn=cnn.apply, params=params, tx=tx)
def create_train_state(rng, config, learning_rate_fn):
"""Creates initial `TrainState`."""
cnn = CNN()
params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
tx = optax.sgd(learning_rate_fn, config.momentum)
return train_state.TrainState.create(
apply_fn=cnn.apply, params=params, tx=tx)