将 PyTorch 模型转换为 Flax#
我们将展示如何将 PyTorch 模型转换为 Flax。我们将介绍卷积、fc 层、批归一化和平均池化。
FC 层#
让我们从 fc 层开始。这里唯一需要注意的是,PyTorch 内核的形状为 [outC, inC],而 Flax 内核的形状为 [inC, outC]。转置内核即可解决问题。
t_fc = torch.nn.Linear(in_features=3, out_features=4)
kernel = t_fc.weight.detach().cpu().numpy()
bias = t_fc.bias.detach().cpu().numpy()
# [outC, inC] -> [inC, outC]
kernel = jnp.transpose(kernel, (1, 0))
key = random.key(0)
x = random.normal(key, (1, 3))
variables = {'params': {'kernel': kernel, 'bias': bias}}
j_fc = nn.Dense(features=4)
j_out = j_fc.apply(variables, x)
t_x = torch.from_numpy(np.array(x))
t_out = t_fc(t_x)
t_out = t_out.detach().cpu().numpy()
np.testing.assert_almost_equal(j_out, t_out, decimal=6)
卷积#
现在让我们看看 2D 卷积。PyTorch 使用 NCHW 格式,而 Flax 使用 NHWC 格式。因此,内核将具有不同的形状。PyTorch 中的内核的形状为 [outC, inC, kH, kW],而 Flax 内核的形状为 [kH, kW, inC, outC]。转置内核即可解决问题。
t_conv = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding='valid')
kernel = t_conv.weight.detach().cpu().numpy()
bias = t_conv.bias.detach().cpu().numpy()
# [outC, inC, kH, kW] -> [kH, kW, inC, outC]
kernel = jnp.transpose(kernel, (2, 3, 1, 0))
key = random.key(0)
x = random.normal(key, (1, 6, 6, 3))
variables = {'params': {'kernel': kernel, 'bias': bias}}
j_conv = nn.Conv(features=4, kernel_size=(2, 2), padding='valid')
j_out = j_conv.apply(variables, x)
# [N, H, W, C] -> [N, C, H, W]
t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2)))
t_out = t_conv(t_x)
# [N, C, H, W] -> [N, H, W, C]
t_out = np.transpose(t_out.detach().cpu().numpy(), (0, 2, 3, 1))
np.testing.assert_almost_equal(j_out, t_out, decimal=6)
卷积和 FC 层#
当模型使用卷积后跟 fc 层(ResNet、VGG 等)时,我们必须小心。在 PyTorch 中,激活在卷积后将具有形状 [N, C, H, W],然后在被馈送到 fc 层之前被重塑为 [N, C * H * W]。当我们将权重从 PyTorch 移植到 Flax 时,Flax 中卷积后的激活的形状将为 [N, H, W, C]。在我们为 fc 层重塑激活之前,我们必须将它们转置为 [N, C, H, W]。
考虑以下 PyTorch 模型
class TModel(torch.nn.Module):
def __init__(self):
super(TModel, self).__init__()
self.conv = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding='valid')
self.fc = torch.nn.Linear(in_features=100, out_features=2)
def forward(self, x):
x = self.conv(x)
x = x.reshape(x.shape[0], -1)
x = self.fc(x)
return x
t_model = TModel()
现在,如果要在 Flax 中使用此模型的权重,则相应的 Flax 模型必须如下所示
class JModel(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=4, kernel_size=(2, 2), padding='valid', name='conv')(x)
# [N, H, W, C] -> [N, C, H, W]
x = jnp.transpose(x, (0, 3, 1, 2))
x = jnp.reshape(x, (x.shape[0], -1))
x = nn.Dense(features=2, name='fc')(x)
return x
j_model = JModel()
该模型看起来与 PyTorch 模型非常相似,只是我们在为 fc 层重塑激活之前添加了转置操作。如果我们在重塑之前应用池化,使得空间维度为 1x1,则可以省略转置操作。
除了重塑之前的转置操作外,我们可以像之前一样转换权重
conv_kernel = t_model.state_dict()['conv.weight'].detach().cpu().numpy()
conv_bias = t_model.state_dict()['conv.bias'].detach().cpu().numpy()
fc_kernel = t_model.state_dict()['fc.weight'].detach().cpu().numpy()
fc_bias = t_model.state_dict()['fc.bias'].detach().cpu().numpy()
# [outC, inC, kH, kW] -> [kH, kW, inC, outC]
conv_kernel = jnp.transpose(conv_kernel, (2, 3, 1, 0))
# [outC, inC] -> [inC, outC]
fc_kernel = jnp.transpose(fc_kernel, (1, 0))
variables = {'params': {'conv': {'kernel': conv_kernel, 'bias': conv_bias},
'fc': {'kernel': fc_kernel, 'bias': fc_bias}}}
key = random.key(0)
x = random.normal(key, (1, 6, 6, 3))
j_out = j_model.apply(variables, x)
# [N, H, W, C] -> [N, C, H, W]
t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2)))
t_out = t_model(t_x)
t_out = t_out.detach().cpu().numpy()
np.testing.assert_almost_equal(j_out, t_out, decimal=6)
批归一化#
torch.nn.BatchNorm2d
使用 0.1
作为 momentum
参数的默认值,而 nn.BatchNorm
使用 0.9
。然而,这对应于相同的计算,因为 PyTorch 将估计的统计信息乘以 (1 − momentum)
,并将新的观察值乘以 momentum
,而 Flax 将估计的统计信息乘以 momentum
,并将新的观察值乘以 (1 − momentum)
。
t_bn = torch.nn.BatchNorm2d(num_features=3, momentum=0.1)
t_bn.eval()
scale = t_bn.weight.detach().cpu().numpy()
bias = t_bn.bias.detach().cpu().numpy()
mean = t_bn.running_mean.detach().cpu().numpy()
var = t_bn.running_var.detach().cpu().numpy()
variables = {'params': {'scale': scale, 'bias': bias},
'batch_stats': {'mean': mean, 'var': var}}
key = random.key(0)
x = random.normal(key, (1, 6, 6, 3))
j_bn = nn.BatchNorm(momentum=0.9, use_running_average=True)
j_out = j_bn.apply(variables, x)
# [N, H, W, C] -> [N, C, H, W]
t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2)))
t_out = t_bn(t_x)
# [N, C, H, W] -> [N, H, W, C]
t_out = np.transpose(t_out.detach().cpu().numpy(), (0, 2, 3, 1))
np.testing.assert_almost_equal(j_out, t_out, decimal=6)
平均池化#
当使用默认参数时,torch.nn.AvgPool2d
和 nn.avg_pool()
是兼容的。但是,torch.nn.AvgPool2d
有一个参数 count_include_pad
。当 count_include_pad=False
时,零填充将不被考虑用于平均计算。对于 nn.avg_pool()
,不存在类似的参数。但是,我们可以很容易地实现池化操作的包装器。nn.pool()
是 nn.avg_pool()
和 nn.max_pool()
背后的核心函数。
def avg_pool(inputs, window_shape, strides=None, padding='VALID'):
"""
Pools the input by taking the average over a window.
In comparison to nn.avg_pool(), this pooling operation does not
consider the padded zero's for the average computation.
"""
assert len(window_shape) == 2
y = nn.pool(inputs, 0., jax.lax.add, window_shape, strides, padding)
counts = nn.pool(jnp.ones_like(inputs), 0., jax.lax.add, window_shape, strides, padding)
y = y / counts
return y
key = random.key(0)
x = random.normal(key, (1, 6, 6, 3))
j_out = avg_pool(x, window_shape=(2, 2), strides=(1, 1), padding=((1, 1), (1, 1)))
t_pool = torch.nn.AvgPool2d(kernel_size=2, stride=1, padding=1, count_include_pad=False)
# [N, H, W, C] -> [N, C, H, W]
t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2)))
t_out = t_pool(t_x)
# [N, C, H, W] -> [N, H, W, C]
t_out = np.transpose(t_out.detach().cpu().numpy(), (0, 2, 3, 1))
np.testing.assert_almost_equal(j_out, t_out, decimal=6)
转置卷积#
torch.nn.ConvTranspose2d
和 nn.ConvTranspose
不兼容。nn.ConvTranspose
是 jax.lax.conv_transpose
的包装器,它计算分数步幅卷积,而 torch.nn.ConvTranspose2d
计算基于梯度的转置卷积。目前,在 Jax
中没有基于梯度的转置卷积的实现。但是,有一个包含实现的待处理的 pull request。
要将 torch.nn.ConvTranspose2d
参数加载到 Flax 中,我们需要在 Flax 的 nn.ConvTranspose
层中使用 transpose_kernel
参数。
# padding is inverted
torch_padding = 0
flax_padding = 1 - torch_padding
t_conv = torch.nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=torch_padding)
kernel = t_conv.weight.detach().cpu().numpy()
bias = t_conv.bias.detach().cpu().numpy()
# [inC, outC, kH, kW] -> [kH, kW, outC, inC]
kernel = jnp.transpose(kernel, (2, 3, 1, 0))
key = random.key(0)
x = random.normal(key, (1, 6, 6, 3))
variables = {'params': {'kernel': kernel, 'bias': bias}}
# ConvTranspose expects the kernel to be [kH, kW, inC, outC],
# but with `transpose_kernel=True`, it expects [kH, kW, outC, inC] instead
j_conv = nn.ConvTranspose(features=4, kernel_size=(2, 2), padding=flax_padding, transpose_kernel=True)
j_out = j_conv.apply(variables, x)
# [N, H, W, C] -> [N, C, H, W]
t_x = torch.from_numpy(np.transpose(np.array(x), (0, 3, 1, 2)))
t_out = t_conv(t_x)
# [N, C, H, W] -> [N, H, W, C]
t_out = np.transpose(t_out.detach().cpu().numpy(), (0, 2, 3, 1))
np.testing.assert_almost_equal(j_out, t_out, decimal=6)