flax.cursor 包#

Cursor API 允许 pytree 的可变性。与进行许多嵌套的 dataclasses.replace 调用相比,此 API 提供了一种更符合人体工程学的解决方案,可以对深度嵌套的不可变数据结构进行部分更新。

为了说明,请考虑下面的示例

>>> from flax.cursor import cursor
>>> import dataclasses
>>> from typing import Any

>>> @dataclasses.dataclass(frozen=True)
>>> class A:
...   x: Any

>>> a = A(A(A(A(A(A(A(0)))))))

要使用 dataclasses.replace 替换整数 0,我们需要编写许多嵌套的调用

>>> a2 = dataclasses.replace(
...   a,
...   x=dataclasses.replace(
...     a.x,
...     x=dataclasses.replace(
...       a.x.x,
...       x=dataclasses.replace(
...         a.x.x.x,
...         x=dataclasses.replace(
...           a.x.x.x.x,
...           x=dataclasses.replace(
...             a.x.x.x.x.x,
...             x=dataclasses.replace(a.x.x.x.x.x.x, x=1),
...           ),
...         ),
...       ),
...     ),
...   ),
... )

使用 Cursor API 可以更简单地实现等效效果

>>> a3 = cursor(a).x.x.x.x.x.x.x.set(1)
>>> assert a2 == a3

Cursor 对象会跟踪对其所做的更改,并在调用 .build 时,生成一个包含累积更改的新对象。基本用法包括将对象包装在 Cursor 中,对 Cursor 对象进行更改,并生成一个包含累积更改的原始对象的新副本。

flax.cursor.cursor(obj)[source]#

obj 上包装 Cursor 并返回它。然后可以通过以下方式将更改应用于 Cursor 对象

  • 通过 .set 方法进行单行更改

  • 进行多次更改,然后调用 .build 方法

  • 通过 .apply_update 方法,根据 pytree 路径和节点值有条件地进行多次更改,然后调用 .build 方法

.set 示例

>>> from flax.cursor import cursor

>>> dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]}
>>> modified_dict_obj = cursor(dict_obj)['b'][0].set(10)
>>> assert modified_dict_obj == {'a': 1, 'b': (10, 3), 'c': [4, 5]}

.build 示例

>>> from flax.cursor import cursor

>>> dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]}
>>> c = cursor(dict_obj)
>>> c['b'][0] = 10
>>> c['a'] = (100, 200)
>>> modified_dict_obj = c.build()
>>> assert modified_dict_obj == {'a': (100, 200), 'b': (10, 3), 'c': [4, 5]}

.apply_update 示例

>>> from flax.cursor import cursor
>>> from flax.training import train_state
>>> import optax

>>> def update_fn(path, value):
...   '''Replace params with empty dictionary.'''
...   if 'params' in path:
...     return {}
...   return value

>>> state = train_state.TrainState.create(
...     apply_fn=lambda x: x,
...     params={'a': 1, 'b': 2},
...     tx=optax.adam(1e-3),
... )
>>> c = cursor(state)
>>> state2 = c.apply_update(update_fn).build()
>>> assert state2.params == {}
>>> assert state.params == {'a': 1, 'b': 2} # make sure original params are unchanged

如果底层 objlisttuple,也可以迭代 Cursor 对象以获取子 Cursor

>>> from flax.cursor import cursor

>>> c = cursor(((1, 2), (3, 4)))
>>> for child_c in c:
...   child_c[1] *= -1
>>> assert c.build() == ((1, -2), (3, -4))

查看每个方法的文档字符串以查看更多用法示例。

参数

obj – 你想要在其中包装 Cursor 的对象

返回

围绕 obj 包装的 Cursor 对象。

class flax.cursor.Cursor(obj, parent_key)[source]#
apply_update(update_fn)[source]#

遍历 Cursor 对象,并通过 update_fn 递归记录有条件的更改。更改记录在 Cursor 对象的 ._changes 字典中。要生成具有累积更改的原始对象的副本,请在调用 .apply_update 后调用 .build 方法。

update_fn 的函数签名是 (str, Any) -> Any

  • 输入参数是当前键路径(以 '/' 分隔的字符串形式)和当前键路径的值

  • 输出是新值(由 update_fn 修改,或者如果条件未满足则与输入值相同)

