flax.errors 包#

Flax 具有以下几类错误。

exception flax.errors.AlreadyExistsError(path)[源代码]#

尝试通过复制覆盖文件。

你可以传递 overwrite=True 来禁用此行为并覆盖现有的文件。

exception flax.errors.ApplyModuleInvalidMethodError(method)[源代码]#

当调用 Module.apply() 时,你可以指定

使用参数 method 来应用的方法。如果提供的参数不是模块中的方法,也不是至少有一个参数的函数,则会抛出此错误。

Module.apply() 的参考文档中了解更多信息。

exception flax.errors.ApplyScopeInvalidVariablesStructureError(variables)[源代码]#

当传递给 apply() 的 variables 字典具有

额外的“params”层时,即 {‘params’: {‘params’: …}} 时,会抛出此错误。有关变量字典的更多说明,请参阅 flax.core.variables

exception flax.errors.ApplyScopeInvalidVariablesTypeError[源代码]#

当调用 Module.apply() 时,第一个

参数应为变量字典。有关变量字典的更多说明,请参阅 flax.core.variables

exception flax.errors.AssignSubModuleError(cls)[源代码]#

你只能在两个地方创建子模块

  1. 如果你的模块是非紧凑的:在 Module.setup() 中。

  2. 如果你的模块是紧凑的:在用 nn.compact() 包裹的方法内。

例如,以下代码会抛出此错误,因为 nn.Conv 是在 __call__ 中创建的,而 __call__ 未标记为紧凑

class Foo(nn.Module):
  def setup(self):
    pass

  def __call__(self, x):
    conv = nn.Conv(features=3, kernel_size=3)

Foo().init(random.key(0), jnp.zeros((1,)))

请注意,如果在 setup 中部分定义了一个模块,也会抛出此错误

class Foo(nn.Module):
  def setup(self):
    self.conv = functools.partial(nn.Conv, features=3)

  def __call__(self, x):
    x = self.conv(kernel_size=4)(x)
    return x

Foo().init(random.key(0), jnp.zeros((1,)))

在这种情况下,self.conv(kernel_size=4) 是从 __call__ 调用的,这是不允许的,因为它既不在 setup 中,也不是在用 x``nn.compact`` 包裹的方法中。

exception flax.errors.CallCompactUnboundModuleError[源代码]#

当你尝试直接调用模块而不是

通过 Module.apply() 调用时,会发生此错误。例如,当尝试运行此代码时,会引发该错误

from flax import linen as nn
import jax.numpy as jnp

test_dense = nn.Dense(10)
test_dense(jnp.ones((5,5)))

相反,你应该通过 Module.apply() 传递变量(参数和其他状态)(或使用 Module.init() 获取初始变量)

from jax import random
variables = test_dense.init(random.key(0), jnp.ones((5,5)))

y = test_dense.apply(variables, jnp.ones((5,5)))
exception flax.errors.CallSetupUnboundModuleError[源代码]#

当你尝试直接调用 .setup() 时,会发生此错误。

例如,当尝试运行此代码时,会引发该错误

from flax import linen as nn
import jax.numpy as jnp

class MyModule(nn.Module):
  def setup(self):
    self.submodule = MySubModule()

module = MyModule()
module.setup() # <-- ERROR!
submodule = module.submodule

一般来说,你不应该自己调用 .setup(),如果你需要访问在 setup 中定义的字段或子模块,你可以创建一个函数来提取它并将其传递给 nn.apply

# setup() will be called automatically by ``nn.apply``
def get_submodule(module):
  return module.submodule.clone() # avoid leaking the Scope

empty_variables = {} # you can also use the real variables
submodule = nn.apply(get_submodule, module)(empty_variables)
exception flax.errors.CallShareScopeOnUnboundModuleError[源代码]#

当你尝试在未绑定的模块上调用 nn.share_scope 时,会发生此错误。例如,当你尝试在顶层使用 nn.share_scope

from flax import linen as nn

