RNNCellBase 升级指南

RNNCellBase 升级指南#

RNNCellBase API 经历了一些旨在增强可用性的关键更新。

  • initialize_carry 方法已从类方法转换为实例方法,简化了其应用。

  • 所有必要的元数据现在直接存储在单元实例中,提供了简化的方法签名。

本指南将引导您了解这些更改,演示如何更新现有代码以符合这些增强功能。

基本用法#

让我们首先定义一些变量和一个代表序列批次的示例输入

batch_size = 32
seq_len = 10
in_features = 64
out_features = 128

x = jnp.ones((batch_size, seq_len, in_features))

首先,重要的是要注意,所有元数据,包括特征数量、携带初始化器等,现在都存储在单元实例中

cell = nn.LSTMCell()
cell = nn.LSTMCell(features=out_features)

一个重要的更改是 initialize_carry 已转换为实例方法。鉴于单元实例现在包含所有元数据,initialize_carry 方法的签名仅需要 PRNG 密钥和示例输入

carry = nn.LSTMCell.initialize_carry(jax.random.key(0), (batch_size,), out_features)
carry = cell.initialize_carry(jax.random.key(0), x[:, 0].shape)

在这里,x[:, 0].shape 表示单元的输入(不包括时间维度)。您也可以在更方便时直接创建输入形状

carry = cell.initialize_carry(jax.random.key(0), (batch_size, in_features))

升级模式#

以下部分将演示一些有用的模式,用于更新您的代码以符合新的 API。

首先,我们将展示如何升级一个包装单元的 Module,在 __call__ 期间应用扫描逻辑,并具有静态 initialize_carry 方法。在这里,我们将尝试对代码进行最少的更改以使其工作,尽管不是最惯用的方式

class SimpleLSTM(nn.Module):

  @functools.partial(
    nn.transforms.scan,
    variable_broadcast='params',
    in_axes=1, out_axes=1,
    split_rngs={'params': False})
  @nn.compact
  def __call__(self, carry, x):

    return nn.OptimizedLSTMCell()(carry, x)

  @staticmethod
  def initialize_carry(batch_dims, hidden_size):
    return nn.OptimizedLSTMCell.initialize_carry(
      jax.random.key(0), batch_dims, hidden_size)
class SimpleLSTM(nn.Module):

  @functools.partial(
    nn.transforms.scan,
    variable_broadcast='params',
    in_axes=1, out_axes=1,
    split_rngs={'params': False})
  @nn.compact
  def __call__(self, carry, x):
    features = carry[0].shape[-1]
    return nn.OptimizedLSTMCell(features)(carry, x)

  @staticmethod
  def initialize_carry(batch_dims, hidden_size):
    return nn.OptimizedLSTMCell(hidden_size, parent=None).initialize_carry(
      jax.random.key(0), (*batch_dims, hidden_size))

请注意,在新版本中,我们必须在 __call__ 期间从携带中提取特征数量,并在 initialize_carry 期间使用 parent=None 以避免一些潜在的副作用。

接下来,我们将展示一种更惯用的方式来编写类似的 LSTM 模块。这里的主要更改是将 features 属性添加到模块,并使用它在 setup 方法中初始化 nn.scan 扫描的单元版本

class SimpleLSTM(nn.Module):

  @functools.partial(
    nn.transforms.scan,
    variable_broadcast='params',
    in_axes=1, out_axes=1,
    split_rngs={'params': False})
  @nn.compact
  def __call__(self, carry, x):
    return nn.OptimizedLSTMCell()(carry, x)

  @staticmethod
  def initialize_carry(batch_dims, hidden_size):
    return nn.OptimizedLSTMCell.initialize_carry(
      jax.random.key(0), batch_dims, hidden_size)

model = SimpleLSTM()
carry = SimpleLSTM.initialize_carry((batch_size,), out_features)
variables = model.init(jax.random.key(0), carry, x)
class SimpleLSTM(nn.Module):
  features: int

  def setup(self):
    self.scan_cell = nn.transforms.scan(
      nn.OptimizedLSTMCell,
      variable_broadcast='params',
      in_axes=1, out_axes=1,
      split_rngs={'params': False})(self.features)


  @nn.compact
  def __call__(self, x):
    carry = self.scan_cell.initialize_carry(jax.random.key(0), x[:, 0].shape)
    return self.scan_cell(carry, x)[1]  # only return the output


model = SimpleLSTM(features=out_features)
variables = model.init(jax.random.key(0), x)

由于可以从示例输入轻松初始化 carry,我们可以将对 initialize_carry 的调用移动到 __call__ 方法中,从而简化代码。

开发笔记#

在开发新单元时,请考虑以下事项

  • 将必要的元数据作为实例属性包含在内。

  • initialize_carry 现在仅需要 PRNG 密钥和示例输入。

  • 需要一个新的 num_feature_axes 属性来指定特征维度的数量。

class LSTMCell(nn.RNNCellBase):
  features: int # ← All metadata is now stored within the cell instance
  ... #              ↓
  carry_init: Initializer

  def initialize_carry(self, rng, input_shape) -> Carry:
    ...

  @property
  def num_feature_axes(self):
    return 1

num_feature_axes 是一个新的 API 功能,允许代码处理任意 RNNCellBase 实例,例如 RNN 模块,以推断批次维度的数量并确定时间轴的位置。