模型手术#

通常,Flax 模块和优化器会跟踪并更新参数。但有时您可能想进行一些模型手术,并自己调整参数张量。本指南将向您展示如何进行此操作。

设置#

!pip install --upgrade -q pip jax jaxlib flax
import functools

import jax
import jax.numpy as jnp
from flax import traverse_util
from flax import linen as nn
from flax.core import freeze
import jax
import optax

Flax 模块的手术#

让我们创建一个小型卷积神经网络模型进行演示。

像往常一样,您可以运行 CNN.init(...)['params'] 来获取 params,以便在训练的每一步传递和修改它。

class CNN(nn.Module):
    @nn.compact
    def __call__(self, x):
      x = nn.Conv(features=32, kernel_size=(3, 3))(x)
      x = nn.relu(x)
      x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
      x = nn.Conv(features=64, kernel_size=(3, 3))(x)
      x = nn.relu(x)
      x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
      x = x.reshape((x.shape[0], -1))
      x = nn.Dense(features=256)(x)
      x = nn.relu(x)
      x = nn.Dense(features=10)(x)
      x = nn.log_softmax(x)
      return x

def get_initial_params(key):
    init_shape = jnp.ones((1, 28, 28, 1), jnp.float32)
    initial_params = CNN().init(key, init_shape)['params']
    return initial_params

key = jax.random.key(0)
params = get_initial_params(key)

jax.tree_util.tree_map(jnp.shape, params)
{'Conv_0': {'bias': (32,), 'kernel': (3, 3, 1, 32)},
 'Conv_1': {'bias': (64,), 'kernel': (3, 3, 32, 64)},
 'Dense_0': {'bias': (256,), 'kernel': (3136, 256)},
 'Dense_1': {'bias': (10,), 'kernel': (256, 10)}}

请注意,返回的 params 是一个 FrozenDict,其中包含一些 JAX 数组作为内核和偏差。

FrozenDict 只不过是一个只读字典,Flax 使其只读的原因在于 JAX 的函数式特性:JAX 数组是不可变的,新的 params 需要替换旧的 params。使字典只读可确保在训练和更新期间不会意外发生字典的原地突变。

在 Flax 模块外部实际修改参数的一种方法是显式地展平它并创建一个可变的字典。请注意,您可以使用分隔符 sep 来连接所有嵌套的键。如果没有给定 sep,则键将是所有嵌套键的元组。

# Get a flattened key-value list.
flat_params = traverse_util.flatten_dict(params, sep='/')

jax.tree_util.tree_map(jnp.shape, flat_params)
{'Conv_0/bias': (32,),
 'Conv_0/kernel': (3, 3, 1, 32),
 'Conv_1/bias': (64,),
 'Conv_1/kernel': (3, 3, 32, 64),
 'Dense_0/bias': (256,),
 'Dense_0/kernel': (3136, 256),
 'Dense_1/bias': (10,),
 'Dense_1/kernel': (256, 10)}

现在,您可以对参数执行任何操作。完成后,将其展平还原并在以后的训练中使用它。

# Somehow modify a layer
dense_kernel = flat_params['Dense_1/kernel']
flat_params['Dense_1/kernel'] = dense_kernel / jnp.linalg.norm(dense_kernel)

# Unflatten.
unflat_params = traverse_util.unflatten_dict(flat_params, sep='/')
# Refreeze.
unflat_params = freeze(unflat_params)
jax.tree_util.tree_map(jnp.shape, unflat_params)
FrozenDict({
    Conv_0: {
        bias: (32,),
        kernel: (3, 3, 1, 32),
    },
    Conv_1: {
        bias: (64,),
        kernel: (3, 3, 32, 64),
    },
    Dense_0: {
        bias: (256,),
        kernel: (3136, 256),
    },
    Dense_1: {
        bias: (10,),
        kernel: (256, 10),
    },
})

优化器的手术#

当使用 Optax 作为优化器时,opt_state 实际上是组成优化器的各个梯度转换状态的嵌套元组。这些状态包含反映参数树的 pytree,并且可以采用相同的方式进行修改:展平、修改、展平还原,然后重新创建与原始状态镜像的新优化器状态。

tx = optax.adam(1.0)
opt_state = tx.init(params)

# The optimizer state is a tuple of gradient transformation states.
jax.tree_util.tree_map(jnp.shape, opt_state)
(ScaleByAdamState(count=(), mu={'Conv_0': {'bias': (32,), 'kernel': (3, 3, 1, 32)}, 'Conv_1': {'bias': (64,), 'kernel': (3, 3, 32, 64)}, 'Dense_0': {'bias': (256,), 'kernel': (3136, 256)}, 'Dense_1': {'bias': (10,), 'kernel': (256, 10)}}, nu={'Conv_0': {'bias': (32,), 'kernel': (3, 3, 1, 32)}, 'Conv_1': {'bias': (64,), 'kernel': (3, 3, 32, 64)}, 'Dense_0': {'bias': (256,), 'kernel': (3136, 256)}, 'Dense_1': {'bias': (10,), 'kernel': (256, 10)}}),
 EmptyState())

优化器状态内的 pytree 遵循与参数相同的结构,并且可以完全以相同的方式进行展平/修改。

flat_mu = traverse_util.flatten_dict(opt_state[0].mu, sep='/')
flat_nu = traverse_util.flatten_dict(opt_state[0].nu, sep='/')

jax.tree_util.tree_map(jnp.shape, flat_mu)
{'Conv_0/bias': (32,),
 'Conv_0/kernel': (3, 3, 1, 32),
 'Conv_1/bias': (64,),
 'Conv_1/kernel': (3, 3, 32, 64),
 'Dense_0/bias': (256,),
 'Dense_0/kernel': (3136, 256),
 'Dense_1/bias': (10,),
 'Dense_1/kernel': (256, 10)}

修改后,重新创建优化器状态。将其用于以后的训练。

opt_state = (
    opt_state[0]._replace(
        mu=traverse_util.unflatten_dict(flat_mu, sep='/'),
        nu=traverse_util.unflatten_dict(flat_nu, sep='/'),
    ),
) + opt_state[1:]
jax.tree_util.tree_map(jnp.shape, opt_state)
(ScaleByAdamState(count=(), mu={'Conv_0': {'bias': (32,), 'kernel': (3, 3, 1, 32)}, 'Conv_1': {'bias': (64,), 'kernel': (3, 3, 32, 64)}, 'Dense_0': {'bias': (256,), 'kernel': (3136, 256)}, 'Dense_1': {'bias': (10,), 'kernel': (256, 10)}}, nu={'Conv_0': {'bias': (32,), 'kernel': (3, 3, 1, 32)}, 'Conv_1': {'bias': (64,), 'kernel': (3, 3, 32, 64)}, 'Dense_0': {'bias': (256,), 'kernel': (3136, 256)}, 'Dense_1': {'bias': (10,), 'kernel': (256, 10)}}),
 EmptyState())