class CustomDense(nn.Dense):
  def __call__(self, x):
    return super().__call__(x) + 1

custom_dense = CustomDense(5)
dense = nn.Dense(5)  # has the parameters

nn.share_scope(custom_dense, dense) # <-- ERROR!
exception flax.errors.CallUnbindOnUnboundModuleError[源代码]#

当你尝试在未绑定的模块上调用 .unbind() 时,会发生此错误。例如,当你尝试运行以下示例时,会引发错误

from flax import linen as nn

class MyModule(nn.Module):
  @nn.compact
  def __call__(self, x):
    return nn.Dense(features=10)(x)

module = MyModule()
module.unbind() # <-- ERROR!

相反,你应该在调用 .unbind() 之前将模块 bind 到变量集合

bound_module = module.bind(variables)
... # do something with bound_module
module = bound_module.unbind() # <-- OK!
exception flax.errors.CursorFindError(cursor=None, cursor2=None)[源代码]#

调用 Cursor.find() 时出错。

如果根据 cond_fn 的条件没有找到对象或找到了多个对象,则会发生此错误。

exception flax.errors.DescriptorAttributeError[source]#

当您尝试访问一个访问不存在属性的属性时,会发生此错误。

例如,当尝试运行以下代码时会引发此错误

class Foo(nn.Module):
    @property
    def prop(self):
        return self.non_existent_field # ERROR!
    def __call__(self, x):
        return self.prop

foo = Foo()
variables = foo.init(jax.random.key(0), jnp.ones(shape=(1, 8)))
exception flax.errors.IncorrectPostInitOverrideError[source]#

当您重写 .__post_init__() 而没有调用 super().__post_init__() 时,会发生此错误。

例如,当尝试运行以下代码时会引发此错误

from flax import linen as nn
import jax.numpy as jnp
import jax
class A(nn.Module):
  x: float
  def __post_init__(self):
    self.x_square = self.x ** 2
    # super().__post_init__() <-- forgot to add this line
  @nn.compact
  def __call__(self, input):
    return input + 3

r = A(x=3)
r.init(jax.random.key(2), jnp.ones(3))
exception flax.errors.InvalidCheckpointError(path, step)[source]#

检查点不能存储在已在当前或后续步骤具有检查点的目录中。

在当前或后续步骤具有检查点的目录中。

您可以传递 overwrite=True 来禁用此行为并覆盖目标目录中的现有检查点。

exception flax.errors.InvalidFilterError(filter_like)[source]#

过滤器应为布尔值、字符串或容器对象。

exception flax.errors.InvalidInstanceModuleError[source]#

当您尝试在 Module 类本身而不是 Module 类的实例上调用 .init().init_with_output().apply().bind() 时,会发生此错误。 例如,当尝试运行以下代码时会引发此错误

在 Module 类本身而不是 Module 类的实例上调用时会发生此错误。例如,当尝试运行以下代码时会引发此错误

class B(nn.Module):
  @nn.compact
  def __call__(self, x):
    return x

k = random.key(0)
x = random.uniform(random.key(1), (2,))
B.init(k, x)   # B is module class, not B() a module instance
B.apply(vs, x)   # similar issue with apply called on class instead of
instance.
exception flax.errors.InvalidRngError(msg)[source]#

Module 中使用的所有 rng 都应适当地传递给

Module.init()Module.apply()。 我们使用以下示例分别解释这两者

class Bar(nn.Module):
  @nn.compact
  def __call__(self, x):
    some_param = self.param('some_param', nn.initializers.zeros_init(), (1, ))
    dropout_rng = self.make_rng('dropout')
    x = nn.Dense(features=4)(x)
    ...

