词汇表#
有关其他术语,请参阅 Jax 词汇表。
- 绑定模块#
当通过常规 Python 对象构造(例如,module = SomeModule(args…))创建
Module
时,它处于未绑定状态。这意味着仅设置了数据类属性,并且没有变量绑定到该模块。当调用纯函数Module.init()
或Module.apply()
时,Flax 会克隆该模块并将变量绑定到它,并且模块的方法代码在局部绑定状态下执行,从而允许直接调用子模块,而无需提供变量。有关更多详细信息,请参阅 模块生命周期。- 紧凑/非紧凑模块#
具有单个方法的模块可以通过使用
@nn.compact
装饰器内联声明子模块和变量。这些被称为“紧凑式模块”,而定义setup()
方法(通常但不总是带有多个可调用方法)的模块被称为“设置式模块”。要了解更多信息,请参阅 设置与紧凑指南。- 折叠#
给定一个输入 PRNG 密钥和一个整数,生成一个新的 PRNG 密钥。通常用于您想要生成新密钥但仍然能够在此之后使用原始 rng 密钥的情况。您也可以使用 jax.random.split 来完成此操作,但这会有效地创建两个 RNG 密钥,这会更慢。请参阅我们的 RNG 指南,了解 Flax 如何在
Modules
中自动生成新的 PRNG 密钥。- FrozenDict#
一个不可变的字典,可以“解冻”为常规的可变字典。在内部,Flax 使用 FrozenDicts 来确保变量字典不会被意外改变。注意:我们正在考虑从我们的 API 返回常规字典,并且只在内部使用 FrozenDicts。(请参阅 #1223)。
- 功能核心#
flax 核心库实现了简单的容器 Scope API,用于通过模型传递变量和 PRNG,以及转换传递 Scope 对象的函数所需的提升机制。基于 Python 类的模块 API 构建在此核心库之上。
- 延迟初始化#
Flax 中的变量仅在需要时才会被延迟初始化。也就是说,在模块的正常执行过程中,如果未在提供的变量集合数据中找到请求的变量名称,我们会调用初始化器函数来创建它。这允许我们将初始化和应用视为相同的代码路径,从而简化了 JAX 转换与层的使用。
- 提升的转换#
请参阅 Flax 文档。
- 模块#
一个数据类,允许以引用透明的形式定义和初始化参数。它负责存储和更新自身中的变量和参数。模块可以轻松转换为函数,从而使其可以简单地与 JAX 转换(如 vmap 和 scan)一起使用。
- 参数 / parameters#
“params”是变量字典(dict)中的规范变量集合。“params”集合通常包含可训练权重。
- RNG 序列#
在 Flax
Modules
中,您可以通过Module.make_rng()
获取新的 PRNG 密钥。这些密钥可用于通过 JAX 的函数式随机数生成器生成随机数。拥有不同的 RNG 序列(例如,“params”和“dropout”)可以在多主机设置中进行细粒度控制(例如,在不同的主机上相同地初始化参数,但具有不同的 dropout 掩码),并在 提升转换时以不同方式处理这些序列。有关更多详细信息,请参阅 RNG 指南。- 范围#
一个容器类,用于保存每一层的变量和 PRNG 密钥。
- 形状推断#
模块不需要在其定义中指定输入数组的形状。Flax 在初始化时会检查输入数组,并推断模型中参数的正确形状。
- TrainState#
- 变量#
驻留在 变量集合 叶节点中的 权重/参数/数据/数组。变量在模块内部使用
Module.variable()
定义。集合“params”的变量简称为参数,可以使用Module.param()
设置。- 变量集合#
变量字典中的条目,其中包含模型使用的权重/参数/数据/数组。“params”是变量字典中的规范集合。它们通常是可微分的,由外部 SGD 式循环/优化器更新,而不是由前向传递代码直接修改。
- 变量字典#
一个包含 变量集合 的字典。每个变量集合都是从字符串名称(例如,“params”或“batch_stats”)到具有 变量 作为叶节点的(可能是嵌套的)字典的映射,匹配子模块树结构。在 Jax 文档中阅读有关 pytrees 和叶节点的更多信息。