处理整个数据集#

出于效率原因,我们形成包含多个示例的批次并并行处理它们。特别是在评估模型时,重要的是我们处理所有示例并**避免丢失**末尾无法形成完整批次的剩余示例。

问题#

当在单个设备上进行评估时,可以删除最后一个不完整的批次,或者可以形成与前面的批次形状不同的最后一个批次。后者的缺点是这将触发 eval_step() 的**重新编译**,因为 XLA 不是形状多态的。

collections.Counter(
    tuple(batch['image'].shape)
    for batch in tfds.load('mnist', split='test').batch(per_device_batch_size)
)
# output:
# Counter({(272, 28, 28, 1): 1, (512, 28, 28, 1): 19})

当使用多个设备进行数据并行时,问题会更加突出。如果批次大小**不能被设备数量整除**,则最后一步必须在单个设备(或设备子集)上执行。通常会删除最后一个批次,但这会导致不正确的结果。

sum(
    np.prod(batch['label'].shape)
    for batch in tfds.load('mnist', split='test')
        .batch(per_device_batch_size, drop_remainder=True)
        .batch(jax.local_device_count())
)
# output:
# 9728

使用多个主机进一步使情况复杂化,因为 JAX 使用 SPMD 范式,每个主机必须执行相同的程序。我们通常会使用 tfds.split_for_jax_process() 为不同的主机形成非重叠的分割,但这可能导致**不同主机的数量不同**,当要处理所有示例时,会导致不同的 JAX 程序。

process_count = 6
[
    len(tfds.load(dataset_name, split=tfds.split_for_jax_process(
        'test', process_index=process_index, process_count=process_count)))
    for process_index in range(process_count)
]
# output:
# [1667, 1667, 1667, 1667, 1666, 1666]

解决方案:填充#

即使可以通过巧妙地调整不同设备在不同主机上执行的批次数量来解决此问题,但这样的解决方案很快就会变得复杂,并且使主评估循环难以阅读,其中包含大量繁琐的逻辑。

解决此问题的更直接的方法是在数据集末尾使用填充,以确保最后一个批次与前面的批次具有相同的大小。

手动实现#

手动填充最后一个批次,使其包含与前面批次中相同数量的示例。从计算中丢弃填充示例的预测。

shard = lambda x: einops.rearrange(
    x, '(d b) ... -> d b ...', d=jax.local_device_count())
unshard = lambda x: einops.rearrange(x, 'd b ... -> (d b) ...')

correct = total = 0
for batch in ds.as_numpy_iterator():
  images = batch['image']
  n = len(images)
  padding = np.zeros([per_host_batch_size - n, *images.shape[1:]], images.dtype)
  padded_images = np.concatenate([images, padding])
  preds = unshard(get_preds(variables, shard(padded_images)))[:n]
  total += n
  correct += (batch['label'] == preds.argmax(axis=-1)).sum()

使用 pad_shard_unpad()#

上述模式,即 pad→shard→predict→unshard→unpad 序列,可以提取到一个实用程序包装器 pad_shard_unpad() 中,这大大简化了上述评估循环。

correct = total = 0
for batch in ds.as_numpy_iterator():
  preds = flax.jax_utils.pad_shard_unpad(get_preds)(
      vs, batch['image'], min_device_batch=per_device_batch_size)
  total += len(batch['image'])
  correct += (batch['label'] == preds.argmax(axis=-1)).sum()

eval_step() 中计算指标#

我们通常希望使指标计算成为评估步骤的一部分,而不是返回预测并在主评估循环中计算指标,特别是在使用诸如 clu.metricsclu.metrics 之类的库时。

在这种情况下,我们希望将指标作为 static_argnums 传递(即不进行分片/填充),并将返回值也视为 static_return (即不进行取消分片或取消填充)。

def eval_step(metrics, variables, batch):
  print('retrigger compilation', {k: v.shape for k, v in batch.items()})
  preds = model.apply(variables, batch['image'])
  correct = (batch['mask'] & (batch['label'] == preds.argmax(axis=-1))).sum()
  total = batch['mask'].sum()
  return dict(
      correct=metrics['correct'] + jax.lax.psum(correct, axis_name='batch'),
      total=metrics['total'] + jax.lax.psum(total, axis_name='batch'),
  )

eval_step = jax.pmap(eval_step, axis_name='batch')
eval_step = flax.jax_utils.pad_shard_unpad(
    eval_step, static_argnums=(0, 1), static_return=True)

添加“无限填充”#

上述解决方案在大多数情况下都有效,但它有一些限制

  1. 在极少数情况下,即使在多个主机上均匀分割数据集也会导致批次数量不同。想象一下,有一个包含 n=4097 个示例的数据集,并在 h=8 上进行评估,每个主机都有 d=8 个本地设备,并形成 b=128 的设备上批次大小。通过均匀数据集分割,第一个主机将获得 4096/8+1==513 个示例,而所有其他主机将获得 4096/8==512 个示例。形成每个主机的 d*b==512 批次,这将导致第一个主机上有两个批次,而所有其他主机上只有一个批次,违反了 SPMD 原则,并在最后一个 psum() 指令中挂起多主机设置(该指令仅由第一个主机执行,而其他主机则不执行)。

  2. 当使用 ds.filter() 动态删除示例时。

在这些更复杂的情况下,我们可以独立地在每个主机上的数据集添加“无限填充”,并继续处理示例,直到所有主机都用完未填充的示例为止。

correct = total = 0
for batch in ds.as_numpy_iterator():
  n = count_p(batch['mask'])[0].item()  # adds sync barrier
  if not n: break

  preds = get_preds(vs, batch['image']).argmax(axis=-1)
  total += n
  correct += count_correct_p(batch['label'], preds, batch['mask'])[0]