flax.cursor 包#
Cursor API 允许 pytree 的可变性。与进行许多嵌套的 dataclasses.replace
调用相比,此 API 提供了一种更符合人体工程学的解决方案,可以对深度嵌套的不可变数据结构进行部分更新。
为了说明,请考虑下面的示例
>>> from flax.cursor import cursor
>>> import dataclasses
>>> from typing import Any
>>> @dataclasses.dataclass(frozen=True)
>>> class A:
... x: Any
>>> a = A(A(A(A(A(A(A(0)))))))
要使用 dataclasses.replace
替换整数 0
,我们需要编写许多嵌套的调用
>>> a2 = dataclasses.replace(
... a,
... x=dataclasses.replace(
... a.x,
... x=dataclasses.replace(
... a.x.x,
... x=dataclasses.replace(
... a.x.x.x,
... x=dataclasses.replace(
... a.x.x.x.x,
... x=dataclasses.replace(
... a.x.x.x.x.x,
... x=dataclasses.replace(a.x.x.x.x.x.x, x=1),
... ),
... ),
... ),
... ),
... ),
... )
使用 Cursor API 可以更简单地实现等效效果
>>> a3 = cursor(a).x.x.x.x.x.x.x.set(1)
>>> assert a2 == a3
Cursor 对象会跟踪对其所做的更改,并在调用 .build
时,生成一个包含累积更改的新对象。基本用法包括将对象包装在 Cursor 中,对 Cursor 对象进行更改,并生成一个包含累积更改的原始对象的新副本。
- flax.cursor.cursor(obj)[source]#
在
obj
上包装Cursor
并返回它。然后可以通过以下方式将更改应用于 Cursor 对象通过
.set
方法进行单行更改进行多次更改,然后调用
.build
方法通过
.apply_update
方法,根据 pytree 路径和节点值有条件地进行多次更改,然后调用.build
方法
.set
示例>>> from flax.cursor import cursor >>> dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]} >>> modified_dict_obj = cursor(dict_obj)['b'][0].set(10) >>> assert modified_dict_obj == {'a': 1, 'b': (10, 3), 'c': [4, 5]}
.build
示例>>> from flax.cursor import cursor >>> dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]} >>> c = cursor(dict_obj) >>> c['b'][0] = 10 >>> c['a'] = (100, 200) >>> modified_dict_obj = c.build() >>> assert modified_dict_obj == {'a': (100, 200), 'b': (10, 3), 'c': [4, 5]}
.apply_update
示例>>> from flax.cursor import cursor >>> from flax.training import train_state >>> import optax >>> def update_fn(path, value): ... '''Replace params with empty dictionary.''' ... if 'params' in path: ... return {} ... return value >>> state = train_state.TrainState.create( ... apply_fn=lambda x: x, ... params={'a': 1, 'b': 2}, ... tx=optax.adam(1e-3), ... ) >>> c = cursor(state) >>> state2 = c.apply_update(update_fn).build() >>> assert state2.params == {} >>> assert state.params == {'a': 1, 'b': 2} # make sure original params are unchanged
如果底层
obj
是list
或tuple
,也可以迭代 Cursor 对象以获取子 Cursor>>> from flax.cursor import cursor >>> c = cursor(((1, 2), (3, 4))) >>> for child_c in c: ... child_c[1] *= -1 >>> assert c.build() == ((1, -2), (3, -4))
查看每个方法的文档字符串以查看更多用法示例。
- 参数
obj – 你想要在其中包装 Cursor 的对象
- 返回
围绕 obj 包装的 Cursor 对象。
- class flax.cursor.Cursor(obj, parent_key)[source]#
- apply_update(update_fn)[source]#
遍历 Cursor 对象,并通过
update_fn
递归记录有条件的更改。更改记录在 Cursor 对象的._changes
字典中。要生成具有累积更改的原始对象的副本,请在调用.apply_update
后调用.build
方法。update_fn
的函数签名是(str, Any) -> Any
输入参数是当前键路径(以
'/'
分隔的字符串形式)和当前键路径的值输出是新值(由
update_fn
修改,或者如果条件未满足则与输入值相同)
注意
如果
update_fn
返回修改后的值,则此方法不会进一步向下递归该分支以记录更改。例如,如果我们打算用 int 替换指向字典的属性,则无需在字典内部寻找进一步的更改,因为无论如何都会替换该字典。is
运算符用于确定返回值是否被修改(通过将其与输入值进行比较)。因此,如果update_fn
修改了可变容器(例如列表、字典等)并返回相同的容器,则.apply_update
会将返回值视为未修改,因为它包含相同的id
。为了避免这种情况,请返回修改值的副本。.apply_update
不会为 pytree 最高级别的值(即根节点)调用update_fn
。update_fn
将首先在根节点的子节点上调用,然后 pytree 遍历将从那里继续递归。
示例
>>> import flax.linen as nn >>> from flax.cursor import cursor >>> import jax, jax.numpy as jnp >>> class Model(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(3)(x) ... x = nn.relu(x) ... x = nn.Dense(3)(x) ... x = nn.relu(x) ... x = nn.Dense(3)(x) ... x = nn.relu(x) ... return x >>> params = Model().init(jax.random.key(0), jnp.empty((1, 2)))['params'] >>> def update_fn(path, value): ... '''Multiply all dense kernel params by 2 and add 1. ... Subtract the Dense_1 bias param by 1.''' ... if 'kernel' in path: ... return value * 2 + 1 ... elif 'Dense_1' in path and 'bias' in path: ... return value - 1 ... return value >>> c = cursor(params) >>> new_params = c.apply_update(update_fn).build() >>> for layer in ('Dense_0', 'Dense_1', 'Dense_2'): ... assert (new_params[layer]['kernel'] == 2 * params[layer]['kernel'] + 1).all() ... if layer == 'Dense_1': ... assert (new_params[layer]['bias'] == params[layer]['bias'] - 1).all() ... else: ... assert (new_params[layer]['bias'] == params[layer]['bias']).all() >>> assert jax.tree_util.tree_all( ... jax.tree_util.tree_map( ... lambda x, y: (x == y).all(), ... params, ... Model().init(jax.random.key(0), jnp.empty((1, 2)))[ ... 'params' ... ], ... ) ... ) # make sure original params are unchanged
- 参数
update_fn – 将有条件地记录对 Cursor 对象更改的函数
- 返回
当前 Cursor 对象,其中包含
update_fn
指定的已记录的有条件更改。要生成具有累积更改的原始对象的副本,请在调用.apply_update
后调用.build
方法。
- build()[source]#
创建并返回一个具有累积更改的原始对象的副本。此方法在对 Cursor 对象进行更改后调用。
注意
新对象是自下而上构建的,更改将首先应用于叶节点,然后应用于其父节点,一直到根节点。
示例
>>> from flax.cursor import cursor >>> from flax.training import train_state >>> import optax >>> dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]} >>> c = cursor(dict_obj) >>> c['b'][0] = 10 >>> c['a'] = (100, 200) >>> modified_dict_obj = c.build() >>> assert modified_dict_obj == {'a': (100, 200), 'b': (10, 3), 'c': [4, 5]} >>> state = train_state.TrainState.create( ... apply_fn=lambda x: x, ... params=dict_obj, ... tx=optax.adam(1e-3), ... ) >>> new_fn = lambda x: x + 1 >>> c = cursor(state) >>> c.params['b'][1] = 10 >>> c.apply_fn = new_fn >>> modified_state = c.build() >>> assert modified_state.params == {'a': 1, 'b': (2, 10), 'c': [4, 5]} >>> assert modified_state.apply_fn == new_fn
- 返回
具有累积更改的原始对象的副本。
- find(cond_fn)[source]#
遍历 Cursor 对象,并返回满足
cond_fn
中条件的子 Cursor 对象。cond_fn
的函数签名是(str, Any) -> bool
输入参数是当前键路径(以
'/'
分隔的字符串形式)和当前键路径的值输出是一个布尔值,表示是否在此路径返回子 Cursor 对象
如果未找到任何对象或找到多个满足
cond_fn
条件的对象,则会引发CursorFindError
异常。之所以抛出异常,是因为用户应始终期望此方法返回唯一一个其对应键路径和值满足cond_fn
条件的对象。注意
如果在特定键路径上
cond_fn
的计算结果为 True,则此方法不会在该分支中进一步递归;即,此方法将查找并返回特定键路径中满足cond_fn
条件的“最早”子节点。.find
不会搜索 pytree 最顶层(即根节点)的值。cond_fn
将从根节点的子节点开始递归评估。
示例
>>> import flax.linen as nn >>> from flax.cursor import cursor >>> import jax, jax.numpy as jnp >>> class Model(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(3)(x) ... x = nn.relu(x) ... x = nn.Dense(3)(x) ... x = nn.relu(x) ... x = nn.Dense(3)(x) ... x = nn.relu(x) ... return x >>> params = Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))['params'] >>> def cond_fn(path, value): ... '''Find the second dense layer params.''' ... return 'Dense_1' in path >>> new_params = cursor(params).find(cond_fn)['bias'].set(params['Dense_1']['bias'] + 1) >>> for layer in ('Dense_0', 'Dense_1', 'Dense_2'): ... if layer == 'Dense_1': ... assert (new_params[layer]['bias'] == params[layer]['bias'] + 1).all() ... else: ... assert (new_params[layer]['bias'] == params[layer]['bias']).all() >>> c = cursor(params) >>> c2 = c.find(cond_fn) >>> c2['kernel'] += 2 >>> c2['bias'] += 2 >>> new_params = c.build() >>> for layer in ('Dense_0', 'Dense_1', 'Dense_2'): ... if layer == 'Dense_1': ... assert (new_params[layer]['kernel'] == params[layer]['kernel'] + 2).all() ... assert (new_params[layer]['bias'] == params[layer]['bias'] + 2).all() ... else: ... assert (new_params[layer]['kernel'] == params[layer]['kernel']).all() ... assert (new_params[layer]['bias'] == params[layer]['bias']).all() >>> assert jax.tree_util.tree_all( ... jax.tree_util.tree_map( ... lambda x, y: (x == y).all(), ... params, ... Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))[ ... 'params' ... ], ... ) ... ) # make sure original params are unchanged
- 参数
cond_fn – 用于有条件地查找子 Cursor 对象的函数
- 返回
满足
cond_fn
中条件的子 Cursor 对象。
- find_all(cond_fn)[source]#
遍历 Cursor 对象并返回一个子 Cursor 对象生成器,这些对象满足
cond_fn
中的条件。cond_fn
的函数签名是(str, Any) -> bool
输入参数是当前键路径(以
'/'
分隔的字符串形式)和当前键路径的值输出是一个布尔值,表示是否在此路径返回子 Cursor 对象
注意
如果在特定键路径上
cond_fn
的计算结果为 True,则此方法不会在该分支中进一步递归;即,此方法将查找并返回特定键路径中满足cond_fn
条件的“最早”子节点。.find_all
不会搜索 pytree 最顶层(即根节点)的值。cond_fn
将从根节点的子节点开始递归评估。
示例
>>> import flax.linen as nn >>> from flax.cursor import cursor >>> import jax, jax.numpy as jnp >>> class Model(nn.Module): ... @nn.compact ... def __call__(self, x): ... x = nn.Dense(3)(x) ... x = nn.relu(x) ... x = nn.Dense(3)(x) ... x = nn.relu(x) ... x = nn.Dense(3)(x) ... x = nn.relu(x) ... return x >>> params = Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))['params'] >>> def cond_fn(path, value): ... '''Find all dense layer params.''' ... return 'Dense' in path >>> c = cursor(params) >>> for dense_params in c.find_all(cond_fn): ... dense_params['bias'] += 1 >>> new_params = c.build() >>> for layer in ('Dense_0', 'Dense_1', 'Dense_2'): ... assert (new_params[layer]['bias'] == params[layer]['bias'] + 1).all() >>> assert jax.tree_util.tree_all( ... jax.tree_util.tree_map( ... lambda x, y: (x == y).all(), ... params, ... Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))[ ... 'params' ... ], ... ) ... ) # make sure original params are unchanged
- 参数
cond_fn – 用于有条件地查找子 Cursor 对象的函数
- 返回
满足
cond_fn
中条件的子 Cursor 对象生成器。
- set(value)[source]#
为 Cursor 对象中的属性、属性、元素或条目设置新值,并返回原始对象的副本,其中包含新的设置值。
示例
>>> from flax.cursor import cursor >>> from flax.training import train_state >>> import optax >>> dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]} >>> modified_dict_obj = cursor(dict_obj)['b'][0].set(10) >>> assert modified_dict_obj == {'a': 1, 'b': (10, 3), 'c': [4, 5]} >>> state = train_state.TrainState.create( ... apply_fn=lambda x: x, ... params=dict_obj, ... tx=optax.adam(1e-3), ... ) >>> modified_state = cursor(state).params['b'][1].set(10) >>> assert modified_state.params == {'a': 1, 'b': (2, 10), 'c': [4, 5]}
- 参数
value – 用于设置 Cursor 对象中的属性、属性、元素或条目的值
- 返回
具有新设置值的原始对象的副本。