检查#
- flax.linen.tabulate(module, rngs, depth=None, show_repeated=False, mutable=DenyList(deny='intermediates'), console_kwargs=None, table_kwargs=mappingproxy({}), column_kwargs=mappingproxy({}), compute_flops=False, compute_vjp_flops=False, **kwargs)[source]#
返回一个函数,该函数创建模块的摘要,表示为表格。
此函数接受大多数相同的参数,并在内部调用 Module.init,除了它返回 (*args, **kwargs) -> str 形式的函数,其中 *args 和 **kwargs 在前向传递期间传递给 method (例如 __call__)。
tabulate 在底层使用 jax.eval_shape 来运行前向计算,而不会消耗任何 FLOP 或分配内存。
其他参数可以传递到 console_kwargs 参数中,例如 {'width': 120}。有关 console_kwargs 参数的完整列表,请参阅:https://rich.pythonlang.cn/en/stable/reference/console.html#rich.console.Console
示例
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> class Foo(nn.Module): ... @nn.compact ... def __call__(self, x): ... h = nn.Dense(4)(x) ... return nn.Dense(2)(h) >>> x = jnp.ones((16, 9)) >>> tabulate_fn = nn.tabulate( ... Foo(), jax.random.key(0), compute_flops=True, compute_vjp_flops=True) >>> # print(tabulate_fn(x))
这会给出以下输出
Foo Summary ┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ ┃ path ┃ module ┃ inputs ┃ outputs ┃ flops ┃ vjp_flops ┃ params ┃ ┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ │ │ Foo │ float32[16,9] │ float32[16,2] │ 1504 │ 4460 │ │ ├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤ │ Dense_0 │ Dense │ float32[16,9] │ float32[16,4] │ 1216 │ 3620 │ bias: │ │ │ │ │ │ │ │ float32[4] │ │ │ │ │ │ │ │ kernel: │ │ │ │ │ │ │ │ float32[9,4] │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ 40 (160 B) │ ├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤ │ Dense_1 │ Dense │ float32[16,4] │ float32[16,2] │ 288 │ 840 │ bias: │ │ │ │ │ │ │ │ float32[2] │ │ │ │ │ │ │ │ kernel: │ │ │ │ │ │ │ │ float32[4,2] │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ 10 (40 B) │ ├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤ │ │ │ │ │ │ Total │ 50 (200 B) │ └─────────┴────────┴───────────────┴───────────────┴───────┴───────────┴─────────────────┘ Total Parameters: 50 (200 B)
注意:表格中的行顺序不代表执行顺序,而是与 variables 中的键顺序一致,这些键按字母顺序排序。
注意:如果模块不可微,则 vjp_flops 返回 0。
- 参数
module – 要制表的模块。
rngs – 传递给 Module.init 的变量集合的 rng。
depth – 控制摘要可以深入到多少个子模块。默认情况下,它是 None,这意味着没有限制。如果由于深度限制而未显示子模块,则其参数计数和字节将添加到其第一个显示的祖先的行,以便所有行的总和始终加起来等于模块的参数总数。
show_repeated – 如果为 True,则表格中会显示对同一模块的重复调用,否则只会显示第一次调用。默认为 False。
mutable – 可以是 bool、str 或列表。指定应将哪些集合视为可变的:
bool
:所有/无集合是可变的。str
:单个可变集合的名称。list
:可变集合的名称列表。默认情况下,除“intermediates”之外的所有集合都是可变的。console_kwargs – 一个可选字典,其中包含在渲染表格时传递给 rich.console.Console 的其他关键字参数。默认参数为 {'force_terminal': True, 'force_jupyter': False}。
table_kwargs – 一个可选字典,其中包含传递给 rich.table.Table 构造函数的其他关键字参数。
column_kwargs – 一个可选字典,其中包含在向表格添加列时传递给 rich.table.Table.add_column 的其他关键字参数。
compute_flops – 是否在表格中包含一个 flops 列,列出每个模块前向传递的估计 FLOP 成本。确实会产生实际的设备上计算/编译/内存分配,但仍会为大型模块引入开销(例如,Stable Diffusion 的 UNet 会额外增加 20 秒,而其他情况下,制表将在 5 秒内完成)。
compute_vjp_flops – 是否在表格中包含一个 vjp_flops 列,列出每个模块后向传递的估计 FLOP 成本。引入了大约 2-3 倍于 compute_flops 的计算开销。
**kwargs – 传递给 Module.init 的其他参数。
- 返回值
一个接受前向传递 (method) 的相同 *args 和 **kwargs 的函数,并返回一个包含模块表格表示形式的字符串。