本指南将向您展示如何从模块中提取中间值。让我们从这个使用 nn.compact
的简单 CNN 开始。
from flax import linen as nn
import jax
import jax.numpy as jnp
from typing import Sequence
class CNN(nn.Module):
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
因为此模块使用 nn.compact
可以通过调用 sow
来增强 CNN,以如下方式存储中间值
class CNN(nn.Module):
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
class SowCNN(nn.Module):
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
self.sow('intermediates', 'features', x)
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
充当空操作。因此,它非常适合调试和可选的中间值跟踪。“intermediates” 集合也由 capture_intermediates
API 使用(请参阅 使用 capture_intermediates 部分)。
- 请参阅Module.sow()
class SowCNN2(nn.Module):
def __call__(self, x):
mod = SowCNN(name='SowCNN')
return mod(x) + mod(x) # Calling same module instance twice.
def init(key, x):
variables = SowCNN2().init(key, x)
# By default the 'intermediates' collection is not mutable during init.
# So variables will only contain 'params' here.
return variables
def predict(variables, x):
# If mutable='intermediates' is not specified, then .sow() acts as a noop.
output, mod_vars = SowCNN2().apply(variables, x, mutable='intermediates')
features = mod_vars['intermediates']['SowCNN']['features']
return output, features
batch = jnp.ones((1,28,28,1))
variables = init(jax.random.key(0), batch)
preds, feats = predict(variables, batch)
assert len(feats) == 2 # Tuple with two values since module was called twice.
对于那些清楚知道如何拆分子模块的情况,这是一种有用的模式。您在 setup
中公开的任何子模块都可以直接使用。在限制情况下,您可以在 setup
中定义所有子模块,并完全避免使用 nn.compact
class RefactoredCNN(nn.Module):
def setup(self):
self.features = Features()
self.classifier = Classifier()
def __call__(self, x):
x = self.features(x)
x = self.classifier(x)
return x
class Features(nn.Module):
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
return x
class Classifier(nn.Module):
def __call__(self, x):
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
def init(key, x):
variables = RefactoredCNN().init(key, x)
return variables['params']
def features(params, x):
return RefactoredCNN().apply({"params": params}, x,
method=lambda module, x: module.features(x))
params = init(jax.random.key(0), batch)
features(params, batch)
使用 capture_intermediates
Linen 支持自动捕获子模块的中间返回值,而无需任何代码更改。此模式应被视为捕获中间值的“大锤”方法。作为调试和检查工具,它非常有用,但是使用本指南中描述的其他模式将使您对要提取的中间值进行更精细的控制。
在以下代码示例中,我们检查是否有任何中间激活是非有限的(NaN 或无穷大)
def init(key, x):
variables = CNN().init(key, x)
return variables
def predict(variables, x):
y, state = CNN().apply(variables, x, capture_intermediates=True, mutable=["intermediates"])
intermediates = state['intermediates']
fin = jax.tree_util.tree_map(lambda xs: jnp.all(jnp.isfinite(xs)), intermediates)
return y, fin
variables = init(jax.random.key(0), batch)
y, is_finite = predict(variables, batch)
all_finite = all(jax.tree_util.tree_leaves(is_finite))
assert all_finite, "non-finite intermediate detected!"
默认情况下,仅收集 __call__
方法的中间值。或者,您可以基于 Module
filter_Dense = lambda mdl, method_name: isinstance(mdl, nn.Dense)
filter_encodings = lambda mdl, method_name: method_name == "encode"
y, state = CNN().apply(variables, batch, capture_intermediates=filter_Dense, mutable=["intermediates"])
dense_intermediates = state['intermediates']
仅适用于层。您可以使用 self.sow
class Model(nn.Module):
def __call__(self, x):
a = nn.Dense(4)(x) # Dense_0
b = nn.Dense(4)(x) # Dense_1
c = a + b # not a Flax layer, so won't be stored as an intermediate
d = nn.Dense(4)(c) # Dense_2
return d
def init(key, x):
variables = Model().init(key, x)
return variables['params']
def predict(params, x):
return Model().apply({"params": params}, x, capture_intermediates=True)
batch = jax.random.uniform(jax.random.key(1), (1,3))
params = init(jax.random.key(0), batch)
preds, feats = predict(params, batch)
feats # intermediate c in Model was not stored because it's not a Flax layer
class Model(nn.Module):
def __call__(self, x):
a = nn.Dense(4)(x) # Dense_0
b = nn.Dense(4)(x) # Dense_1
c = a + b
self.sow('intermediates', 'c', c) # store intermediate c
d = nn.Dense(4)(c) # Dense_2
return d
def init(key, x):
variables = Model().init(key, x)
return variables['params']
def predict(params, x):
# filter specifically for only the Dense_0 and Dense_2 layer
filter_fn = lambda mdl, method_name: isinstance(mdl.name, str) and (mdl.name in {'Dense_0', 'Dense_2'})
return Model().apply({"params": params}, x, capture_intermediates=filter_fn)
batch = jax.random.uniform(jax.random.key(1), (1,3))
params = init(jax.random.key(0), batch)
preds, feats = predict(params, batch)
feats # intermediate c in Model is stored and isn't filtered out by the filter function
为了将从 self.sow
提取的中间值与从 capture_intermediates
提取的中间值分开,我们可以定义一个单独的集合,例如 self.sow('sow_intermediates', 'c', c)
,或者在调用 .apply()
flattened_dict = flax.traverse_util.flatten_dict(feats['intermediates'], sep='/')
在效率方面,只要所有内容都经过 jit 处理,那么任何您最终没有使用的中间值都应该被 XLA 优化掉。
使用 Sequential
您还可以使用 Sequential
组合器的简单实现来定义 CNN
(这在更多有状态的方法中很常见)。这对于非常简单的模型可能很有用,并为您提供任意的模型手术。但是,它可能非常有限 – 如果您甚至想添加一个条件,您将被迫从 Sequential
class Sequential(nn.Module):
layers: Sequence[nn.Module]
def __call__(self, x):
for layer in self.layers:
x = layer(x)
return x
def SeqCNN():
return Sequential([
nn.Conv(features=32, kernel_size=(3, 3)),
lambda x: nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)),
nn.Conv(features=64, kernel_size=(3, 3)),
lambda x: nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)),
lambda x: x.reshape((x.shape[0], -1)), # flatten
def init(key, x):
variables = SeqCNN().init(key, x)
return variables['params']
def features(params, x):
return Sequential(SeqCNN().layers[0:7]).apply({"params": params}, x)
batch = jnp.ones((1,28,28,1))
params = init(jax.random.key(0), batch)
features(params, batch)
出于调试目的,提取中间值的梯度可能很有用。这可以通过在所需值上使用 Module.perturb()
class Model(nn.Module):
def __call__(self, x):
x = nn.relu(nn.Dense(8)(x))
x = self.perturb('hidden', x)
x = nn.Dense(2)(x)
x = self.perturb('logits', x)
return x
默认情况下将变量添加到 perturbations
x = jnp.empty((1, 4)) # random data
y = jnp.empty((1, 2)) # random data
model = Model()
variables = model.init(jax.random.key(1), x)
params, perturbations = variables['params'], variables['perturbations']
def loss_fn(params, perturbations, x, y):
y_pred = model.apply({'params': params, 'perturbations': perturbations}, x)
return jnp.mean((y_pred - y) ** 2)
intermediate_grads = jax.grad(loss_fn, argnums=1)(params, perturbations, x, y)