提取中间值#
本指南将向您展示如何从模块中提取中间值。让我们从这个使用 nn.compact
的简单 CNN 开始。
from flax import linen as nn
import jax
import jax.numpy as jnp
from typing import Sequence
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
因为此模块使用 nn.compact
,所以我们无法直接访问中间值。有几种方法可以公开它们
将中间值存储在新变量集合中#
可以通过调用 sow
来增强 CNN,以如下方式存储中间值
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
class SowCNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
self.sow('intermediates', 'features', x)
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
当变量集合不可变时,sow
充当空操作。因此,它非常适合调试和可选的中间值跟踪。“intermediates” 集合也由 capture_intermediates
API 使用(请参阅 使用 capture_intermediates 部分)。
请注意,默认情况下,sow
每次调用时都会追加值
这是必要的,因为一旦实例化,一个模块可以在其父模块中被多次调用,并且我们希望捕获所有播种的值。
因此,您要确保不要将中间值反馈到
variables
中。否则,每次调用都会增加该元组的长度并触发重新编译。要覆盖默认的追加行为,请指定
init_fn
和reduce_fn
- 请参阅Module.sow()
。
class SowCNN2(nn.Module):
@nn.compact
def __call__(self, x):
mod = SowCNN(name='SowCNN')
return mod(x) + mod(x) # Calling same module instance twice.
@jax.jit
def init(key, x):
variables = SowCNN2().init(key, x)
# By default the 'intermediates' collection is not mutable during init.
# So variables will only contain 'params' here.
return variables
@jax.jit
def predict(variables, x):
# If mutable='intermediates' is not specified, then .sow() acts as a noop.
output, mod_vars = SowCNN2().apply(variables, x, mutable='intermediates')
features = mod_vars['intermediates']['SowCNN']['features']
return output, features
batch = jnp.ones((1,28,28,1))
variables = init(jax.random.key(0), batch)
preds, feats = predict(variables, batch)
assert len(feats) == 2 # Tuple with two values since module was called twice.
将模块重构为子模块#
对于那些清楚知道如何拆分子模块的情况,这是一种有用的模式。您在 setup
中公开的任何子模块都可以直接使用。在限制情况下,您可以在 setup
中定义所有子模块,并完全避免使用 nn.compact
。
class RefactoredCNN(nn.Module):
def setup(self):
self.features = Features()
self.classifier = Classifier()
def __call__(self, x):
x = self.features(x)
x = self.classifier(x)
return x
class Features(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
return x
class Classifier(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
@jax.jit
def init(key, x):
variables = RefactoredCNN().init(key, x)
return variables['params']
@jax.jit
def features(params, x):
return RefactoredCNN().apply({"params": params}, x,
method=lambda module, x: module.features(x))
params = init(jax.random.key(0), batch)
features(params, batch)
使用 capture_intermediates
#
Linen 支持自动捕获子模块的中间返回值,而无需任何代码更改。此模式应被视为捕获中间值的“大锤”方法。作为调试和检查工具,它非常有用,但是使用本指南中描述的其他模式将使您对要提取的中间值进行更精细的控制。
在以下代码示例中,我们检查是否有任何中间激活是非有限的(NaN 或无穷大)
@jax.jit
def init(key, x):
variables = CNN().init(key, x)
return variables
@jax.jit
def predict(variables, x):
y, state = CNN().apply(variables, x, capture_intermediates=True, mutable=["intermediates"])
intermediates = state['intermediates']
fin = jax.tree_util.tree_map(lambda xs: jnp.all(jnp.isfinite(xs)), intermediates)
return y, fin
variables = init(jax.random.key(0), batch)
y, is_finite = predict(variables, batch)
all_finite = all(jax.tree_util.tree_leaves(is_finite))
assert all_finite, "non-finite intermediate detected!"
默认情况下,仅收集 __call__
方法的中间值。或者,您可以基于 Module
实例和方法名称传递自定义筛选函数。
filter_Dense = lambda mdl, method_name: isinstance(mdl, nn.Dense)
filter_encodings = lambda mdl, method_name: method_name == "encode"
y, state = CNN().apply(variables, batch, capture_intermediates=filter_Dense, mutable=["intermediates"])
dense_intermediates = state['intermediates']
请注意,capture_intermediates
仅适用于层。您可以使用 self.sow
手动存储非层中间值,但筛选函数不会应用于它。
class Model(nn.Module):
@nn.compact
def __call__(self, x):
a = nn.Dense(4)(x) # Dense_0
b = nn.Dense(4)(x) # Dense_1
c = a + b # not a Flax layer, so won't be stored as an intermediate
d = nn.Dense(4)(c) # Dense_2
return d
@jax.jit
def init(key, x):
variables = Model().init(key, x)
return variables['params']
@jax.jit
def predict(params, x):
return Model().apply({"params": params}, x, capture_intermediates=True)
batch = jax.random.uniform(jax.random.key(1), (1,3))
params = init(jax.random.key(0), batch)
preds, feats = predict(params, batch)
feats # intermediate c in Model was not stored because it's not a Flax layer
class Model(nn.Module):
@nn.compact
def __call__(self, x):
a = nn.Dense(4)(x) # Dense_0
b = nn.Dense(4)(x) # Dense_1
c = a + b
self.sow('intermediates', 'c', c) # store intermediate c
d = nn.Dense(4)(c) # Dense_2
return d
@jax.jit
def init(key, x):
variables = Model().init(key, x)
return variables['params']
@jax.jit
def predict(params, x):
# filter specifically for only the Dense_0 and Dense_2 layer
filter_fn = lambda mdl, method_name: isinstance(mdl.name, str) and (mdl.name in {'Dense_0', 'Dense_2'})
return Model().apply({"params": params}, x, capture_intermediates=filter_fn)
batch = jax.random.uniform(jax.random.key(1), (1,3))
params = init(jax.random.key(0), batch)
preds, feats = predict(params, batch)
feats # intermediate c in Model is stored and isn't filtered out by the filter function
为了将从 self.sow
提取的中间值与从 capture_intermediates
提取的中间值分开,我们可以定义一个单独的集合,例如 self.sow('sow_intermediates', 'c', c)
,或者在调用 .apply()
后手动筛选出中间值。例如
flattened_dict = flax.traverse_util.flatten_dict(feats['intermediates'], sep='/')
flattened_dict['c']
在效率方面,只要所有内容都经过 jit 处理,那么任何您最终没有使用的中间值都应该被 XLA 优化掉。
使用 Sequential
#
您还可以使用 Sequential
组合器的简单实现来定义 CNN
(这在更多有状态的方法中很常见)。这对于非常简单的模型可能很有用,并为您提供任意的模型手术。但是,它可能非常有限 – 如果您甚至想添加一个条件,您将被迫从 Sequential
重构,并更明确地构造您的模型。
class Sequential(nn.Module):
layers: Sequence[nn.Module]
def __call__(self, x):
for layer in self.layers:
x = layer(x)
return x
def SeqCNN():
return Sequential([
nn.Conv(features=32, kernel_size=(3, 3)),
nn.relu,
lambda x: nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)),
nn.Conv(features=64, kernel_size=(3, 3)),
nn.relu,
lambda x: nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)),
lambda x: x.reshape((x.shape[0], -1)), # flatten
nn.Dense(features=256),
nn.relu,
nn.Dense(features=10),
nn.log_softmax,
])
@jax.jit
def init(key, x):
variables = SeqCNN().init(key, x)
return variables['params']
@jax.jit
def features(params, x):
return Sequential(SeqCNN().layers[0:7]).apply({"params": params}, x)
batch = jnp.ones((1,28,28,1))
params = init(jax.random.key(0), batch)
features(params, batch)
提取中间值的梯度#
出于调试目的,提取中间值的梯度可能很有用。这可以通过在所需值上使用 Module.perturb()
方法来完成。
class Model(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.relu(nn.Dense(8)(x))
x = self.perturb('hidden', x)
x = nn.Dense(2)(x)
x = self.perturb('logits', x)
return x
perturb
默认情况下将变量添加到 perturbations
集合中,它的行为类似于恒等函数,并且扰动的梯度与输入的梯度匹配。要获得扰动,只需初始化模型
x = jnp.empty((1, 4)) # random data
y = jnp.empty((1, 2)) # random data
model = Model()
variables = model.init(jax.random.key(1), x)
params, perturbations = variables['params'], variables['perturbations']
最后,计算损失相对于扰动的梯度,这些将与中间值的梯度匹配
def loss_fn(params, perturbations, x, y):
y_pred = model.apply({'params': params, 'perturbations': perturbations}, x)
return jnp.mean((y_pred - y) ** 2)
intermediate_grads = jax.grad(loss_fn, argnums=1)(params, perturbations, x, y)