flax.traverse_util 包#
用于遍历不可变数据结构的实用程序。
Traversal 可用于迭代和更新复杂的数据结构。Traversals 接受一个对象并返回其内容的子集。例如,Traversal 可以选择对象的属性
>>> from flax import traverse_util
>>> import dataclasses
>>> @dataclasses.dataclass
... class Foo:
... foo: int = 0
... bar: int = 0
...
>>> x = Foo(foo=1)
>>> iterator = traverse_util.TraverseAttr('foo').iterate(x)
>>> list(iterator)
[1]
可以使用组合来构建更复杂的遍历。通常从标识遍历开始并使用方法链来构建预期的 Traversal 是有用的
>>> data = [{'foo': 1, 'bar': 2}, {'foo': 3, 'bar': 4}]
>>> traversal = traverse_util.t_identity.each()['foo']
>>> iterator = traversal.iterate(data)
>>> list(iterator)
[1, 3]
Traversals 也可以使用 update
方法进行更改
>>> data = {'foo': Foo(bar=2)}
>>> traversal = traverse_util.t_identity['foo'].bar
>>> data = traversal.update(lambda x: x + x, data)
>>> data
{'foo': Foo(foo=0, bar=4)}
Traversals 永远不会更改原始数据。因此,更新本质上返回包含所提供更新的数据副本。
遍历对象#
字典工具#
- flax.traverse_util.flatten_dict(xs, keep_empty_nodes=False, is_leaf=None, sep=None)[源代码]#
展平嵌套字典。
嵌套的键被展平为一个元组。请参阅
unflatten_dict
,了解如何恢复嵌套的字典结构。示例
>>> from flax.traverse_util import flatten_dict >>> xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}} >>> flat_xs = flatten_dict(xs) >>> flat_xs {('foo',): 1, ('bar', 'a'): 2}
请注意,空字典将被忽略,并且不会被
unflatten_dict
恢复。- 参数
xs – 一个嵌套字典
keep_empty_nodes – 将空字典替换为
traverse_util.empty_node
。is_leaf – 一个可选的函数,它接受下一个嵌套字典和嵌套键,如果嵌套字典是一个叶子(即,不应进一步展平),则返回 True。
sep – 如果指定,则返回的字典的键将是
sep
连接的字符串(如果None
,则键将是元组)。
- 返回
扁平化的字典。
- flax.traverse_util.unflatten_dict(xs, sep=None)[源代码]#
取消展平字典。
请参阅
flatten_dict
示例
>>> flat_xs = { ... ('foo',): 1, ... ('bar', 'a'): 2, ... } >>> xs = unflatten_dict(flat_xs) >>> xs {'foo': 1, 'bar': {'a': 2}}
- 参数
xs – 一个展平的字典
sep – 分隔符(与
flatten_dict()
使用的相同)。
- 返回
嵌套的字典。
- flax.traverse_util.path_aware_map(f, nested_dict)[源代码]#
一个映射函数,它在考虑每个叶子路径的情况下对嵌套字典结构进行操作。
示例
>>> import jax.numpy as jnp >>> from flax import traverse_util >>> params = {'a': {'x': 10, 'y': 3}, 'b': {'x': 20}} >>> f = lambda path, x: x + 5 if 'x' in path else -x >>> traverse_util.path_aware_map(f, params) {'a': {'x': 15, 'y': -3}, 'b': {'x': 25}}
- 参数
f – 一个可调用对象,它接受
(path, value)
参数并将它们映射到新值。其中path
是一个字符串元组。nested_dict – 一个嵌套的字典结构。
- 返回
一个新的嵌套字典结构,其中包含映射的值。
模型参数遍历#
- class flax.traverse_util.ModelParamTraversal(*args, **kwargs)[源]#
使用名称过滤器选择模型参数。
此遍历操作于参数的嵌套字典,并基于
filter_fn
参数选择子集。有关如何使用
ModelParamTraversal
通过特定的优化器更新参数树的子集,请参见flax.optim.MultiOptimizer
。