迁移学习#
本指南演示了 Flax 迁移学习工作流程的各个部分。根据任务的不同,预训练模型可以用作特征提取器,也可以作为更大模型的一部分进行微调。
本指南演示如何
从 HuggingFace Transformers 加载预训练模型,并从该预训练模型中提取特定的子模块。
创建分类器模型。
将预训练的参数传输到新的模型结构中。
使用 Optax 为模型的不同部分分别创建优化器。
设置模型进行训练。
性能提示
根据您的任务,本指南中的某些内容可能不是最优的。例如,如果您只打算在预训练模型之上训练线性分类器,那么最好只提取一次特征嵌入,这样可以大大加快训练速度,并且可以使用专门的算法进行线性回归或逻辑分类。本指南展示了如何使用所有模型参数进行迁移学习。
设置#
# Note that the Transformers library doesn't use the latest Flax version.
! pip install -q "transformers[flax]"
# Install/upgrade Flax and JAX. For JAX installation with GPU/TPU support,
# visit https://github.com/jax-ml/jax#installation.
! pip install -U -q flax jax jaxlib
创建用于模型加载的函数#
为了方便加载预训练的分类器,首先创建一个返回 Flax Module
及其预训练变量的函数。
在下面的代码中,load_model
函数使用 HuggingFace 的 FlaxCLIPVisionModel
模型(来自 Transformers 库)并提取 FlaxCLIPModule
模块。
%%capture
from IPython.display import clear_output
from transformers import FlaxCLIPModel
# Note: FlaxCLIPModel is not a Flax Module
def load_model():
clip = FlaxCLIPModel.from_pretrained('openai/clip-vit-base-patch32')
clear_output(wait=False) # Clear the loading messages
module = clip.module # Extract the Flax Module
variables = {'params': clip.params} # Extract the parameters
return module, variables
请注意,FlaxCLIPVisionModel
本身不是 Flax Module
,这就是我们需要执行此额外步骤的原因。
提取子模块#
从上面的代码片段调用 load_model
会返回 FlaxCLIPModule
,它由 text_model
和 vision_model
子模块组成。
提取在 .setup()
中定义的 vision_model
子模块及其变量的简单方法是,在 clip
模块上使用 flax.linen.Module.bind
,紧接着在 vision_model
子模块上使用 flax.linen.Module.unbind
。
import flax.linen as nn
clip, clip_variables = load_model()
vision_model, vision_model_vars = clip.bind(clip_variables).vision_model.unbind()
创建分类器#
要创建分类器,请定义一个新的 Flax Module
,其中包含一个 backbone
(预训练的视觉模型)和一个 head
(分类器)子模块。
from typing import Callable
import jax.numpy as jnp
import jax
class Classifier(nn.Module):
num_classes: int
backbone: nn.Module
@nn.compact
def __call__(self, x):
x = self.backbone(x).pooler_output
x = nn.Dense(
self.num_classes, name='head', kernel_init=nn.zeros)(x)
return x
要构建分类器 model
,vision_model
模块作为 backbone
传递给 Classifier
。然后,可以通过传递用于推断参数形状的伪数据来随机初始化模型的 params
。
num_classes = 3
model = Classifier(num_classes=num_classes, backbone=vision_model)
x = jnp.empty((1, 224, 224, 3))
variables = model.init(jax.random.key(1), x)
params = variables['params']
传输参数#
由于 params
当前是随机的,因此必须将 vision_model_vars
中的预训练参数传输到适当位置(即 backbone
)的 params
结构中。
params['backbone'] = vision_model_vars['params']
注意:如果模型包含其他变量集合(例如 batch_stats
),也必须传输这些变量。
优化#
如果您需要单独训练模型的不同部分,则有三个选择
使用
stop_gradient
。过滤
jax.grad
的参数。为不同的参数使用多个优化器。
在大多数情况下,我们建议通过 Optax 的 multi_transform
使用多个优化器,因为它既高效又可以轻松扩展以实现许多微调策略。
optax.multi_transform#
要使用 optax.multi_transform
,必须定义以下内容
参数分区。
分区及其优化器之间的映射。
与参数形状相同的 PyTree,但其叶子包含相应的分区标签。
要使用上面的模型通过 optax.multi_transform
冻结层,可以使用以下设置
定义
trainable
和frozen
参数分区。对于
trainable
参数,选择 Adam (optax.adam
) 优化器。
对于
frozen
参数,选择optax.set_to_zero
优化器。此虚拟优化器将梯度归零,因此不进行任何训练。使用
flax.traverse_util.path_aware_map
将参数映射到分区,将backbone
中的叶子标记为frozen
,其余的标记为trainable
。
from flax import traverse_util
import optax
partition_optimizers = {'trainable': optax.adam(5e-3), 'frozen': optax.set_to_zero()}
param_partitions = traverse_util.path_aware_map(
lambda path, v: 'frozen' if 'backbone' in path else 'trainable', params)
tx = optax.multi_transform(partition_optimizers, param_partitions)
# visualize a subset of the param_partitions structure
flat = list(traverse_util.flatten_dict(param_partitions).items())
traverse_util.unflatten_dict(dict(flat[:2] + flat[-2:]))
FrozenDict({
backbone: {
embeddings: {
class_embedding: 'frozen',
patch_embedding: {
kernel: 'frozen',
},
},
},
head: {
bias: 'trainable',
kernel: 'trainable',
},
})
为了实现差分学习率,optax.set_to_zero
可以替换为任何其他优化器,并且可以根据任务选择不同的优化器和分区方案。有关高级优化器的更多信息,请参阅 Optax 的 组合优化器文档。
创建 TrainState
#
一旦定义了模块、参数和优化器,就可以像往常一样构建 TrainState
from flax.training.train_state import TrainState
state = TrainState.create(
apply_fn=model.apply,
params=params,
tx=tx)
由于优化器负责冻结或微调策略,train_step
不需要额外的更改,训练可以正常进行。