Open in Colab Open On GitHub

Flax 基础#

本笔记本将引导您完成以下工作流程

  • 从 Flax 内置层或第三方模型实例化模型。

  • 初始化模型的参数和手动编写的训练。

  • 使用 Flax 提供的优化器来简化训练。

  • 序列化参数和其他对象。

  • 创建您自己的模型并管理状态。

设置我们的环境#

在这里,我们提供了设置笔记本环境所需的代码。

# Install the latest JAXlib version.
!pip install --upgrade -q pip jax jaxlib
# Install Flax at head:
!pip install --upgrade -q git+https://github.com/google/flax.git
WARNING: Running pip as root will break packages and permissions. You should install packages reliably by using venv: https://pip.pythonlang.cn/warnings/venv
WARNING: Running pip as root will break packages and permissions. You should install packages reliably by using venv: https://pip.pythonlang.cn/warnings/venv
import jax
from typing import Any, Callable, Sequence
from jax import random, numpy as jnp
import flax
from flax import linen as nn

使用 Flax 的线性回归#

在之前的不耐烦的 JAX 笔记本中,我们以一个线性回归示例结束。众所周知,线性回归也可以写成一个单一的密集神经网络层,我们将在下面展示,以便我们可以比较它的完成方式。

密集层是一个具有核参数\(W\in\mathcal{M}_{m,n}(\mathbb{R})\)的层,其中\(m\)是模型输出的特征数量,而\(n\)是输入的维度,以及偏置参数\(b\in\mathbb{R}^m\)。密集层从输入\(x\in\mathbb{R}^n\)返回\(Wx+b\)

Flax 的 flax.linen 模块(此处导入为 nn)中已经提供了这个密集层。

# We create one dense layer instance (taking 'features' parameter as input)
model = nn.Dense(features=5)

层(以及通常的模型,我们从现在开始使用这个词)是 linen.Module 类的子类。

模型参数 & 初始化#

参数不会与模型本身一起存储。您需要使用 PRNGKey 和虚拟输入数据调用 init 函数来初始化参数。

key1, key2 = random.split(random.key(0))
x = random.normal(key1, (10,)) # Dummy input data
params = model.init(key2, x) # Initialization call
jax.tree_util.tree_map(lambda x: x.shape, params) # Checking output shapes
{'params': {'bias': (5,), 'kernel': (10, 5)}}

注意:JAX 和 Flax 与 NumPy 一样,是基于行的系统,这意味着向量表示为行向量而不是列向量。这可以在此处内核的形状中看到。

结果是我们所期望的:正确大小的偏置和内核参数。在后台

  • 虚拟输入数据 x 用于触发形状推断:我们只声明了模型输出中所需的特征数量,而不是输入的大小。Flax 自己找出内核的正确大小。

  • 随机 PRNG 密钥用于触发初始化函数(这些函数具有此处模块提供的默认值)。

  • 调用初始化函数来生成模型将使用的初始参数集。这些函数将 (PRNG Key、形状、dtype) 作为参数,并返回形状为 shape 的数组。

  • init 函数返回初始化的参数集(您也可以通过使用 init_with_output 方法而不是 init,以相同的语法获得虚拟输入的正向传递输出)。

要使用给定的一组参数(这些参数永远不会与模型一起存储)执行模型的正向传递,我们只需使用 apply 方法,并向其提供要使用的参数以及输入

model.apply(params, x)
Array([-1.3721193 ,  0.61131495,  0.6442836 ,  2.2192965 , -1.1271116 ],      dtype=float32)

梯度下降#

如果您直接跳到这里,而没有浏览 JAX 部分,以下是我们要使用的线性回归公式:从一组数据点 \(\{(x_i,y_i), i\in \{1,\ldots, k\}, x_i\in\mathbb{R}^n,y_i\in\mathbb{R}^m\}\) 中,我们尝试找到一组参数 \(W\in \mathcal{M}_{m,n}(\mathbb{R}), b\in\mathbb{R}^m\),使得函数 \(f_{W,b}(x)=Wx+b\) 最小化均方误差