注意

  • 如果 update_fn 返回修改后的值,则此方法不会进一步向下递归该分支以记录更改。例如,如果我们打算用 int 替换指向字典的属性,则无需在字典内部寻找进一步的更改,因为无论如何都会替换该字典。

  • is 运算符用于确定返回值是否被修改(通过将其与输入值进行比较)。因此,如果 update_fn 修改了可变容器(例如列表、字典等)并返回相同的容器,则 .apply_update 会将返回值视为未修改,因为它包含相同的 id。为了避免这种情况,请返回修改值的副本。

  • .apply_update 不会为 pytree 最高级别的值(即根节点)调用 update_fnupdate_fn 将首先在根节点的子节点上调用,然后 pytree 遍历将从那里继续递归。

示例

>>> import flax.linen as nn
>>> from flax.cursor import cursor
>>> import jax, jax.numpy as jnp

>>> class Model(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(3)(x)
...     x = nn.relu(x)
...     x = nn.Dense(3)(x)
...     x = nn.relu(x)
...     x = nn.Dense(3)(x)
...     x = nn.relu(x)
...     return x

>>> params = Model().init(jax.random.key(0), jnp.empty((1, 2)))['params']

>>> def update_fn(path, value):
...   '''Multiply all dense kernel params by 2 and add 1.
...   Subtract the Dense_1 bias param by 1.'''
...   if 'kernel' in path:
...     return value * 2 + 1
...   elif 'Dense_1' in path and 'bias' in path:
...     return value - 1
...   return value

>>> c = cursor(params)
>>> new_params = c.apply_update(update_fn).build()
>>> for layer in ('Dense_0', 'Dense_1', 'Dense_2'):
...   assert (new_params[layer]['kernel'] == 2 * params[layer]['kernel'] + 1).all()
...   if layer == 'Dense_1':
...     assert (new_params[layer]['bias'] == params[layer]['bias'] - 1).all()
...   else:
...     assert (new_params[layer]['bias'] == params[layer]['bias']).all()

>>> assert jax.tree_util.tree_all(
...       jax.tree_util.tree_map(
...           lambda x, y: (x == y).all(),
...           params,
...           Model().init(jax.random.key(0), jnp.empty((1, 2)))[
...               'params'
...           ],
...       )
...   ) # make sure original params are unchanged
参数

update_fn – 将有条件地记录对 Cursor 对象更改的函数

返回

当前 Cursor 对象,其中包含 update_fn 指定的已记录的有条件更改。要生成具有累积更改的原始对象的副本,请在调用 .apply_update 后调用 .build 方法。

build()[source]#

创建并返回一个具有累积更改的原始对象的副本。此方法在对 Cursor 对象进行更改后调用。

注意

新对象是自下而上构建的,更改将首先应用于叶节点,然后应用于其父节点,一直到根节点。

示例

>>> from flax.cursor import cursor
>>> from flax.training import train_state
>>> import optax

>>> dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]}
>>> c = cursor(dict_obj)
>>> c['b'][0] = 10
>>> c['a'] = (100, 200)
>>> modified_dict_obj = c.build()
>>> assert modified_dict_obj == {'a': (100, 200), 'b': (10, 3), 'c': [4, 5]}

>>> state = train_state.TrainState.create(
...     apply_fn=lambda x: x,
...     params=dict_obj,
...     tx=optax.adam(1e-3),
... )
>>> new_fn = lambda x: x + 1
>>> c = cursor(state)
>>> c.params['b'][1] = 10
>>> c.apply_fn = new_fn
>>> modified_state = c.build()
>>> assert modified_state.params == {'a': 1, 'b': (2, 10), 'c': [4, 5]}
>>> assert modified_state.apply_fn == new_fn
返回

具有累积更改的原始对象的副本。

find(cond_fn)[source]#

遍历 Cursor 对象,并返回满足 cond_fn 中条件的子 Cursor 对象。cond_fn 的函数签名是 (str, Any) -> bool

  • 输入参数是当前键路径(以 '/' 分隔的字符串形式)和当前键路径的值

  • 输出是一个布尔值,表示是否在此路径返回子 Cursor 对象

如果未找到任何对象或找到多个满足 cond_fn 条件的对象,则会引发 CursorFindError 异常。之所以抛出异常,是因为用户应始终期望此方法返回唯一一个其对应键路径和值满足 cond_fn 条件的对象。

注意

  • 如果在特定键路径上 cond_fn 的计算结果为 True,则此方法不会在该分支中进一步递归;即,此方法将查找并返回特定键路径中满足 cond_fn 条件的“最早”子节点。

  • .find 不会搜索 pytree 最顶层(即根节点)的值。 cond_fn 将从根节点的子节点开始递归评估。

示例

>>> import flax.linen as nn
>>> from flax.cursor import cursor
>>> import jax, jax.numpy as jnp

