词汇表

词汇表#

有关其他术语,请参阅 Jax 词汇表

绑定模块#

当通过常规 Python 对象构造(例如,module = SomeModule(args…))创建 Module 时,它处于未绑定状态。这意味着仅设置了数据类属性,并且没有变量绑定到该模块。当调用纯函数 Module.init()Module.apply() 时,Flax 会克隆该模块并将变量绑定到它,并且模块的方法代码在局部绑定状态下执行,从而允许直接调用子模块,而无需提供变量。有关更多详细信息,请参阅 模块生命周期

紧凑/非紧凑模块#

具有单个方法的模块可以通过使用 @nn.compact 装饰器内联声明子模块和变量。这些被称为“紧凑式模块”,而定义 setup() 方法(通常但不总是带有多个可调用方法)的模块被称为“设置式模块”。要了解更多信息,请参阅 设置与紧凑指南

折叠#

给定一个输入 PRNG 密钥和一个整数,生成一个新的 PRNG 密钥。通常用于您想要生成新密钥但仍然能够在此之后使用原始 rng 密钥的情况。您也可以使用 jax.random.split 来完成此操作,但这会有效地创建两个 RNG 密钥,这会更慢。请参阅我们的 RNG 指南,了解 Flax 如何在 Modules 中自动生成新的 PRNG 密钥。

FrozenDict#

一个不可变的字典,可以“解冻”为常规的可变字典。在内部,Flax 使用 FrozenDicts 来确保变量字典不会被意外改变。注意:我们正在考虑从我们的 API 返回常规字典,并且只在内部使用 FrozenDicts。(请参阅 #1223)。

功能核心#

flax 核心库实现了简单的容器 Scope API,用于通过模型传递变量和 PRNG,以及转换传递 Scope 对象的函数所需的提升机制。基于 Python 类的模块 API 构建在此核心库之上。

延迟初始化#

Flax 中的变量仅在需要时才会被延迟初始化。也就是说,在模块的正常执行过程中,如果未在提供的变量集合数据中找到请求的变量名称,我们会调用初始化器函数来创建它。这允许我们将初始化和应用视为相同的代码路径,从而简化了 JAX 转换与层的使用。

提升的转换#

请参阅 Flax 文档

模块#

一个数据类,允许以引用透明的形式定义和初始化参数。它负责存储和更新自身中的变量和参数。模块可以轻松转换为函数,从而使其可以简单地与 JAX 转换(如 vmapscan)一起使用。

参数 / parameters#

“params”是变量字典(dict)中的规范变量集合。“params”集合通常包含可训练权重。

RNG 序列#

在 Flax Modules 中,您可以通过 Module.make_rng() 获取新的 PRNG 密钥。这些密钥可用于通过 JAX 的函数式随机数生成器生成随机数。拥有不同的 RNG 序列(例如,“params”和“dropout”)可以在多主机设置中进行细粒度控制(例如,在不同的主机上相同地初始化参数,但具有不同的 dropout 掩码),并在 提升转换时以不同方式处理这些序列。有关更多详细信息,请参阅 RNG 指南

范围#

一个容器类,用于保存每一层的变量和 PRNG 密钥。

形状推断#

模块不需要在其定义中指定输入数组的形状。Flax 在初始化时会检查输入数组,并推断模型中参数的正确形状。

TrainState#

请参阅 flax.training.train_state.TrainState

变量#

驻留在 变量集合 叶节点中的 权重/参数/数据/数组。变量在模块内部使用 Module.variable() 定义。集合“params”的变量简称为参数,可以使用 Module.param() 设置。

变量集合#

变量字典中的条目,其中包含模型使用的权重/参数/数据/数组。“params”是变量字典中的规范集合。它们通常是可微分的,由外部 SGD 式循环/优化器更新,而不是由前向传递代码直接修改。

变量字典#

一个包含 变量集合 的字典。每个变量集合都是从字符串名称(例如,“params”或“batch_stats”)到具有 变量 作为叶节点的(可能是嵌套的)字典的映射,匹配子模块树结构。在 Jax 文档中阅读有关 pytrees 和叶节点的更多信息。