\[\mathcal{L}(W,b)\rightarrow\frac{1}{k}\sum_{i=1}^{k} \frac{1}{2}\|y_i-f_{W,b}(x_i)\|^2_2\]

在这里,我们看到元组 \((W,b)\) 与 Dense 层的参数匹配。我们将使用这些参数执行梯度下降。让我们首先生成我们将使用的虚拟数据。数据与 JAX 部分的线性回归 pytree 示例中的数据完全相同。

# Set problem dimensions.
n_samples = 20
x_dim = 10
y_dim = 5

# Generate random ground truth W and b.
key = random.key(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))
# Store the parameters in a FrozenDict pytree.
true_params = flax.core.freeze({'params': {'bias': b, 'kernel': W}})

# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise,(n_samples, y_dim))
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)
x shape: (20, 10) ; y shape: (20, 5)

我们复制我们在 JAX pytree 线性回归示例中使用的相同训练循环,其中包含 jax.value_and_grad(),但在这里我们可以使用 model.apply(),而无需定义我们自己的前馈函数(JAX 示例中的 predict_pytree())。

# Same as JAX version but using model.apply().
@jax.jit
def mse(params, x_batched, y_batched):
  # Define the squared loss for a single pair (x,y)
  def squared_error(x, y):
    pred = model.apply(params, x)
    return jnp.inner(y-pred, y-pred) / 2.0
  # Vectorize the previous to compute the average of the loss on all samples.
  return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)

最后执行梯度下降。

learning_rate = 0.3  # Gradient step size.
print('Loss for "true" W,b: ', mse(true_params, x_samples, y_samples))
loss_grad_fn = jax.value_and_grad(mse)

@jax.jit
def update_params(params, learning_rate, grads):
  params = jax.tree_util.tree_map(
      lambda p, g: p - learning_rate * g, params, grads)
  return params

for i in range(101):
  # Perform one gradient update.
  loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
  params = update_params(params, learning_rate, grads)
  if i % 10 == 0:
    print(f'Loss step {i}: ', loss_val)
Loss for "true" W,b:  0.023639796
Loss step 0:  35.343876
Loss step 10:  0.5143468
Loss step 20:  0.11384157
Loss step 30:  0.039326735
Loss step 40:  0.019916197
Loss step 50:  0.014209114
Loss step 60:  0.012425648
Loss step 70:  0.01185039
Loss step 80:  0.011661778
Loss step 90:  0.011599409
Loss step 100:  0.011578697

使用 Optax 进行优化#

Flax 过去使用自己的 flax.optim 包进行优化,但随着 FLIP #1009,它已被弃用,转而使用 Optax

Optax 的基本用法很简单

  1. 选择一种优化方法(例如 optax.adam)。

  2. 从参数创建优化器状态(对于 Adam 优化器,此状态将包含动量值)。

  3. 使用 jax.value_and_grad() 计算损失的梯度。

  4. 在每次迭代时,调用 Optax update 函数来更新内部优化器状态并创建参数的更新。然后使用 Optax 的 apply_updates 方法将更新添加到参数。

请注意,Optax 可以做更多事情:它旨在将简单的梯度转换组合成更复杂的转换,从而允许实现各种优化器。它还支持随时间更改优化器超参数(“计划”)、对参数树的不同部分应用不同的更新(“掩码”)等等。有关详细信息,请参阅官方文档

import optax
tx = optax.adam(learning_rate=learning_rate)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(mse)
for i in range(101):
  loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
  updates, opt_state = tx.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  if i % 10 == 0:
    print('Loss step {}: '.format(i), loss_val)
Loss step 0:  0.011577629
Loss step 10:  0.26143134
Loss step 20:  0.076747075
Loss step 30:  0.036439072
Loss step 40:  0.022011759
Loss step 50:  0.01617833
Loss step 60:  0.013002962
Loss step 70:  0.01202613
Loss step 80:  0.0117645
Loss step 90:  0.011646036
Loss step 100:  0.011585514

