Flax 的理念#

排名不分先后

  • 库代码应该易于阅读和理解。

  • 宁愿复制代码也不要使用糟糕的抽象。

  • 一般来说,宁愿复制代码也不要向函数添加选项。

  • 注释驱动的设计:如果难以记录你的代码,请考虑更改设计。

  • 单元测试驱动的设计:如果难以测试你的代码,请考虑更改设计。

  • 人们通过复制现有实现来启动项目——使基本实现变得优秀。

  • 如果我们向开发者公开一个抽象,我们就承担了心理负担。

  • 面向开发者的函数式编程抽象会使一些用户感到困惑,在收益高的地方公开它们。

  • “阅读手册”不是对开发者困惑的恰当回应。框架应该通过断言和错误消息等方式引导开发者找到好的解决方案。

  • 无用的错误消息是一个 bug。

  • “调试的难度是编写代码的两倍。因此,如果你尽可能巧妙地编写代码,那么根据定义,你还不够聪明去调试它。” ——Brian Kernighan

设计原则#

Flax 是一个构建在 JAX 之上的神经网络库,已被越来越多的用户采用,最引人注目的是在 MLPerf 0.7 基准测试中提交的 JAX 代码。我们在过去一年中的经验(以及与用户和 JAX 核心开发人员的多次对话)指导了名为 Linenflax.linen)的 API 的重新设计,以响应以下基本设计问题。

构建在 JAX 之上的神经网络库如何受益,并利用 JAX 的独特优势?#

世界上已经有了 TensorFlow 和 PyTorch,没有必要构建它们的克隆版本。我们认为,JAX 采用的可组合函数转换方法为使神经网络代码比现有库更易于维护、更具可扩展性和更高性能开辟了新的前沿。虽然我们努力提供熟悉 Keras/Sonnet/PyTorch 的用户熟悉的 API,但 Linen 本质上是一个用于在 JAX 中定义神经网络的功能系统。以下是我们认为以 JAX 为目标的库可以实现的一些示例

  • 将模型编写为“单例”代码,并使用 jax.vmap 自动引入批处理。

  • 自动处理 NLP 中的不规则批次和其他掩码问题。

  • 通过为大型卷积网络利用重新物化的 scan 来创建高效的编译时和运行时模型。

  • 通过启用简单的重新物化、可逆性和模型并行数据分片来消除内存问题。

如何与 JAX 转换互操作?#

可以说,神经网络库的全部意义在于提供一个隐式变量管理 API,以使用户不必手动将数千个变量穿过复杂的函数树。然而,JAX 操作的是纯函数。为了处理当前和未来的 JAX 转换(以任何方式配置和组合),Linen 模块被直接“函数化”,即自动就地转换为以下形式的显式函数

\[f \left( v_{in}, x \right) \rightarrow v_{out}, y\]

其中 \(v_{in}\) 是模型使用的变量集合和 PRNG 状态,\(v_{out}\) 是变异的输出变量集合,\(x\) 是输入数据,\(y\) 是输出数据。应用 JAX 转换然后简单地简化为指定对各种变量集合和 PRNG 状态的任何特定于参数的转换选项。这释放了 JAX 转换的灵活性和强大功能——例如,可以通过以不同的方式使用 jax.pmap 来实现设备并行训练或按设备集成,而无需任何显式库支持。此外,**在模块**中,我们公开了围绕复杂 JAX 转换的轻量级包装器,例如 jax.vmapjax.lax.scan,它们注释每个变量集合如何被 JAX 转换。重要的是,我们正确地处理了在映射和循环转换下创建新变量和转换变量的非平凡情况,以进行初始化和应用。

如何表示参数,以及如何处理更新有状态变量的通用“可微算法”?#

我们遵循 JAX 函数式惯例,将数据存储在“pytrees”中:嵌套元组、列表、字典中包含的 JAX 数组。由于研究人员不可避免地会手动与此数据交互,因此我们使用具有有意义默认键的嵌套字典,并提供几个实用程序(遍历等)来直接处理它们。Linen 使用 Python 冻结字典的加速版本,该版本缓存其 JAX 扁平化形式,以加快 jit 函数调用开销。

Flax 通过允许模型接受几种不同“类型”的集合来概括神经网络的操作:参数、批归一化统计信息、自回归缓存、调试信息、细粒度的超参数等。每个集合都存储在与模型结构相同的嵌套字典中。重要的是,我们*不*将这些不同的类型混淆在单个模糊的“状态”标题下,而是将不同逻辑类型的变量分开,这些变量可以在 JAX 转换和突变(例如,训练与预测)下进行不同的处理。同样,我们允许在模块内部有多个单独命名的 PRNG 链,用于对不同的应用(如初始化、dropout、采样等)进行不同的随机性处理。

在每个阶段,与神经网络关联的数据都不会保存在自定义对象层次结构中,而是保留在显式的 Python 和 JAX 本机形式中,易于内省和修改。用户已经利用它来将 TF 和 PyTorch 检查点映射到 Flax,实现特定于子模型的损失项,并执行快速模型手术等。为了保存此数据,大多数 Flax 示例通过高效的“msgpack”二进制格式存储这些嵌套字典——但是由于变量只是 Python 字典,因此你可以直接使用任何(非 JAX 感知的)序列化库。

如何与纯函数式的 JAX 代码互操作?#

为了对 JAX 生态系统广泛有用,用户不必为了给定的数值任务添加“可训练性”而对其代码进行大量重构。*“库不应该妨碍用户。”*从 Linen 中利用纯函数式代码非常简单:模块实现只是带有命名变量的 JAX 代码。在其他纯函数式代码内部使用 Linen 模块可以像使用单个顶层模块转换一样简单,以允许初始化和纯应用可能包含各种可训练部分的任何 JAX 程序。