检查

目录

检查#

flax.linen.tabulate(module, rngs, depth=None, show_repeated=False, mutable=DenyList(deny='intermediates'), console_kwargs=None, table_kwargs=mappingproxy({}), column_kwargs=mappingproxy({}), compute_flops=False, compute_vjp_flops=False, **kwargs)[source]#

返回一个函数,该函数创建模块的摘要,表示为表格。

此函数接受大多数相同的参数,并在内部调用 Module.init,除了它返回 (*args, **kwargs) -> str 形式的函数,其中 *args**kwargs 在前向传递期间传递给 method (例如 __call__)。

tabulate 在底层使用 jax.eval_shape 来运行前向计算,而不会消耗任何 FLOP 或分配内存。

其他参数可以传递到 console_kwargs 参数中,例如 {'width': 120}。有关 console_kwargs 参数的完整列表,请参阅:https://rich.pythonlang.cn/en/stable/reference/console.html#rich.console.Console

示例

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     h = nn.Dense(4)(x)
...     return nn.Dense(2)(h)

>>> x = jnp.ones((16, 9))
>>> tabulate_fn = nn.tabulate(
...     Foo(), jax.random.key(0), compute_flops=True, compute_vjp_flops=True)

>>> # print(tabulate_fn(x))

这会给出以下输出

                                       Foo Summary
┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃ path    ┃ module ┃ inputs        ┃ outputs       ┃ flops ┃ vjp_flops ┃ params          ┃
┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│         │ Foo    │ float32[16,9] │ float32[16,2] │ 1504  │ 4460      │                 │
├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤
│ Dense_0 │ Dense  │ float32[16,9] │ float32[16,4] │ 1216  │ 3620      │ bias:           │
│         │        │               │               │       │           │ float32[4]      │
│         │        │               │               │       │           │ kernel:         │
│         │        │               │               │       │           │ float32[9,4]    │
│         │        │               │               │       │           │                 │
│         │        │               │               │       │           │ 40 (160 B)      │
├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤
│ Dense_1 │ Dense  │ float32[16,4] │ float32[16,2] │ 288   │ 840       │ bias:           │
│         │        │               │               │       │           │ float32[2]      │
│         │        │               │               │       │           │ kernel:         │
│         │        │               │               │       │           │ float32[4,2]    │
│         │        │               │               │       │           │                 │
│         │        │               │               │       │           │ 10 (40 B)       │
├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤
│         │        │               │               │       │     Total │ 50 (200 B)      │
└─────────┴────────┴───────────────┴───────────────┴───────┴───────────┴─────────────────┘

                               Total Parameters: 50 (200 B)

注意:表格中的行顺序不代表执行顺序,而是与 variables 中的键顺序一致,这些键按字母顺序排序。

注意:如果模块不可微,则 vjp_flops 返回 0

参数
  • module – 要制表的模块。

  • rngs – 传递给 Module.init 的变量集合的 rng。

  • depth – 控制摘要可以深入到多少个子模块。默认情况下,它是 None,这意味着没有限制。如果由于深度限制而未显示子模块,则其参数计数和字节将添加到其第一个显示的祖先的行,以便所有行的总和始终加起来等于模块的参数总数。

  • show_repeated – 如果为 True,则表格中会显示对同一模块的重复调用,否则只会显示第一次调用。默认为 False

  • mutable – 可以是 bool、str 或列表。指定应将哪些集合视为可变的:bool:所有/无集合是可变的。str:单个可变集合的名称。list:可变集合的名称列表。默认情况下,除“intermediates”之外的所有集合都是可变的。

  • console_kwargs – 一个可选字典,其中包含在渲染表格时传递给 rich.console.Console 的其他关键字参数。默认参数为 {'force_terminal': True, 'force_jupyter': False}

  • table_kwargs – 一个可选字典,其中包含传递给 rich.table.Table 构造函数的其他关键字参数。

  • column_kwargs – 一个可选字典,其中包含在向表格添加列时传递给 rich.table.Table.add_column 的其他关键字参数。

  • compute_flops – 是否在表格中包含一个 flops 列,列出每个模块前向传递的估计 FLOP 成本。确实会产生实际的设备上计算/编译/内存分配,但仍会为大型模块引入开销(例如,Stable Diffusion 的 UNet 会额外增加 20 秒,而其他情况下,制表将在 5 秒内完成)。

  • compute_vjp_flops – 是否在表格中包含一个 vjp_flops 列,列出每个模块后向传递的估计 FLOP 成本。引入了大约 2-3 倍于 compute_flops 的计算开销。

  • **kwargs – 传递给 Module.init 的其他参数。

返回值

一个接受前向传递 (method) 的相同 *args**kwargs 的函数,并返回一个包含模块表格表示形式的字符串。