序列化结果#

现在我们对训练结果感到满意,我们可能希望保存模型参数以便稍后重新加载。Flax 提供了一个序列化包,使您能够做到这一点。

from flax import serialization
bytes_output = serialization.to_bytes(params)
dict_output = serialization.to_state_dict(params)
print('Dict output')
print(dict_output)
print('Bytes output')
print(bytes_output)
Dict output
{'params': {'bias': Array([-1.4555763, -2.027799 ,  2.0790977,  1.2186142, -0.9980988],      dtype=float32), 'kernel': Array([[ 1.0098811 ,  0.1893436 ,  0.04455061, -0.92802244,  0.34784058],
       [ 1.7298452 ,  0.9879369 ,  1.1640465 ,  1.1006078 , -0.1065392 ],
       [-1.202946  ,  0.28635207,  1.415598  ,  0.11870954, -1.3141488 ],
       [-1.1941487 , -0.18958527,  0.03413866,  1.3169426 ,  0.08060387],
       [ 0.13852389,  1.371304  , -1.3187188 ,  0.5315267 , -2.2404993 ],
       [ 0.5629402 ,  0.8122313 ,  0.31751987,  0.534551  ,  0.9050044 ],
       [-0.37925997,  1.7410395 ,  1.0790284 , -0.5039832 ,  0.92830735],
       [ 0.970649  , -1.3153405 ,  0.33681503,  0.80993414, -1.2018454 ],
       [ 1.0194316 , -0.62024766,  1.081883  , -1.8389739 , -0.4580481 ],
       [-0.6436535 ,  0.45666716, -1.1329136 , -0.6853864 ,  0.1682897 ]],      dtype=float32)}}
Bytes output
b'\x81\xa6params\x82\xa4bias\xc7!\x01\x93\x91\x05\xa7float32\xc4\x14SP\xba\xbfu\xc7\x01\xc0\xf0\x0f\x05@\x8d\xfb\x9b?g\x83\x7f\xbf\xa6kernel\xc7\xd6\x01\x93\x92\n\x05\xa7float32\xc4\xc8\xc9C\x81?J\xe3A>\xb2z6=\xe1\x92m\xbf)\x18\xb2>\x91k\xdd?o\xe9|?z\xff\x94?\xb7\xe0\x8c?:1\xda\xbd"\xfa\x99\xbf\xbd\x9c\x92>Q2\xb5?\xfd\x1d\xf3=\x076\xa8\xbf\xdd\xd9\x98\xbf\xa4"B\xbe\xfc\xd4\x0b=\x93\x91\xa8?\xa4\x13\xa5=5\xd9\r>\xe4\x86\xaf?\xc7\xcb\xa8\xbf"\x12\x08?Wd\x0f\xc0\xd9\x1c\x10?d\xeeO?\xf7\x91\xa2>V\xd8\x08?^\xaeg?].\xc2\xbeb\xda\xde?\x9a\x1d\x8a?\x0b\x05\x01\xbf\x8d\xa5m?t|x?\x14]\xa8\xbf\x05s\xac>\xd8WO?\x12\xd6\x99\xbf\xbc|\x82?\x8d\xc8\x1e\xbf${\x8a?\x7fc\xeb\xbfH\x85\xea\xbez\xc6$\xbfG\xd0\xe9>P\x03\x91\xbf|u/\xbf#T,>'

要将模型重新加载回来,您需要使用模型参数结构的模板,就像您从模型初始化中获得的一样。在这里,我们使用先前生成的 params 作为模板。请注意,这将产生一个新的变量结构,而不是就地改变。

通过模板强制执行结构的目的在于避免用户在下游出现问题,因此您首先需要具有生成参数结构的正确模型。