class Foo(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = Bar()(x)
    ...

用于 Module.init() 的 PRNG

在此示例中,使用了两个 rng

  • params 用于初始化模型的参数。 此 rng 用于初始化 some_params 参数,并用于初始化 Bar 中使用的 Dense Module 的权重。

  • dropout 用于 Bar 中使用的 dropout rng。

因此,Foo 的初始化方式如下

init_rngs = {'params': random.key(0), 'dropout': random.key(1)}
variables = Foo().init(init_rngs, init_inputs)

如果 Module 仅需要用于 params 的 rng,则可以使用

SomeModule().init(rng, ...)  # Shorthand for {'params': rng}

用于 Module.apply() 的 PRNG

当应用 Foo 时,只需要 dropout 的 rng,因为 params 仅用于初始化 Module 参数

Foo().apply(variables, inputs, rngs={'dropout': random.key(2)})

如果 Module 仅需要用于 params 的 rng,则完全不必为 apply 提供 rng

SomeModule().apply(variables, inputs)  # rngs=None
exception flax.errors.InvalidScopeError(scope_name)[source]#

临时作用域仅在创建它的上下文中有效

with Scope(variables, rngs=rngs).temporary() as root

y = fn(root, *args, **kwargs) # 此处 root 有效。

# 此处 root 无效。

exception flax.errors.JaxTransformError[source]#

JAX 转换和 Flax 模块不能混合使用。

JAX 的函数式转换期望纯函数。 当您想在 Flax 模型 **内部** 使用 JAX 转换时,您应该使用 Flax 转换包装器(例如: flax.linen.vmapflax.linen.scan 等)。

exception flax.errors.LazyInitError(partial_val)[source]#

Lazy Init 函数具有不可计算的返回值。

当使用影响已初始化变量的 jax.ShapeDtypeStruct 将参数传递给 lazy_init 时,会发生这种情况。 确保 init 函数仅使用 shape 和 dtype,或者如果不可能,请传递实际的 JAX 数组。

示例

class Foo(nn.Module):
  @compact
  def __call__(self, x):
    # This parameter depends on the input x
    # this causes an error when using lazy_init.
    k = self.param("kernel", lambda _: x)
    return x * k
Foo().lazy_init(random.key(0), jax.ShapeDtypeStruct((8, 4), jnp.float32))
exception flax.errors.MPACheckpointingRequiredError(path, step)[source]#

要最佳地保存和恢复多进程数组(来自 pjit 的 GDA 或 jax 数组输出),请使用 GlobalAsyncCheckpointManager。

您可以在顶层创建一个 GlobalAsyncCheckpointManager 并将其作为参数传递

from jax.experimental.gda_serialization import serialization as gdas
gda_manager = gdas.GlobalAsyncCheckpointManager()
save_checkpoint(..., gda_manager=gda_manager)
exception flax.errors.MPARestoreDataCorruptedError(step, path)[source]#

存储在 Google Cloud Storage 中的多进程数组不包含 “commit_success.txt” 文件,该文件应在保存结束时写入。

未能找到它可能表明您保存的 GDA 数据已损坏。

exception flax.errors.MPARestoreTargetRequiredError(path, step, key=None)[source]#

使用多进程数组恢复检查点时,请提供有效的目标。

多进程数组需要初始化分片(全局网格和分区规范)。 因此,要恢复包含多进程数组的检查点,请确保您传递的 target 在相应的树结构位置包含有效的多进程数组。 如果您无法提供完整的有效 target,请考虑 allow_partial_mpa_restoration=True

exception flax.errors.ModifyScopeVariableError(col, variable_name, scope_path)[source]#

如果变量所属的集合是不可变的,则无法更新该变量。

当您应用 Module 时,应指定哪些变量集合是可变的

class MyModule(nn.Module):
  @nn.compact
  def __call__(self, x):
    ...
    var = self.variable('batch_stats', 'mean', ...)
    var.value = ...
    ...

v = MyModule.init(...)
...
logits = MyModule.apply(v, batch)  # This throws an error.
logits = MyModule.apply(v, batch, mutable=['batch_stats'])  # This works.
exception flax.errors.MultipleMethodsCompactError[source]#

@compact 装饰器最多只能添加到一个 Flax 方法中

模块。 为了解决这个问题,您可以

  • 移除 @compact,并使用 Module.setup() 定义子模块和变量。

  • 使用两个单独的模块,它们都具有唯一的 @compact 方法。

TODO(marcvanzee): 链接到解释其背后动机的设计说明。不需要与 hk.transparent 等效的东西,而且它使子模块更加合理,因为不需要为方法名称添加前缀。

exception flax.errors.NameInUseError(key_type, value, module_name)[source]#

当尝试创建具有现有名称的子模块、参数或变量时,会引发此错误。

它们都被认为是在同一命名空间中。

共享子模块

这是共享子模块的错误模式

y = nn.Dense(feature=3, name='bar')(x)
z = nn.Dense(feature=3, name='bar')(x+epsilon)

相反,模块应该通过实例来共享

dense = nn.Dense(feature=3, name='bar')
y = dense(x)
z = dense(x+epsilon)

如果子模块未提供名称,则会自动为其指定一个唯一的名称

class MyModule(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = MySubModule()(x)
    x = MySubModule()(x)  # This is fine.
    return x

参数和变量

参数名称可能会与子模块或变量冲突,因为它们都存储在同一个变量字典中

class Foo(nn.Module):
  @nn.compact
  def __call__(self, x):
    bar = self.param('bar', nn.initializers.zeros_init(), (1, ))
    embed = nn.Embed(num_embeddings=2, features=5, name='bar')  # <-- ERROR!

变量也应该具有唯一的名称,即使它们有自己的集合

class Foo(nn.Module):
  @nn.compact
  def __call__(self, inputs):
    _ = self.param('mean', initializers.lecun_normal(), (2, 2))
    _ = self.variable('stats', 'mean', initializers.zeros_init(), (2, 2))
exception flax.errors.PartitioningUnspecifiedError(target)[source]#

当尝试通过以下方式向分区变量添加轴时,会引发此错误

使用转换(例如:scanvmap),而未在 metadata_params 字典中指定“partition_name”。

exception flax.errors.ReservedModuleAttributeError(annotations)[source]#

当创建使用保留属性的模块时,会抛出此错误。

以下属性是保留的

  • parent: 此模块的父模块。

  • name: 此模块的名称。

exception flax.errors.ScopeCollectionNotFound(col_name, var_name, scope_path)[source]#

当尝试从空集合访问变量时,会抛出此错误。

有两种常见原因

  1. 集合未正确传递给 apply
    例如,您可能使用了 module.apply(params, ...) 而不是
    module.apply({'params': params}, ...)
  2. 集合为空,因为变量需要初始化。
    在这种情况下,您应该在
    应用期间使集合可变 (例如:module.apply(variables, ..., mutable=['state']))。
exception flax.errors.ScopeParamNotFoundError(param_name, scope_path)[source]#

当尝试访问不存在的参数时,会抛出此错误。

例如,在下面的代码中,初始化的嵌入名称“embedding”与应用名称“embed”不匹配

class Embed(nn.Module):
  num_embeddings: int
  features: int

  @nn.compact
  def __call__(self, inputs, embed_name='embedding'):
    inputs = inputs.astype('int32')
    embedding = self.param(embed_name,
                           jax.nn.initializers.lecun_normal(),
                           (self.num_embeddings, self.features))
    return embedding[inputs]

model = Embed(4, 8)
variables = model.init(random.key(0), jnp.ones((5, 5, 1)))
_ = model.apply(variables, jnp.ones((5, 5, 1)), 'embed')
exception flax.errors.ScopeParamShapeError(param_name, scope_path, value_shape, init_shape)[source]#

当现有参数的形状与

init_fn 的返回值形状不同时,会抛出此错误。当在 Module.apply() 期间提供的形状与初始化模块时使用的形状不同时,可能会发生这种情况。

例如,以下代码会抛出此错误,因为应用形状((5, 5, 1))与初始化形状((5, 5)不同。因此,init 期间内核的形状为 (1, 8),而 apply 期间的形状为 (5, 8),这会导致此错误。

class NoBiasDense(nn.Module):
  features: int = 8

  @nn.compact
  def __call__(self, x):
    kernel = self.param('kernel',
                        lecun_normal(),
                        (x.shape[-1], self.features))  # <--- ERROR
    y = lax.dot_general(x, kernel,
                        (((x.ndim - 1,), (0,)), ((), ())))
    return y

variables = NoBiasDense().init(random.key(0), jnp.ones((5, 5, 1)))
_ = NoBiasDense().apply(variables, jnp.ones((5, 5)))
exception flax.errors.ScopeVariableNotFoundError(name, col, scope_path)[source]#

当尝试在不可变的集合中在 Scope 中使用变量时,会抛出此错误。

为了创建此变量,请使用 Module.apply() 中的 mutable 关键字显式地将集合标记为可变的。

exception flax.errors.SetAttributeFrozenModuleError(module_cls, attr_name, attr_val)[source]#

您只能在 self 内部将模块属性分配给

Module.setup()。在该方法之外,模块实例被冻结(即不可变)。此行为类似于冻结的 Python 数据类。

例如,在以下情况下会引发此错误

class SomeModule(nn.Module):
  @nn.compact
  def __call__(self, x, num_features=10):
    self.num_features = num_features  # <-- ERROR!
    x = nn.Dense(self.num_features)(x)
    return x

s = SomeModule().init(random.key(0), jnp.ones((5, 5)))

类似地,当尝试在构造子模块后修改子模块的属性时,即使这是在父模块的 setup() 方法中完成的,也会引发此错误

class Foo(nn.Module):
    def setup(self):
      self.dense = nn.Dense(features=10)
      self.dense.features = 20  # <--- This is not allowed

    def __call__(self, x):
      return self.dense(x)
exception flax.errors.SetAttributeInModuleSetupError[source]#

不允许在

Module.setup() 中修改模块类属性:

class Foo(nn.Module):
  features: int = 6

  def setup(self):
    self.features = 3  # <-- ERROR

  def __call__(self, x):
    return nn.Dense(self.features)(x)

variables = SomeModule().init(random.key(0), jnp.ones((1, )))

相反,这些属性应该在初始化模块时设置

class Foo(nn.Module):
  features: int = 6

  @nn.compact
  def __call__(self, x):
    return nn.Dense(self.features)(x)

variables = SomeModule(features=3).init(random.key(0), jnp.ones((1, )))

TODO(marcvanzee): 链接到解释为什么模块必须保持冻结的设计说明(否则我们无法安全地克隆它们,这是我们用于提升转换的方式)。

exception flax.errors.TraceContextError(message)[source]#
exception flax.errors.TransformTargetError(target)[source]#

Linen 转换必须应用于模块类或将模块实例作为第一个参数的函数。

当将无效目标传递给 linen 转换(nn.vmap、nn.scan 等)时,会发生此错误。例如,当尝试转换模块实例时,会发生此错误

nn.vmap(nn.Dense(features))(x)  # raises TransformTargetError

您可以直接转换 nn.Dense

nn.vmap(nn.Dense)(features)(x)

或者,您可以创建一个将模块实例作为第一个参数的函数

class BatchDense(nn.Module):
  @nn.compact
  def __call__(self, x):
    return nn.vmap(
        lambda mdl, x: mdl(x),
        variable_axes={'params': 0}, split_rngs={'params':
        True})(nn.Dense(3), x)
exception flax.errors.TransformedMethodReturnValueError(name)[source]#

转换后的模块方法不能返回其他模块或变量。

异常 flax.errors.TraverseTreeError(update_fn, cond_fn)[源代码]#

调用 Cursor._traverse_tree() 时发生的错误。此函数有两种模式:

  • 如果 update_fn 不为 None,它将遍历树并返回一个生成器,其中包含应用 update_fn 的路径以及新修改的值的元组。

  • 如果 cond_fn 不为 None,它将遍历树并返回一个生成器,其中包含满足 cond_fn 条件的元组路径。

如果 update_fncond_fn 都为 None,或者都不为 None,则会发生此错误。