flax.errors 包#
Flax 具有以下几类错误。
- exception flax.errors.AlreadyExistsError(path)[源代码]#
尝试通过复制覆盖文件。
你可以传递
overwrite=True
来禁用此行为并覆盖现有的文件。
- exception flax.errors.ApplyModuleInvalidMethodError(method)[源代码]#
当调用
Module.apply()
时,你可以指定使用参数
method
来应用的方法。如果提供的参数不是模块中的方法,也不是至少有一个参数的函数,则会抛出此错误。在
Module.apply()
的参考文档中了解更多信息。
- exception flax.errors.ApplyScopeInvalidVariablesStructureError(variables)[源代码]#
当传递给 apply() 的
variables
字典具有额外的“params”层时,即 {‘params’: {‘params’: …}} 时,会抛出此错误。有关变量字典的更多说明,请参阅
flax.core.variables
。
- exception flax.errors.ApplyScopeInvalidVariablesTypeError[源代码]#
当调用
Module.apply()
时,第一个参数应为变量字典。有关变量字典的更多说明,请参阅
flax.core.variables
。
- exception flax.errors.AssignSubModuleError(cls)[源代码]#
你只能在两个地方创建子模块
如果你的模块是非紧凑的:在
Module.setup()
中。如果你的模块是紧凑的:在用
nn.compact()
包裹的方法内。
例如,以下代码会抛出此错误,因为
nn.Conv
是在__call__
中创建的,而__call__
未标记为紧凑class Foo(nn.Module): def setup(self): pass def __call__(self, x): conv = nn.Conv(features=3, kernel_size=3) Foo().init(random.key(0), jnp.zeros((1,)))
请注意,如果在 setup 中部分定义了一个模块,也会抛出此错误
class Foo(nn.Module): def setup(self): self.conv = functools.partial(nn.Conv, features=3) def __call__(self, x): x = self.conv(kernel_size=4)(x) return x Foo().init(random.key(0), jnp.zeros((1,)))
在这种情况下,
self.conv(kernel_size=4)
是从__call__
调用的,这是不允许的,因为它既不在setup
中,也不是在用 x``nn.compact`` 包裹的方法中。
- exception flax.errors.CallCompactUnboundModuleError[源代码]#
当你尝试直接调用模块而不是
通过
Module.apply()
调用时,会发生此错误。例如,当尝试运行此代码时,会引发该错误from flax import linen as nn import jax.numpy as jnp test_dense = nn.Dense(10) test_dense(jnp.ones((5,5)))
相反,你应该通过
Module.apply()
传递变量(参数和其他状态)(或使用Module.init()
获取初始变量)from jax import random variables = test_dense.init(random.key(0), jnp.ones((5,5))) y = test_dense.apply(variables, jnp.ones((5,5)))
- exception flax.errors.CallSetupUnboundModuleError[源代码]#
当你尝试直接调用
.setup()
时,会发生此错误。例如,当尝试运行此代码时,会引发该错误
from flax import linen as nn import jax.numpy as jnp class MyModule(nn.Module): def setup(self): self.submodule = MySubModule() module = MyModule() module.setup() # <-- ERROR! submodule = module.submodule
一般来说,你不应该自己调用
.setup()
,如果你需要访问在setup
中定义的字段或子模块,你可以创建一个函数来提取它并将其传递给nn.apply
# setup() will be called automatically by ``nn.apply`` def get_submodule(module): return module.submodule.clone() # avoid leaking the Scope empty_variables = {} # you can also use the real variables submodule = nn.apply(get_submodule, module)(empty_variables)
当你尝试在未绑定的模块上调用
nn.share_scope
时,会发生此错误。例如,当你尝试在顶层使用nn.share_scope
时from flax import linen as nn class CustomDense(nn.Dense): def __call__(self, x): return super().__call__(x) + 1 custom_dense = CustomDense(5) dense = nn.Dense(5) # has the parameters nn.share_scope(custom_dense, dense) # <-- ERROR!
- exception flax.errors.CallUnbindOnUnboundModuleError[源代码]#
当你尝试在未绑定的模块上调用
.unbind()
时,会发生此错误。例如,当你尝试运行以下示例时,会引发错误from flax import linen as nn class MyModule(nn.Module): @nn.compact def __call__(self, x): return nn.Dense(features=10)(x) module = MyModule() module.unbind() # <-- ERROR!
相反,你应该在调用
.unbind()
之前将模块bind
到变量集合bound_module = module.bind(variables) ... # do something with bound_module module = bound_module.unbind() # <-- OK!
- exception flax.errors.CursorFindError(cursor=None, cursor2=None)[源代码]#
调用
Cursor.find()
时出错。如果根据
cond_fn
的条件没有找到对象或找到了多个对象,则会发生此错误。
- exception flax.errors.DescriptorAttributeError[source]#
当您尝试访问一个访问不存在属性的属性时,会发生此错误。
例如,当尝试运行以下代码时会引发此错误
class Foo(nn.Module): @property def prop(self): return self.non_existent_field # ERROR! def __call__(self, x): return self.prop foo = Foo() variables = foo.init(jax.random.key(0), jnp.ones(shape=(1, 8)))
- exception flax.errors.IncorrectPostInitOverrideError[source]#
当您重写
.__post_init__()
而没有调用super().__post_init__()
时,会发生此错误。例如,当尝试运行以下代码时会引发此错误
from flax import linen as nn import jax.numpy as jnp import jax class A(nn.Module): x: float def __post_init__(self): self.x_square = self.x ** 2 # super().__post_init__() <-- forgot to add this line @nn.compact def __call__(self, input): return input + 3 r = A(x=3) r.init(jax.random.key(2), jnp.ones(3))
- exception flax.errors.InvalidCheckpointError(path, step)[source]#
检查点不能存储在已在当前或后续步骤具有检查点的目录中。
在当前或后续步骤具有检查点的目录中。
您可以传递
overwrite=True
来禁用此行为并覆盖目标目录中的现有检查点。
- exception flax.errors.InvalidInstanceModuleError[source]#
当您尝试在 Module 类本身而不是 Module 类的实例上调用
.init()
、.init_with_output()
、.apply()
或.bind()
时,会发生此错误。 例如,当尝试运行以下代码时会引发此错误在 Module 类本身而不是 Module 类的实例上调用时会发生此错误。例如,当尝试运行以下代码时会引发此错误
class B(nn.Module): @nn.compact def __call__(self, x): return x k = random.key(0) x = random.uniform(random.key(1), (2,)) B.init(k, x) # B is module class, not B() a module instance B.apply(vs, x) # similar issue with apply called on class instead of instance.
- exception flax.errors.InvalidRngError(msg)[source]#
Module 中使用的所有 rng 都应适当地传递给
Module.init()
和Module.apply()
。 我们使用以下示例分别解释这两者class Bar(nn.Module): @nn.compact def __call__(self, x): some_param = self.param('some_param', nn.initializers.zeros_init(), (1, )) dropout_rng = self.make_rng('dropout') x = nn.Dense(features=4)(x) ... class Foo(nn.Module): @nn.compact def __call__(self, x): x = Bar()(x) ...
用于 Module.init() 的 PRNG
在此示例中,使用了两个 rng
params
用于初始化模型的参数。 此 rng 用于初始化some_params
参数,并用于初始化Bar
中使用的Dense
Module 的权重。dropout
用于Bar
中使用的 dropout rng。
因此,
Foo
的初始化方式如下init_rngs = {'params': random.key(0), 'dropout': random.key(1)} variables = Foo().init(init_rngs, init_inputs)
如果 Module 仅需要用于
params
的 rng,则可以使用SomeModule().init(rng, ...) # Shorthand for {'params': rng}
用于 Module.apply() 的 PRNG
当应用
Foo
时,只需要dropout
的 rng,因为params
仅用于初始化 Module 参数Foo().apply(variables, inputs, rngs={'dropout': random.key(2)})
如果 Module 仅需要用于
params
的 rng,则完全不必为 apply 提供 rngSomeModule().apply(variables, inputs) # rngs=None
- exception flax.errors.JaxTransformError[source]#
JAX 转换和 Flax 模块不能混合使用。
JAX 的函数式转换期望纯函数。 当您想在 Flax 模型 **内部** 使用 JAX 转换时,您应该使用 Flax 转换包装器(例如:
flax.linen.vmap
、flax.linen.scan
等)。
- exception flax.errors.LazyInitError(partial_val)[source]#
Lazy Init 函数具有不可计算的返回值。
当使用影响已初始化变量的
jax.ShapeDtypeStruct
将参数传递给 lazy_init 时,会发生这种情况。 确保 init 函数仅使用 shape 和 dtype,或者如果不可能,请传递实际的 JAX 数组。示例
class Foo(nn.Module): @compact def __call__(self, x): # This parameter depends on the input x # this causes an error when using lazy_init. k = self.param("kernel", lambda _: x) return x * k Foo().lazy_init(random.key(0), jax.ShapeDtypeStruct((8, 4), jnp.float32))
- exception flax.errors.MPACheckpointingRequiredError(path, step)[source]#
要最佳地保存和恢复多进程数组(来自 pjit 的 GDA 或 jax 数组输出),请使用 GlobalAsyncCheckpointManager。
您可以在顶层创建一个 GlobalAsyncCheckpointManager 并将其作为参数传递
from jax.experimental.gda_serialization import serialization as gdas gda_manager = gdas.GlobalAsyncCheckpointManager() save_checkpoint(..., gda_manager=gda_manager)
- exception flax.errors.MPARestoreDataCorruptedError(step, path)[source]#
存储在 Google Cloud Storage 中的多进程数组不包含 “commit_success.txt” 文件,该文件应在保存结束时写入。
未能找到它可能表明您保存的 GDA 数据已损坏。
- exception flax.errors.MPARestoreTargetRequiredError(path, step, key=None)[source]#
使用多进程数组恢复检查点时,请提供有效的目标。
多进程数组需要初始化分片(全局网格和分区规范)。 因此,要恢复包含多进程数组的检查点,请确保您传递的
target
在相应的树结构位置包含有效的多进程数组。 如果您无法提供完整的有效target
,请考虑allow_partial_mpa_restoration=True
。
- exception flax.errors.ModifyScopeVariableError(col, variable_name, scope_path)[source]#
如果变量所属的集合是不可变的,则无法更新该变量。
当您应用 Module 时,应指定哪些变量集合是可变的
class MyModule(nn.Module): @nn.compact def __call__(self, x): ... var = self.variable('batch_stats', 'mean', ...) var.value = ... ... v = MyModule.init(...) ... logits = MyModule.apply(v, batch) # This throws an error. logits = MyModule.apply(v, batch, mutable=['batch_stats']) # This works.
- exception flax.errors.MultipleMethodsCompactError[source]#
@compact
装饰器最多只能添加到一个 Flax 方法中模块。 为了解决这个问题,您可以
移除
@compact
,并使用Module.setup()
定义子模块和变量。使用两个单独的模块,它们都具有唯一的
@compact
方法。
TODO(marcvanzee): 链接到解释其背后动机的设计说明。不需要与
hk.transparent
等效的东西,而且它使子模块更加合理,因为不需要为方法名称添加前缀。
- exception flax.errors.NameInUseError(key_type, value, module_name)[source]#
当尝试创建具有现有名称的子模块、参数或变量时,会引发此错误。
它们都被认为是在同一命名空间中。
共享子模块
这是共享子模块的错误模式
y = nn.Dense(feature=3, name='bar')(x) z = nn.Dense(feature=3, name='bar')(x+epsilon)
相反,模块应该通过实例来共享
dense = nn.Dense(feature=3, name='bar') y = dense(x) z = dense(x+epsilon)
如果子模块未提供名称,则会自动为其指定一个唯一的名称
class MyModule(nn.Module): @nn.compact def __call__(self, x): x = MySubModule()(x) x = MySubModule()(x) # This is fine. return x
参数和变量
参数名称可能会与子模块或变量冲突,因为它们都存储在同一个变量字典中
class Foo(nn.Module): @nn.compact def __call__(self, x): bar = self.param('bar', nn.initializers.zeros_init(), (1, )) embed = nn.Embed(num_embeddings=2, features=5, name='bar') # <-- ERROR!
变量也应该具有唯一的名称,即使它们有自己的集合
class Foo(nn.Module): @nn.compact def __call__(self, inputs): _ = self.param('mean', initializers.lecun_normal(), (2, 2)) _ = self.variable('stats', 'mean', initializers.zeros_init(), (2, 2))
- exception flax.errors.PartitioningUnspecifiedError(target)[source]#
当尝试通过以下方式向分区变量添加轴时,会引发此错误
使用转换(例如:
scan
、vmap
),而未在metadata_params
字典中指定“partition_name”。
- exception flax.errors.ReservedModuleAttributeError(annotations)[source]#
当创建使用保留属性的模块时,会抛出此错误。
以下属性是保留的
parent
: 此模块的父模块。name
: 此模块的名称。
- exception flax.errors.ScopeCollectionNotFound(col_name, var_name, scope_path)[source]#
当尝试从空集合访问变量时,会抛出此错误。
有两种常见原因
- 集合未正确传递给
apply
。例如,您可能使用了module.apply(params, ...)
而不是module.apply({'params': params}, ...)
。 - 集合为空,因为变量需要初始化。在这种情况下,您应该在应用期间使集合可变 (例如:
module.apply(variables, ..., mutable=['state'])
)。
- exception flax.errors.ScopeParamNotFoundError(param_name, scope_path)[source]#
当尝试访问不存在的参数时,会抛出此错误。
例如,在下面的代码中,初始化的嵌入名称“embedding”与应用名称“embed”不匹配
class Embed(nn.Module): num_embeddings: int features: int @nn.compact def __call__(self, inputs, embed_name='embedding'): inputs = inputs.astype('int32') embedding = self.param(embed_name, jax.nn.initializers.lecun_normal(), (self.num_embeddings, self.features)) return embedding[inputs] model = Embed(4, 8) variables = model.init(random.key(0), jnp.ones((5, 5, 1))) _ = model.apply(variables, jnp.ones((5, 5, 1)), 'embed')
- exception flax.errors.ScopeParamShapeError(param_name, scope_path, value_shape, init_shape)[source]#
当现有参数的形状与
init_fn
的返回值形状不同时,会抛出此错误。当在Module.apply()
期间提供的形状与初始化模块时使用的形状不同时,可能会发生这种情况。例如,以下代码会抛出此错误,因为应用形状(
(5, 5, 1)
)与初始化形状((5, 5
)不同。因此,init
期间内核的形状为(1, 8)
,而apply
期间的形状为(5, 8)
,这会导致此错误。class NoBiasDense(nn.Module): features: int = 8 @nn.compact def __call__(self, x): kernel = self.param('kernel', lecun_normal(), (x.shape[-1], self.features)) # <--- ERROR y = lax.dot_general(x, kernel, (((x.ndim - 1,), (0,)), ((), ()))) return y variables = NoBiasDense().init(random.key(0), jnp.ones((5, 5, 1))) _ = NoBiasDense().apply(variables, jnp.ones((5, 5)))
- exception flax.errors.ScopeVariableNotFoundError(name, col, scope_path)[source]#
当尝试在不可变的集合中在 Scope 中使用变量时,会抛出此错误。
为了创建此变量,请使用
Module.apply()
中的mutable
关键字显式地将集合标记为可变的。
- exception flax.errors.SetAttributeFrozenModuleError(module_cls, attr_name, attr_val)[source]#
您只能在
self
内部将模块属性分配给Module.setup()
。在该方法之外,模块实例被冻结(即不可变)。此行为类似于冻结的 Python 数据类。例如,在以下情况下会引发此错误
class SomeModule(nn.Module): @nn.compact def __call__(self, x, num_features=10): self.num_features = num_features # <-- ERROR! x = nn.Dense(self.num_features)(x) return x s = SomeModule().init(random.key(0), jnp.ones((5, 5)))
类似地,当尝试在构造子模块后修改子模块的属性时,即使这是在父模块的
setup()
方法中完成的,也会引发此错误class Foo(nn.Module): def setup(self): self.dense = nn.Dense(features=10) self.dense.features = 20 # <--- This is not allowed def __call__(self, x): return self.dense(x)
- exception flax.errors.SetAttributeInModuleSetupError[source]#
不允许在
class Foo(nn.Module): features: int = 6 def setup(self): self.features = 3 # <-- ERROR def __call__(self, x): return nn.Dense(self.features)(x) variables = SomeModule().init(random.key(0), jnp.ones((1, )))
相反,这些属性应该在初始化模块时设置
class Foo(nn.Module): features: int = 6 @nn.compact def __call__(self, x): return nn.Dense(self.features)(x) variables = SomeModule(features=3).init(random.key(0), jnp.ones((1, )))
TODO(marcvanzee): 链接到解释为什么模块必须保持冻结的设计说明(否则我们无法安全地克隆它们,这是我们用于提升转换的方式)。
- exception flax.errors.TransformTargetError(target)[source]#
Linen 转换必须应用于模块类或将模块实例作为第一个参数的函数。
当将无效目标传递给 linen 转换(nn.vmap、nn.scan 等)时,会发生此错误。例如,当尝试转换模块实例时,会发生此错误
nn.vmap(nn.Dense(features))(x) # raises TransformTargetError
您可以直接转换
nn.Dense
类nn.vmap(nn.Dense)(features)(x)
或者,您可以创建一个将模块实例作为第一个参数的函数
class BatchDense(nn.Module): @nn.compact def __call__(self, x): return nn.vmap( lambda mdl, x: mdl(x), variable_axes={'params': 0}, split_rngs={'params': True})(nn.Dense(3), x)