serialization.from_bytes(params, bytes_output)
{'params': {'bias': array([-1.4555763, -2.027799 ,  2.0790977,  1.2186142, -0.9980988],
        dtype=float32),
  'kernel': array([[ 1.0098811 ,  0.1893436 ,  0.04455061, -0.92802244,  0.34784058],
         [ 1.7298452 ,  0.9879369 ,  1.1640465 ,  1.1006078 , -0.1065392 ],
         [-1.202946  ,  0.28635207,  1.415598  ,  0.11870954, -1.3141488 ],
         [-1.1941487 , -0.18958527,  0.03413866,  1.3169426 ,  0.08060387],
         [ 0.13852389,  1.371304  , -1.3187188 ,  0.5315267 , -2.2404993 ],
         [ 0.5629402 ,  0.8122313 ,  0.31751987,  0.534551  ,  0.9050044 ],
         [-0.37925997,  1.7410395 ,  1.0790284 , -0.5039832 ,  0.92830735],
         [ 0.970649  , -1.3153405 ,  0.33681503,  0.80993414, -1.2018454 ],
         [ 1.0194316 , -0.62024766,  1.081883  , -1.8389739 , -0.4580481 ],
         [-0.6436535 ,  0.45666716, -1.1329136 , -0.6853864 ,  0.1682897 ]],
        dtype=float32)}}

定义自己的模型#

Flax 允许您定义自己的模型,这应该比线性回归复杂一点。在本节中,我们将向您展示如何构建简单模型。为此,您需要创建基本 nn.Module 类的子类。

请记住,我们导入了 linen as nn ,这仅适用于新的 linen API

模块基础#

模型的基础抽象是 nn.Module 类,Flax 中每种预定义的层类型(如之前的 Dense)都是 nn.Module 的子类。让我们来看一下,首先定义一个简单但自定义的多层感知器,即一系列密集层与非线性激活函数的调用交错排列。

class ExplicitMLP(nn.Module):
  features: Sequence[int]

  def setup(self):
    # we automatically know what to do with lists, dicts of submodules
    self.layers = [nn.Dense(feat) for feat in self.features]
    # for single submodules, we would just write:
    # self.layer1 = nn.Dense(feat1)

  def __call__(self, inputs):
    x = inputs
    for i, lyr in enumerate(self.layers):
      x = lyr(x)
      if i != len(self.layers) - 1:
        x = nn.relu(x)
    return x

key1, key2 = random.split(random.key(0), 2)
x = random.uniform(key1, (4,4))

model = ExplicitMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))
print('output:\n', y)
initialized parameter shapes:
 {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
 [[ 0.          0.          0.          0.          0.        ]
 [ 0.0072379  -0.00810347 -0.02550939  0.02151716 -0.01261241]
 [ 0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.        ]]

正如我们所见,nn.Module 子类由以下组成:

  • 数据字段的集合(nn.Module 是 Python 数据类) - 这里我们只有类型为 Sequence[int]features 字段。

  • 一个 setup() 方法,它在 __postinit__ 的末尾被调用,你可以在其中注册子模块、变量和模型中需要的参数。

  • 一个 __call__ 函数,它返回模型从给定输入得到的输出。

  • 模型结构定义了一个参数的 pytree,其结构与模型相同:参数树为每一层包含一个 layers_n 子字典,并且每个子字典都包含相关密集层的参数。布局非常清晰。

注意:列表的管理方式大多符合你的预期(正在开发中),有一些你需要注意的边缘情况,如 此处 所述

由于模块结构及其参数彼此不相关联,因此你不能直接对给定输入调用 model(x),因为它会返回错误。__call__ 函数被包装在 apply 函数中,apply 函数才是要对输入调用的函数。

try:
    y = model(x) # Returns an error
except AttributeError as e:
    print(e)
"ExplicitMLP" object has no attribute "layers". If "layers" is defined in '.setup()', remember these fields are only accessible from inside 'init' or 'apply'.

由于这里我们有一个非常简单的模型,我们可以使用另一种(但等效的)方式在 __call__ 中使用 @nn.compact 注释内联声明子模块,如下所示:

class SimpleMLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, inputs):
    x = inputs
    for i, feat in enumerate(self.features):
      x = nn.Dense(feat, name=f'layers_{i}')(x)
      if i != len(self.features) - 1:
        x = nn.relu(x)
      # providing a name is optional though!
      # the default autonames would be "Dense_0", "Dense_1", ...
    return x

