在多设备上进行集成学习#
我们展示了如何在 MNIST 数据集上训练 CNN 集成模型,其中集成模型的大小等于可用设备的数量。简而言之,这种变化可以描述为
使用
jax.pmap()
使多个函数并行化,分割随机种子以获得不同的参数初始化,
在必要时复制输入并取消复制输出,
平均跨设备的概率以计算预测值。
在本操作指南中,我们省略了一些代码,例如导入、CNN 模块和指标计算,但可以在 MNIST 示例中找到。
并行函数#
我们首先创建 create_train_state()
的并行版本,它检索模型的初始参数。我们使用 jax.pmap()
来完成此操作。“pmapping” 一个函数的效果是它将使用 XLA 编译该函数(类似于 jax.jit()
),但在 XLA 设备(例如,GPU/TPU)上并行执行它。
def create_train_state(rng, learning_rate, momentum):
cnn = CNN()
params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
tx = optax.sgd(learning_rate, momentum)
return train_state.TrainState.create(
apply_fn=cnn.apply, params=params, tx=tx)
@functools.partial(jax.pmap, static_broadcasted_argnums=(1, 2))
def create_train_state(rng, learning_rate, momentum):
cnn = CNN()
params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
tx = optax.sgd(learning_rate, momentum)
return train_state.TrainState.create(
apply_fn=cnn.apply, params=params, tx=tx)
请注意,对于上面的单模型代码,我们使用 jax.jit()
来延迟初始化模型(有关更多详细信息,请参阅 Module.init 的文档)。对于集成情况,jax.pmap()
默认会映射提供的参数 rng
的第一个轴,因此我们应确保在稍后调用此函数时为每个设备提供不同的值。
还要注意我们如何指定 learning_rate
和 momentum
是静态参数,这意味着将使用这些参数的具体值,而不是抽象形状。这是必要的,因为提供的参数将是标量值。有关更多详细信息,请参阅 JIT 机制:追踪和静态变量。
接下来,我们只需对函数 apply_model()
和 update_model()
执行相同的操作。为了计算集成的预测,我们取各个概率的平均值。我们使用 jax.lax.pmean()
计算跨设备的平均值。这还要求我们为 jax.pmap()
和 jax.lax.pmean()
都指定 axis_name
。
@jax.jit
def apply_model(state, images, labels):
def loss_fn(params):
logits = CNN().apply({'params': params}, images)
one_hot = jax.nn.one_hot(labels, 10)
loss = optax.softmax_cross_entropy(logits=logits, labels=one_hot).mean()
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(state.params)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
return grads, loss, accuracy
@jax.jit
def update_model(state, grads):
return state.apply_gradients(grads=grads)
@functools.partial(jax.pmap, axis_name='ensemble')
def apply_model(state, images, labels):
def loss_fn(params):
logits = CNN().apply({'params': params}, images)
one_hot = jax.nn.one_hot(labels, 10)
loss = optax.softmax_cross_entropy(logits=logits, labels=one_hot).mean()
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(state.params)
probs = jax.lax.pmean(jax.nn.softmax(logits), axis_name='ensemble')
accuracy = jnp.mean(jnp.argmax(probs, -1) == labels)
return grads, loss, accuracy
@jax.pmap
def update_model(state, grads):
return state.apply_gradients(grads=grads)
训练集成模型#
接下来,我们转换 train_epoch()
函数。在调用上述的 pmapped 函数时,我们主要需要注意在必要时复制所有设备的参数,并取消重复返回值。
def train_epoch(state, train_ds, batch_size, rng):
train_ds_size = len(train_ds['image'])
steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rng, len(train_ds['image']))
perms = perms[:steps_per_epoch * batch_size]
perms = perms.reshape((steps_per_epoch, batch_size))
epoch_loss = []
epoch_accuracy = []
for perm in perms:
batch_images = train_ds['image'][perm, ...]
batch_labels = train_ds['label'][perm, ...]
grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
state = update_model(state, grads)
epoch_loss.append(loss)
epoch_accuracy.append(accuracy)
train_loss = np.mean(epoch_loss)
train_accuracy = np.mean(epoch_accuracy)
return state, train_loss, train_accuracy
def train_epoch(state, train_ds, batch_size, rng):
train_ds_size = len(train_ds['image'])
steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rng, len(train_ds['image']))
perms = perms[:steps_per_epoch * batch_size]
perms = perms.reshape((steps_per_epoch, batch_size))
epoch_loss = []
epoch_accuracy = []
for perm in perms:
batch_images = jax_utils.replicate(train_ds['image'][perm, ...])
batch_labels = jax_utils.replicate(train_ds['label'][perm, ...])
grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
state = update_model(state, grads)
epoch_loss.append(jax_utils.unreplicate(loss))
epoch_accuracy.append(jax_utils.unreplicate(accuracy))
train_loss = np.mean(epoch_loss)
train_accuracy = np.mean(epoch_accuracy)
return state, train_loss, train_accuracy
如您所见,我们无需对 state
周围的逻辑进行任何更改。这是因为,正如我们将在下面的训练代码中看到的那样,训练状态已经复制,因此当我们将其传递给 train_step()
时,事情会正常工作,因为 train_step()
是 pmapped 的。但是,训练数据集尚未复制,因此我们在此处进行复制。由于复制整个训练数据集的内存消耗过大,因此我们在批处理级别执行此操作。
现在我们可以重写实际的训练逻辑。这包括两个简单的更改:确保在将 RNG 传递给 create_train_state()
时复制它们,以及复制测试数据集,该数据集比训练数据集小得多,因此我们可以直接对整个数据集执行此操作。
train_ds, test_ds = get_datasets()
rng = jax.random.key(0)
rng, init_rng = jax.random.split(rng)
state = create_train_state(init_rng, learning_rate, momentum)
for epoch in range(1, num_epochs + 1):
rng, input_rng = jax.random.split(rng)
state, train_loss, train_accuracy = train_epoch(
state, train_ds, batch_size, input_rng)
_, test_loss, test_accuracy = apply_model(
state, test_ds['image'], test_ds['label'])
logging.info(
'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, '
'test_loss: %.4f, test_accuracy: %.2f'
% (epoch, train_loss, train_accuracy * 100, test_loss,
test_accuracy * 100))
train_ds, test_ds = get_datasets()
test_ds = jax_utils.replicate(test_ds)
rng = jax.random.key(0)
rng, init_rng = jax.random.split(rng)
state = create_train_state(jax.random.split(init_rng, jax.device_count()),
learning_rate, momentum)
for epoch in range(1, num_epochs + 1):
rng, input_rng = jax.random.split(rng)
state, train_loss, train_accuracy = train_epoch(
state, train_ds, batch_size, input_rng)
_, test_loss, test_accuracy = jax_utils.unreplicate(
apply_model(state, test_ds['image'], test_ds['label']))
logging.info(
'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, '
'test_loss: %.4f, test_accuracy: %.2f'
% (epoch, train_loss, train_accuracy * 100, test_loss,
test_accuracy * 100))