>>> class Model(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(3)(x)
...     x = nn.relu(x)
...     x = nn.Dense(3)(x)
...     x = nn.relu(x)
...     x = nn.Dense(3)(x)
...     x = nn.relu(x)
...     return x

>>> params = Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))['params']

>>> def cond_fn(path, value):
...   '''Find the second dense layer params.'''
...   return 'Dense_1' in path

>>> new_params = cursor(params).find(cond_fn)['bias'].set(params['Dense_1']['bias'] + 1)

>>> for layer in ('Dense_0', 'Dense_1', 'Dense_2'):
...   if layer == 'Dense_1':
...     assert (new_params[layer]['bias'] == params[layer]['bias'] + 1).all()
...   else:
...     assert (new_params[layer]['bias'] == params[layer]['bias']).all()

>>> c = cursor(params)
>>> c2 = c.find(cond_fn)
>>> c2['kernel'] += 2
>>> c2['bias'] += 2
>>> new_params = c.build()

>>> for layer in ('Dense_0', 'Dense_1', 'Dense_2'):
...   if layer == 'Dense_1':
...     assert (new_params[layer]['kernel'] == params[layer]['kernel'] + 2).all()
...     assert (new_params[layer]['bias'] == params[layer]['bias'] + 2).all()
...   else:
...     assert (new_params[layer]['kernel'] == params[layer]['kernel']).all()
...     assert (new_params[layer]['bias'] == params[layer]['bias']).all()

>>> assert jax.tree_util.tree_all(
...       jax.tree_util.tree_map(
...           lambda x, y: (x == y).all(),
...           params,
...           Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))[
...               'params'
...           ],
...       )
...   ) # make sure original params are unchanged
参数

cond_fn – 用于有条件地查找子 Cursor 对象的函数

返回

满足 cond_fn 中条件的子 Cursor 对象。

find_all(cond_fn)[source]#

遍历 Cursor 对象并返回一个子 Cursor 对象生成器,这些对象满足 cond_fn 中的条件。 cond_fn 的函数签名是 (str, Any) -> bool

  • 输入参数是当前键路径(以 '/' 分隔的字符串形式)和当前键路径的值

  • 输出是一个布尔值,表示是否在此路径返回子 Cursor 对象

注意

  • 如果在特定键路径上 cond_fn 的计算结果为 True,则此方法不会在该分支中进一步递归;即,此方法将查找并返回特定键路径中满足 cond_fn 条件的“最早”子节点。

  • .find_all 不会搜索 pytree 最顶层(即根节点)的值。 cond_fn 将从根节点的子节点开始递归评估。

示例

>>> import flax.linen as nn
>>> from flax.cursor import cursor
>>> import jax, jax.numpy as jnp

>>> class Model(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(3)(x)
...     x = nn.relu(x)
...     x = nn.Dense(3)(x)
...     x = nn.relu(x)
...     x = nn.Dense(3)(x)
...     x = nn.relu(x)
...     return x

>>> params = Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))['params']

>>> def cond_fn(path, value):
...   '''Find all dense layer params.'''
...   return 'Dense' in path

>>> c = cursor(params)
>>> for dense_params in c.find_all(cond_fn):
...   dense_params['bias'] += 1
>>> new_params = c.build()

>>> for layer in ('Dense_0', 'Dense_1', 'Dense_2'):
...   assert (new_params[layer]['bias'] == params[layer]['bias'] + 1).all()

>>> assert jax.tree_util.tree_all(
...       jax.tree_util.tree_map(
...           lambda x, y: (x == y).all(),
...           params,
...           Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))[
...               'params'
...           ],
...       )
...   ) # make sure original params are unchanged
参数

cond_fn – 用于有条件地查找子 Cursor 对象的函数

返回

满足 cond_fn 中条件的子 Cursor 对象生成器。

set(value)[source]#

为 Cursor 对象中的属性、属性、元素或条目设置新值,并返回原始对象的副本,其中包含新的设置值。

示例

>>> from flax.cursor import cursor
>>> from flax.training import train_state
>>> import optax

>>> dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]}
>>> modified_dict_obj = cursor(dict_obj)['b'][0].set(10)
>>> assert modified_dict_obj == {'a': 1, 'b': (10, 3), 'c': [4, 5]}

>>> state = train_state.TrainState.create(
...     apply_fn=lambda x: x,
...     params=dict_obj,
...     tx=optax.adam(1e-3),
... )
>>> modified_state = cursor(state).params['b'][1].set(10)
>>> assert modified_state.params == {'a': 1, 'b': (2, 10), 'c': [4, 5]}
参数

value – 用于设置 Cursor 对象中的属性、属性、元素或条目的值

返回

具有新设置值的原始对象的副本。