key1, key2 = random.split(random.key(0), 2)
x = random.uniform(key1, (4,4))

model = SimpleMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))
print('output:\n', y)
initialized parameter shapes:
 {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
 [[ 0.          0.          0.          0.          0.        ]
 [ 0.0072379  -0.00810347 -0.02550939  0.02151716 -0.01261241]
 [ 0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.        ]]

然而,你需要在两种声明模式之间注意一些差异:

  • setup 中,你可以命名一些子层并将它们保留以供进一步使用(例如,自编码器中的编码器/解码器方法)。

  • 如果你想有多个方法,那么你需要使用 setup 来声明模块,因为 @nn.compact 注释只允许注释一个方法。

  • 最后一个初始化将以不同的方式处理。有关更多详细信息,请参阅这些说明(TODO:添加说明链接)。

模块参数#

在之前的 MLP 示例中,我们仅依赖于预定义的层和运算符(Denserelu)。假设你没有 Flax 提供的 Dense 层,并且你想自己编写它。下面是使用 @nn.compact 方法声明新模块的样子:

class SimpleDense(nn.Module):
  features: int
  kernel_init: Callable = nn.initializers.lecun_normal()
  bias_init: Callable = nn.initializers.zeros_init()

  @nn.compact
  def __call__(self, inputs):
    kernel = self.param('kernel',
                        self.kernel_init, # Initialization function
                        (inputs.shape[-1], self.features))  # shape info.
    y = jnp.dot(inputs, kernel)
    bias = self.param('bias', self.bias_init, (self.features,))
    y = y + bias
    return y

key1, key2 = random.split(random.key(0), 2)
x = random.uniform(key1, (4,4))

model = SimpleDense(features=3)
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameters:\n', params)
print('output:\n', y)
initialized parameters:
 {'params': {'kernel': Array([[ 0.61506   , -0.22728713,  0.6054702 ],
       [-0.29617992,  1.1232013 , -0.879759  ],
       [-0.35162622,  0.3806491 ,  0.6893246 ],
       [-0.1151355 ,  0.04567898, -1.091212  ]], dtype=float32), 'bias': Array([0., 0., 0.], dtype=float32)}}
output:
 [[-0.02996204  1.102088   -0.6660265 ]
 [-0.31092793  0.6323942  -0.53678817]
 [ 0.01424007  0.9424717  -0.6356147 ]
 [ 0.36818963  0.3586519  -0.00459214]]

在这里,我们看到了如何使用 self.param 方法声明参数并将其分配给模型。它接受 (name, init_fn, *init_args, **init_kwargs) 作为输入。

  • name 只是最终出现在参数结构中的参数名称。

  • init_fn 是一个函数,其输入为 (PRNGKey, *init_args, **init_kwargs),返回一个 Array,其中 init_argsinit_kwargs 是调用初始化函数所需的参数。

  • init_argsinit_kwargs 是提供给初始化函数的参数。

这些参数也可以在 setup 方法中声明;它将无法使用形状推断,因为 Flax 在第一次调用时使用延迟初始化。

变量和变量集合#

到目前为止,我们已经看到,使用模型意味着使用:

  • nn.Module 的子类;

  • 模型的参数的 pytree(通常来自 model.init());

然而,这不足以涵盖我们机器学习(尤其是神经网络)所需的一切。在某些情况下,你可能希望你的神经网络在运行时跟踪一些内部状态(例如,批量归一化层)。有一种方法可以使用 variable 方法声明模型参数之外的变量。

为了演示目的,我们将实现一种类似于批量归一化的简化机制:我们将存储运行平均值,并在训练时从输入中减去这些平均值。对于正确的 batchnorm,你应该使用(并查看) 此处 的实现。

class BiasAdderWithRunningMean(nn.Module):
  decay: float = 0.99

  @nn.compact
  def __call__(self, x):
    # easy pattern to detect if we're initializing via empty variable tree
    is_initialized = self.has_variable('batch_stats', 'mean')
    ra_mean = self.variable('batch_stats', 'mean',
                            lambda s: jnp.zeros(s),
                            x.shape[1:])
    bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])
    if is_initialized:
      ra_mean.value = self.decay * ra_mean.value + (1.0 - self.decay) * jnp.mean(x, axis=0, keepdims=True)

    return x - ra_mean.value + bias


