常见问题解答 (FAQ)#

这是一系列常见问题解答(FAQ)。您可以通过在 GitHub Discussions 中发起新主题来为 Flax FAQ 做出贡献。

如何获取关于中间值的导数(使用 Module.perturb)?#

要获取模型层内部隐藏/中间激活相对于输出的导数或梯度,您可以使用flax.linen.Module.perturb()。您在正向传播中定义一个与中间激活具有相同形状的零值flax.linen.Module“扰动”参数 – perturb(...),定义以 'perturbations' 作为附加独立参数的损失函数,在扰动参数上执行具有 jax.grad 的 JAX 导数运算。

有关完整示例和详细文档,请访问

Flax Linen 的 remat_scan() 是否与 scan(remat(...)) 相同?#

Flax 的 remat_scan()flax.linen.remat_scan())和 scan(remat(...))flax.linen.scan() over flax.linen.remat())是不同的,并且 remat_scan() 在其支持的情况下是受限制的。也就是说,remat_scan() 将输入和输出视为 carries(在训练循环中传递的隐藏状态)。建议使用 scan(remat(...)),因为通常您需要额外的参数,例如 in_axes(用于输入数组轴)或 out_axes(输出数组轴),而 flax.linen.remat_scan() 不会公开这些参数。