用 Jax+Flax 编写的神经网络期望其输入数据为 jax.numpy
数组实例。 因此,从任何来源加载数据集都像将其转换为 jax.numpy
例如,本指南演示如何使用 Torchvision、Tensorflow 和 Hugging Face 的 API 导入 MNIST。 我们会将整个数据集加载到内存中。 对于不适合内存的数据集,该过程是类似的,但应以分批方式完成。
MNIST 数据集由 28x28 像素的灰度手写数字图像组成,并具有指定的 60k/10k 训练/测试拆分。 任务是预测每张图像的正确类别(数字 0, ..., 9)。
假设使用基于 CNN 的分类器,输入数据的形状应为 (B, 28, 28, 1)
标签只是表示与图像对应的数字的整数。 因此,标签的形状应为 (B,)
,以便可以使用 optax.softmax_cross_entropy_with_integer_labels
import numpy as np
import jax.numpy as jnp
从 torchvision.datasets
import torchvision
def get_dataset_torch():
mnist = {
'train': torchvision.datasets.MNIST('./data', train=True, download=True),
'test': torchvision.datasets.MNIST('./data', train=False, download=True)
ds = {}
for split in ['train', 'test']:
ds[split] = {
'image': mnist[split].data.numpy(),
'label': mnist[split].targets.numpy()
# cast from np to jnp and rescale the pixel values from [0,255] to [0,1]
ds[split]['image'] = jnp.float32(ds[split]['image']) / 255
ds[split]['label'] = jnp.int16(ds[split]['label'])
# torchvision returns shape (B, 28, 28).
# hence, append the trailing channel dimension.
ds[split]['image'] = jnp.expand_dims(ds[split]['image'], 3)
return ds['train'], ds['test']
train, test = get_dataset_torch()
print(train['image'].shape, train['image'].dtype)
print(train['label'].shape, train['label'].dtype)
print(test['image'].shape, test['image'].dtype)
print(test['label'].shape, test['label'].dtype)
(60000, 28, 28, 1) float32
(60000,) int16
(10000, 28, 28, 1) float32
(10000,) int16
从 tensorflow_datasets
import tensorflow_datasets as tfds
def get_dataset_tf():
mnist = tfds.builder('mnist')
ds = {}
for split in ['train', 'test']:
ds[split] = tfds.as_numpy(mnist.as_dataset(split=split, batch_size=-1))
# cast to jnp and rescale pixel values
ds[split]['image'] = jnp.float32(ds[split]['image']) / 255
ds[split]['label'] = jnp.int16(ds[split]['label'])
return ds['train'], ds['test']
train, test = get_dataset_tf()
print(train['image'].shape, train['image'].dtype)
print(train['label'].shape, train['label'].dtype)
print(test['image'].shape, test['image'].dtype)
print(test['label'].shape, test['label'].dtype)
(60000, 28, 28, 1) float32
(60000,) int16
(10000, 28, 28, 1) float32
(10000,) int16
从 🤗 Hugging Face datasets
#!pip install datasets # datasets isn't preinstalled on Colab; uncomment to install
from datasets import load_dataset
def get_dataset_hf():
mnist = load_dataset("mnist")
ds = {}
for split in ['train', 'test']:
ds[split] = {
'image': np.array([np.array(im) for im in mnist[split]['image']]),
'label': np.array(mnist[split]['label'])
# cast to jnp and rescale pixel values
ds[split]['image'] = jnp.float32(ds[split]['image']) / 255
ds[split]['label'] = jnp.int16(ds[split]['label'])
# append trailing channel dimension
ds[split]['image'] = jnp.expand_dims(ds[split]['image'], 3)
return ds['train'], ds['test']
train, test = get_dataset_hf()
print(train['image'].shape, train['image'].dtype)
print(train['label'].shape, train['label'].dtype)
print(test['image'].shape, test['image'].dtype)
print(test['label'].shape, test['label'].dtype)
(60000, 28, 28, 1) float32
(60000,) int16
(10000, 28, 28, 1) float32
(10000,) int16