迁移学习#

本指南演示了 Flax 迁移学习工作流程的各个部分。根据任务的不同,预训练模型可以用作特征提取器,也可以作为更大模型的一部分进行微调。

本指南演示如何

  • 从 HuggingFace Transformers 加载预训练模型,并从该预训练模型中提取特定的子模块。

  • 创建分类器模型。

  • 将预训练的参数传输到新的模型结构中。

  • 使用 Optax 为模型的不同部分分别创建优化器。

  • 设置模型进行训练。

性能提示

根据您的任务,本指南中的某些内容可能不是最优的。例如,如果您只打算在预训练模型之上训练线性分类器,那么最好只提取一次特征嵌入,这样可以大大加快训练速度,并且可以使用专门的算法进行线性回归或逻辑分类。本指南展示了如何使用所有模型参数进行迁移学习。


设置#

# Note that the Transformers library doesn't use the latest Flax version.
! pip install -q "transformers[flax]"
# Install/upgrade Flax and JAX. For JAX installation with GPU/TPU support,
# visit https://github.com/jax-ml/jax#installation.
! pip install -U -q flax jax jaxlib

创建用于模型加载的函数#

为了方便加载预训练的分类器,首先创建一个返回 Flax Module 及其预训练变量的函数。

在下面的代码中,load_model 函数使用 HuggingFace 的 FlaxCLIPVisionModel 模型(来自 Transformers 库)并提取 FlaxCLIPModule 模块。

%%capture
from IPython.display import clear_output
from transformers import FlaxCLIPModel

# Note: FlaxCLIPModel is not a Flax Module
def load_model():
  clip = FlaxCLIPModel.from_pretrained('openai/clip-vit-base-patch32')
  clear_output(wait=False) # Clear the loading messages
  module = clip.module # Extract the Flax Module
  variables = {'params': clip.params} # Extract the parameters
  return module, variables

请注意,FlaxCLIPVisionModel 本身不是 Flax Module,这就是我们需要执行此额外步骤的原因。

提取子模块#

从上面的代码片段调用 load_model 会返回 FlaxCLIPModule,它由 text_modelvision_model 子模块组成。

提取在 .setup() 中定义的 vision_model 子模块及其变量的简单方法是,在 clip 模块上使用 flax.linen.Module.bind,紧接着在 vision_model 子模块上使用 flax.linen.Module.unbind

import flax.linen as nn

clip, clip_variables = load_model()
vision_model, vision_model_vars = clip.bind(clip_variables).vision_model.unbind()

创建分类器#

要创建分类器,请定义一个新的 Flax Module,其中包含一个 backbone(预训练的视觉模型)和一个 head(分类器)子模块。

from typing import Callable
import jax.numpy as jnp
import jax

class Classifier(nn.Module):
  num_classes: int
  backbone: nn.Module
  

  @nn.compact
  def __call__(self, x):
    x = self.backbone(x).pooler_output
    x = nn.Dense(
      self.num_classes, name='head', kernel_init=nn.zeros)(x)
    return x

要构建分类器 modelvision_model 模块作为 backbone 传递给 Classifier。然后,可以通过传递用于推断参数形状的伪数据来随机初始化模型的 params

num_classes = 3
model = Classifier(num_classes=num_classes, backbone=vision_model)

x = jnp.empty((1, 224, 224, 3))
variables = model.init(jax.random.key(1), x)
params = variables['params']

传输参数#

由于 params 当前是随机的,因此必须将 vision_model_vars 中的预训练参数传输到适当位置(即 backbone)的 params 结构中。

params['backbone'] = vision_model_vars['params']

注意:如果模型包含其他变量集合(例如 batch_stats),也必须传输这些变量。

优化#

如果您需要单独训练模型的不同部分,则有三个选择

  1. 使用 stop_gradient

  2. 过滤 jax.grad 的参数。

  3. 为不同的参数使用多个优化器。

在大多数情况下,我们建议通过 Optaxmulti_transform 使用多个优化器,因为它既高效又可以轻松扩展以实现许多微调策略。

optax.multi_transform#

要使用 optax.multi_transform,必须定义以下内容

  1. 参数分区。

  2. 分区及其优化器之间的映射。

  3. 与参数形状相同的 PyTree,但其叶子包含相应的分区标签。

要使用上面的模型通过 optax.multi_transform 冻结层,可以使用以下设置

  • 定义 trainablefrozen 参数分区。

  • 对于 trainable 参数,选择 Adam (optax.adam) 优化器。

  • 对于 frozen 参数,选择 optax.set_to_zero 优化器。此虚拟优化器将梯度归零,因此不进行任何训练。

  • 使用 flax.traverse_util.path_aware_map 将参数映射到分区,将 backbone 中的叶子标记为 frozen,其余的标记为 trainable

from flax import traverse_util
import optax

partition_optimizers = {'trainable': optax.adam(5e-3), 'frozen': optax.set_to_zero()}
param_partitions = traverse_util.path_aware_map(
  lambda path, v: 'frozen' if 'backbone' in path else 'trainable', params)
tx = optax.multi_transform(partition_optimizers, param_partitions)

# visualize a subset of the param_partitions structure
flat = list(traverse_util.flatten_dict(param_partitions).items())
traverse_util.unflatten_dict(dict(flat[:2] + flat[-2:]))
FrozenDict({
    backbone: {
        embeddings: {
            class_embedding: 'frozen',
            patch_embedding: {
                kernel: 'frozen',
            },
        },
    },
    head: {
        bias: 'trainable',
        kernel: 'trainable',
    },
})

为了实现差分学习率optax.set_to_zero 可以替换为任何其他优化器,并且可以根据任务选择不同的优化器和分区方案。有关高级优化器的更多信息,请参阅 Optax 的 组合优化器文档。

创建 TrainState#

一旦定义了模块、参数和优化器,就可以像往常一样构建 TrainState

from flax.training.train_state import TrainState

state = TrainState.create(
  apply_fn=model.apply,
  params=params,
  tx=tx)

由于优化器负责冻结或微调策略,train_step 不需要额外的更改,训练可以正常进行。