将 PyTorch 模型转换为 Flax

将 PyTorch 模型转换为 Flax#

我们将展示如何将 PyTorch 模型转换为 Flax。我们将介绍卷积、fc 层、批归一化和平均池化。

FC 层#

让我们从 fc 层开始。这里唯一需要注意的是,PyTorch 内核的形状为 [outC, inC],而 Flax 内核的形状为 [inC, outC]。转置内核即可解决问题。

t_fc = torch.nn.Linear(in_features=3, out_features=4)

kernel = t_fc.weight.detach().cpu().numpy()
bias = t_fc.bias.detach().cpu().numpy()

# [outC, inC] -> [inC, outC]
kernel = jnp.transpose(kernel, (1, 0))

key = random.key(0)
x = random.normal(key, (1, 3))

variables = {'params': {'kernel': kernel, 'bias': bias}}
j_fc = nn.Dense(features=4)
j_out = j_fc.apply(variables, x)

t_x = torch.from_numpy(np.array(x))
t_out = t_fc(t_x)
t_out = t_out.detach().cpu().numpy()

np.testing.assert_almost_equal(j_out, t_out, decimal=6)

卷积#

现在让我们看看 2D 卷积。PyTorch 使用 NCHW 格式,而 Flax 使用 NHWC 格式。因此,内核将具有不同的形状。PyTorch 中的内核的形状为 [outC, inC, kH, kW],而 Flax 内核的形状为 [kH, kW, inC, outC]。转置内核即可解决问题。

t_conv = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding='valid')

kernel = t_conv.weight.detach().cpu().numpy()
bias = t_conv.bias.detach().cpu().numpy()

# [outC, inC, kH, kW] -> [kH, kW, inC, outC]
kernel = jnp.transpose(kernel, (2, 3, 1, 0))

key = random.key(0)
x = random.normal(key, (1, 6, 6, 3))

variables = {'params': {'kernel': kernel, 'bias': bias}}
j_conv = nn.Conv(features=4, kernel_size=(2, 2), padding='valid')
j_out = j_conv.apply(variables, x)

# [N, H, W, C] -> [N, C, H, W]
t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2)))
t_out = t_conv(t_x)
# [N, C, H, W] -> [N, H, W, C]
t_out = np.transpose(t_out.detach().cpu().numpy(), (0, 2, 3, 1))

np.testing.assert_almost_equal(j_out, t_out, decimal=6)

卷积和 FC 层#

当模型使用卷积后跟 fc 层(ResNet、VGG 等)时,我们必须小心。在 PyTorch 中,激活在卷积后将具有形状 [N, C, H, W],然后在被馈送到 fc 层之前被重塑为 [N, C * H * W]。当我们将权重从 PyTorch 移植到 Flax 时,Flax 中卷积后的激活的形状将为 [N, H, W, C]。在我们为 fc 层重塑激活之前,我们必须将它们转置为 [N, C, H, W]。

考虑以下 PyTorch 模型

class TModel(torch.nn.Module):

  def __init__(self):
    super(TModel, self).__init__()
    self.conv = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding='valid')
    self.fc = torch.nn.Linear(in_features=100, out_features=2)

  def forward(self, x):
    x = self.conv(x)
    x = x.reshape(x.shape[0], -1)
    x = self.fc(x)
    return x


t_model = TModel()

现在,如果要在 Flax 中使用此模型的权重,则相应的 Flax 模型必须如下所示

class JModel(nn.Module):

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=4, kernel_size=(2, 2), padding='valid', name='conv')(x)
    # [N, H, W, C] -> [N, C, H, W]
    x = jnp.transpose(x, (0, 3, 1, 2))
    x = jnp.reshape(x, (x.shape[0], -1))
    x = nn.Dense(features=2, name='fc')(x)
    return x


j_model = JModel()

该模型看起来与 PyTorch 模型非常相似,只是我们在为 fc 层重塑激活之前添加了转置操作。如果我们在重塑之前应用池化,使得空间维度为 1x1,则可以省略转置操作。

除了重塑之前的转置操作外,我们可以像之前一样转换权重

conv_kernel = t_model.state_dict()['conv.weight'].detach().cpu().numpy()
conv_bias = t_model.state_dict()['conv.bias'].detach().cpu().numpy()
fc_kernel = t_model.state_dict()['fc.weight'].detach().cpu().numpy()
fc_bias = t_model.state_dict()['fc.bias'].detach().cpu().numpy()

# [outC, inC, kH, kW] -> [kH, kW, inC, outC]
conv_kernel = jnp.transpose(conv_kernel, (2, 3, 1, 0))

# [outC, inC] -> [inC, outC]
fc_kernel = jnp.transpose(fc_kernel, (1, 0))

variables = {'params': {'conv': {'kernel': conv_kernel, 'bias': conv_bias},
                        'fc': {'kernel': fc_kernel, 'bias': fc_bias}}}

key = random.key(0)
x = random.normal(key, (1, 6, 6, 3))

j_out = j_model.apply(variables, x)

