RNNCellBase 升级指南#
RNNCellBase
API 经历了一些旨在增强可用性的关键更新。
initialize_carry
方法已从类方法转换为实例方法,简化了其应用。所有必要的元数据现在直接存储在单元实例中,提供了简化的方法签名。
本指南将引导您了解这些更改,演示如何更新现有代码以符合这些增强功能。
基本用法#
让我们首先定义一些变量和一个代表序列批次的示例输入
batch_size = 32
seq_len = 10
in_features = 64
out_features = 128
x = jnp.ones((batch_size, seq_len, in_features))
首先,重要的是要注意,所有元数据,包括特征数量、携带初始化器等,现在都存储在单元实例中
cell = nn.LSTMCell()
cell = nn.LSTMCell(features=out_features)
一个重要的更改是 initialize_carry
已转换为实例方法。鉴于单元实例现在包含所有元数据,initialize_carry
方法的签名仅需要 PRNG 密钥和示例输入
carry = nn.LSTMCell.initialize_carry(jax.random.key(0), (batch_size,), out_features)
carry = cell.initialize_carry(jax.random.key(0), x[:, 0].shape)
在这里,x[:, 0].shape
表示单元的输入(不包括时间维度)。您也可以在更方便时直接创建输入形状
carry = cell.initialize_carry(jax.random.key(0), (batch_size, in_features))
升级模式#
以下部分将演示一些有用的模式,用于更新您的代码以符合新的 API。
首先,我们将展示如何升级一个包装单元的 Module
,在 __call__
期间应用扫描逻辑,并具有静态 initialize_carry
方法。在这里,我们将尝试对代码进行最少的更改以使其工作,尽管不是最惯用的方式
class SimpleLSTM(nn.Module):
@functools.partial(
nn.transforms.scan,
variable_broadcast='params',
in_axes=1, out_axes=1,
split_rngs={'params': False})
@nn.compact
def __call__(self, carry, x):
return nn.OptimizedLSTMCell()(carry, x)
@staticmethod
def initialize_carry(batch_dims, hidden_size):
return nn.OptimizedLSTMCell.initialize_carry(
jax.random.key(0), batch_dims, hidden_size)
class SimpleLSTM(nn.Module):
@functools.partial(
nn.transforms.scan,
variable_broadcast='params',
in_axes=1, out_axes=1,
split_rngs={'params': False})
@nn.compact
def __call__(self, carry, x):
features = carry[0].shape[-1]
return nn.OptimizedLSTMCell(features)(carry, x)
@staticmethod
def initialize_carry(batch_dims, hidden_size):
return nn.OptimizedLSTMCell(hidden_size, parent=None).initialize_carry(
jax.random.key(0), (*batch_dims, hidden_size))
请注意,在新版本中,我们必须在 __call__
期间从携带中提取特征数量,并在 initialize_carry
期间使用 parent=None
以避免一些潜在的副作用。
接下来,我们将展示一种更惯用的方式来编写类似的 LSTM 模块。这里的主要更改是将 features
属性添加到模块,并使用它在 setup
方法中初始化 nn.scan
扫描的单元版本
class SimpleLSTM(nn.Module):
@functools.partial(
nn.transforms.scan,
variable_broadcast='params',
in_axes=1, out_axes=1,
split_rngs={'params': False})
@nn.compact
def __call__(self, carry, x):
return nn.OptimizedLSTMCell()(carry, x)
@staticmethod
def initialize_carry(batch_dims, hidden_size):
return nn.OptimizedLSTMCell.initialize_carry(
jax.random.key(0), batch_dims, hidden_size)
model = SimpleLSTM()
carry = SimpleLSTM.initialize_carry((batch_size,), out_features)
variables = model.init(jax.random.key(0), carry, x)
class SimpleLSTM(nn.Module):
features: int
def setup(self):
self.scan_cell = nn.transforms.scan(
nn.OptimizedLSTMCell,
variable_broadcast='params',
in_axes=1, out_axes=1,
split_rngs={'params': False})(self.features)
@nn.compact
def __call__(self, x):
carry = self.scan_cell.initialize_carry(jax.random.key(0), x[:, 0].shape)
return self.scan_cell(carry, x)[1] # only return the output
model = SimpleLSTM(features=out_features)
variables = model.init(jax.random.key(0), x)
由于可以从示例输入轻松初始化 carry
,我们可以将对 initialize_carry
的调用移动到 __call__
方法中,从而简化代码。
开发笔记#
在开发新单元时,请考虑以下事项
将必要的元数据作为实例属性包含在内。
initialize_carry
现在仅需要 PRNG 密钥和示例输入。需要一个新的
num_feature_axes
属性来指定特征维度的数量。
class LSTMCell(nn.RNNCellBase):
features: int # ← All metadata is now stored within the cell instance
... # ↓
carry_init: Initializer
def initialize_carry(self, rng, input_shape) -> Carry:
...
@property
def num_feature_axes(self):
return 1
num_feature_axes
是一个新的 API 功能,允许代码处理任意 RNNCellBase
实例,例如 RNN
模块,以推断批次维度的数量并确定时间轴的位置。