常见问题解答 (FAQ)#
这是一系列常见问题解答(FAQ)。您可以通过在 GitHub Discussions 中发起新主题来为 Flax FAQ 做出贡献。
如何获取关于中间值的导数(使用 Module.perturb
)?#
要获取模型层内部隐藏/中间激活相对于输出的导数或梯度,您可以使用flax.linen.Module.perturb()
。您在正向传播中定义一个与中间激活具有相同形状的零值flax.linen.Module
“扰动”参数 – perturb(...)
,定义以 'perturbations'
作为附加独立参数的损失函数,在扰动参数上执行具有 jax.grad
的 JAX 导数运算。
有关完整示例和详细文档,请访问
Flax Linen 的 remat_scan()
是否与 scan(remat(...))
相同?#
Flax 的 remat_scan()
(flax.linen.remat_scan()
)和 scan(remat(...))
(flax.linen.scan()
over flax.linen.remat()
)是不同的,并且 remat_scan()
在其支持的情况下是受限制的。也就是说,remat_scan()
将输入和输出视为 carries(在训练循环中传递的隐藏状态)。建议使用 scan(remat(...))
,因为通常您需要额外的参数,例如 in_axes
(用于输入数组轴)或 out_axes
(输出数组轴),而 flax.linen.remat_scan()
不会公开这些参数。
有哪些推荐的训练循环库?#
考虑使用 CLU (Common Loop Utils) google/CommonLoopUtils。要开始使用,请访问此CLU Synopsis Colab。您可以在 google/flax GitHub Discussions 上找到关于 CLU 和 Flax 的常见问题解答。
查看官方的 google/flax 示例,了解如何将训练循环与 (CLU) 指标一起使用。例如,这是 Flax ImageNet 的 train.py。
对于计算机视觉研究,请考虑google-research/scenic。Scenic 是一组共享的轻量级库,用于解决训练大规模视觉模型时经常遇到的任务(其中包含多个项目的示例)。Scenic 是在 JAX 中使用 Flax 开发的。要开始使用,请访问 GitHub 上的 README 页面。