# [N, H, W, C] -> [N, C, H, W]
t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2)))
t_out = t_model(t_x)
t_out = t_out.detach().cpu().numpy()

np.testing.assert_almost_equal(j_out, t_out, decimal=6)

批归一化#

torch.nn.BatchNorm2d 使用 0.1 作为 momentum 参数的默认值,而 nn.BatchNorm 使用 0.9。然而,这对应于相同的计算,因为 PyTorch 将估计的统计信息乘以 (1 momentum),并将新的观察值乘以 momentum,而 Flax 将估计的统计信息乘以 momentum,并将新的观察值乘以 (1 momentum)

t_bn = torch.nn.BatchNorm2d(num_features=3, momentum=0.1)
t_bn.eval()

scale = t_bn.weight.detach().cpu().numpy()
bias = t_bn.bias.detach().cpu().numpy()
mean = t_bn.running_mean.detach().cpu().numpy()
var = t_bn.running_var.detach().cpu().numpy()

variables = {'params': {'scale': scale, 'bias': bias},
             'batch_stats': {'mean': mean, 'var': var}}

key = random.key(0)
x = random.normal(key, (1, 6, 6, 3))

j_bn = nn.BatchNorm(momentum=0.9, use_running_average=True)

j_out = j_bn.apply(variables, x)

# [N, H, W, C] -> [N, C, H, W]
t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2)))
t_out = t_bn(t_x)
# [N, C, H, W] -> [N, H, W, C]
t_out = np.transpose(t_out.detach().cpu().numpy(), (0, 2, 3, 1))

np.testing.assert_almost_equal(j_out, t_out, decimal=6)

平均池化#

当使用默认参数时,torch.nn.AvgPool2dnn.avg_pool() 是兼容的。但是,torch.nn.AvgPool2d 有一个参数 count_include_pad。当 count_include_pad=False 时,零填充将不被考虑用于平均计算。对于 nn.avg_pool(),不存在类似的参数。但是,我们可以很容易地实现池化操作的包装器。nn.pool()nn.avg_pool()nn.max_pool() 背后的核心函数。

def avg_pool(inputs, window_shape, strides=None, padding='VALID'):
  """
  Pools the input by taking the average over a window.
  In comparison to nn.avg_pool(), this pooling operation does not
  consider the padded zero's for the average computation.
  """
  assert len(window_shape) == 2

  y = nn.pool(inputs, 0., jax.lax.add, window_shape, strides, padding)
  counts = nn.pool(jnp.ones_like(inputs), 0., jax.lax.add, window_shape, strides, padding)
  y = y / counts
  return y


key = random.key(0)
x = random.normal(key, (1, 6, 6, 3))

j_out = avg_pool(x, window_shape=(2, 2), strides=(1, 1), padding=((1, 1), (1, 1)))
t_pool = torch.nn.AvgPool2d(kernel_size=2, stride=1, padding=1, count_include_pad=False)

# [N, H, W, C] -> [N, C, H, W]
t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2)))
t_out = t_pool(t_x)
# [N, C, H, W] -> [N, H, W, C]
t_out = np.transpose(t_out.detach().cpu().numpy(), (0, 2, 3, 1))

np.testing.assert_almost_equal(j_out, t_out, decimal=6)

转置卷积#

torch.nn.ConvTranspose2dnn.ConvTranspose 不兼容。nn.ConvTransposejax.lax.conv_transpose 的包装器,它计算分数步幅卷积,而 torch.nn.ConvTranspose2d 计算基于梯度的转置卷积。目前,在 Jax 中没有基于梯度的转置卷积的实现。但是,有一个包含实现的待处理的 pull request

要将 torch.nn.ConvTranspose2d 参数加载到 Flax 中,我们需要在 Flax 的 nn.ConvTranspose 层中使用 transpose_kernel 参数。

# padding is inverted
torch_padding = 0
flax_padding = 1 - torch_padding

t_conv = torch.nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=torch_padding)

kernel = t_conv.weight.detach().cpu().numpy()
bias = t_conv.bias.detach().cpu().numpy()

# [inC, outC, kH, kW] -> [kH, kW, outC, inC]
kernel = jnp.transpose(kernel, (2, 3, 1, 0))

key = random.key(0)
x = random.normal(key, (1, 6, 6, 3))

variables = {'params': {'kernel': kernel, 'bias': bias}}
# ConvTranspose expects the kernel to be [kH, kW, inC, outC],
# but with `transpose_kernel=True`, it expects [kH, kW, outC, inC] instead
j_conv = nn.ConvTranspose(features=4, kernel_size=(2, 2), padding=flax_padding, transpose_kernel=True)
j_out = j_conv.apply(variables, x)

# [N, H, W, C] -> [N, C, H, W]
t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2)))
t_out = t_conv(t_x)
# [N, C, H, W] -> [N, H, W, C]
t_out = np.transpose(t_out.detach().cpu().numpy(), (0, 2, 3, 1))
np.testing.assert_almost_equal(j_out, t_out, decimal=6)