key1, key2 = random.split(random.key(0), 2)
x = jnp.ones((10,5))
model = BiasAdderWithRunningMean()
variables = model.init(key1, x)
print('initialized variables:\n', variables)
y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
print('updated state:\n', updated_state)
initialized variables:
 {'batch_stats': {'mean': Array([0., 0., 0., 0., 0.], dtype=float32)}, 'params': {'bias': Array([0., 0., 0., 0., 0.], dtype=float32)}}
updated state:
 {'batch_stats': {'mean': Array([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32)}}

在这里,updated_state 仅返回在模型应用于数据时被改变的状态变量。要更新变量并获取模型的新参数,我们可以使用以下模式:

for val in [1.0, 2.0, 3.0]:
  x = val * jnp.ones((10,5))
  y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
  old_state, params = flax.core.pop(variables, 'params')
  variables = flax.core.freeze({'params': params, **updated_state})
  print('updated state:\n', updated_state) # Shows only the mutable part
updated state:
 {'batch_stats': {'mean': Array([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32)}}
updated state:
 {'batch_stats': {'mean': Array([[0.0299, 0.0299, 0.0299, 0.0299, 0.0299]], dtype=float32)}}
updated state:
 {'batch_stats': {'mean': Array([[0.059601, 0.059601, 0.059601, 0.059601, 0.059601]], dtype=float32)}}

从这个简化的示例中,你应该能够推导出完整的 BatchNorm 实现,或任何涉及状态的层。最后,让我们添加一个优化器,看看如何使用优化器更新的参数和状态变量。

此示例没有任何实际作用,仅用于演示目的。

from functools import partial

@partial(jax.jit, static_argnums=(0, 1))
def update_step(tx, apply_fn, x, opt_state, params, state):

  def loss(params):
    y, updated_state = apply_fn({'params': params, **state},
                                x, mutable=list(state.keys()))
    l = ((x - y) ** 2).sum()
    return l, updated_state

  (l, state), grads = jax.value_and_grad(loss, has_aux=True)(params)
  updates, opt_state = tx.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  return opt_state, params, state

x = jnp.ones((10,5))
variables = model.init(random.key(0), x)
state, params = flax.core.pop(variables, 'params')
del variables
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)

for _ in range(3):
  opt_state, params, state = update_step(tx, model.apply, x, opt_state, params, state)
  print('Updated state: ', state)
Updated state:  {'batch_stats': {'mean': Array([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32)}}
Updated state:  {'batch_stats': {'mean': Array([[0.0199, 0.0199, 0.0199, 0.0199, 0.0199]], dtype=float32)}}
Updated state:  {'batch_stats': {'mean': Array([[0.029701, 0.029701, 0.029701, 0.029701, 0.029701]], dtype=float32)}}

请注意,上述函数具有非常冗长的签名,并且它实际上不能与 jax.jit() 一起使用,因为函数参数不是“有效的 JAX 类型”。

Flax 提供了一个方便的包装器 - TrainState - 它可以简化上述代码。查看 flax.training.train_state.TrainState 以了解更多信息。

使用 jax2tf 导出到 Tensorflow 的 SavedModel#

JAX 发布了一个名为 jax2tf 的实验性转换器,它可以将训练好的 Flax 模型转换为 Tensorflow 的 SavedModel 格式(因此它可以用于 TF HubTF.liteTF.js 或其他下游应用程序)。该存储库包含更多文档,并有各种 Flax 示例。