Flax nnx ConvTranspose Does Not Restore Input Shape When Used with Conv (Unexpected Behavior)
When using Flax’s Conv and ConvTranspose layers in pair, the ConvTranspose does not seem to correctly restore the original input shape, even when the parameters are set in a way that should theoretically allow this. This behavior differs from PyTorch, where ConvXd and ConvTransposeXd used together reliably restore the input shape.
ConvTranspose function appears to produce incorrect output shapes, sometimes resulting in dimensions collapsing to zero. This behavior is not just a mismatch with PyTorch, but makes the function effectively unusable in certain cases.
Reproduction Example
from jax import random
from flax import nnx
import torch
from torch import nn
key = random.PRNGKey(42)
batch_size = 4
in_channels = 128
out_channels = 32
i = 4
k = 3
s = 1
p = 0
# ============= Flax ===========================
x = random.uniform(key, shape=(batch_size, i, i, in_channels))
conv = nnx.Conv(in_features=in_channels,
out_features=out_channels,
kernel_size=(k, k),
strides=(s, s),
padding=p,
rngs=nnx.Rngs(0))
y = conv(x)
print(y.shape) # (4, 2, 2, 32)
assert y.shape[2] == 2
tconv = nnx.ConvTranspose(in_features=out_channels,
out_features=in_channels,
kernel_size=(k, k),
strides=(s, s),
padding=p,
rngs=nnx.Rngs(0))
z = tconv(y)
print(z.shape) # (4, 0, 0, 128)
if z.shape[2] != i:
print(f"Flax transConv failed to restore original input shape.")
# ============= PyTorch ========================
x = torch.rand(batch_size, in_channels, i, i)
conv = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=k,
stride=s,
padding=p)
y = conv(x)
print(y.shape) # torch.Size([4, 32, 2, 2])
assert y.shape == (batch_size, out_channels, 2, 2)
kp = k
sp = s
pp = k - 1
ip = 2
op = ip + (k-1)
tconv = nn.ConvTranspose2d(in_channels=out_channels,
out_channels=in_channels,
kernel_size=k,
stride=s,
padding=p)
z = tconv(y)
print(z.shape) # torch.Size([4, 128, 4, 4])
assert z.shape[2] == i
Hey @Stella-S-Yan, using p = 'VALID' will give you the shapes you specify there:
batch_size = 4
in_channels = 128
out_channels = 32
i = 4
k = 3
s = 1
p = 'VALID'
key = jax.random.key(0)
# ============= Flax ===========================
x = jax.random.uniform(key, shape=(batch_size, i, i, in_channels))
print(f'{x.shape = }')
conv = nnx.Conv(
in_features=in_channels,
out_features=out_channels,
kernel_size=(k, k),
strides=(s, s),
padding=p,
rngs=nnx.Rngs(0),
)
y = conv(x)
print(y.shape) # (4, 2, 2, 32)
# assert y.shape[2] == 2
tconv = nnx.ConvTranspose(
in_features=out_channels,
out_features=in_channels,
kernel_size=(k, k),
strides=(s, s),
padding=p,
rngs=nnx.Rngs(0),
)
z = tconv(y)
print(z.shape) # (4, 0, 0, 128)
if z.shape[2] != i:
print(f'Flax transConv failed to restore original input shape.')
Setting to "VALID" means zero padding. But explicitly setting to 0 padding gives wrong result, and this is something we should fix. Here is another example, that no matter setting to "VALID" or "SAME" or "CIRCULAR" will not work.
from jax import random
from flax import nnx
import torch
from torch import nn
key = random.PRNGKey(42)
batch_size = 4
in_channels = 128
out_channels = 32
i = 5
k = 4
s = 1
p = 2
ip = 6
# ============= Flax ===========================
x = random.uniform(key, shape=(batch_size, i, i, in_channels))
conv = nnx.Conv(in_features=in_channels,
out_features=out_channels,
kernel_size=(k, k),
strides=(s, s),
padding=p,
rngs=nnx.Rngs(0))
y = conv(x)
print(y.shape)
assert y.shape[2] == ip
tconv = nnx.ConvTranspose(in_features=out_channels,
out_features=in_channels,
kernel_size=(k, k),
strides=(s, s),
padding="VALID",
rngs=nnx.Rngs(0))
z = tconv(y)
print(z.shape)
if z.shape[2] != i:
print(f"Flax transConv failed to restore original input shape.")
# ============= PyTorch ========================
x = torch.rand(batch_size, in_channels, i, i)
conv = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=k,
stride=s,
padding=p)
y = conv(x)
print(y.shape) # torch.Size([4, 32, 2, 2])
assert y.shape[2] == ip
kp = k
sp = s
pp = k - 1
ip = 2
op = ip + (k-1)
tconv = nn.ConvTranspose2d(in_channels=out_channels,
out_channels=in_channels,
kernel_size=k,
stride=s,
padding=p)
z = tconv(y)
print(z.shape)
assert z.shape[2] == i
The padding argument to nnx.ConvTranspose is serving a different role than the padding argument of torch.ConvTranspose2d. According to the pytorch docs:
The padding argument effectively adds dilation * (kernel_size - 1) - padding amount of zero padding to both sizes of the input. This is set so that when a Conv2d and a ConvTranspose2d are initialized with same parameters, they are inverses of each other in regard to the input and output shapes.
On the other hand, for nnx.ConvTranspose, the padding argument literally means the amount of padding to add to both sides of the input. It is passed more or less directly to lax.conv_transpose. In your first example, the padding to ConvTranspose should be 1 * (3 - 1) - 0 = 2 to match what pytorch would do.
Counterintuitively, 'VALID' does not mean zero padding for lax.conv_transpose (and therefore for nnx.ConvTranspose). According to the jax docs for the padding argument:
padding (str | Sequence[tuple[int, int]]) – ‘SAME’, ‘VALID’ will set as transpose of corresponding forward conv, or a sequence of n integer 2-tuples describing before-and-after padding for each n spatial dimension.
So in your initial case, using 'VALID' with the ConvTranspose would actually mean using a padding of 2, which is the padding we'd need to turn the output of an ordinary Conv with padding="VALID" back into the shape of its input. Now, is this convention within jax highly unintuitive? Absolutely. But given that this is the official behavior within jax, I don't know if flax should behave any differently.
For reference, I found the discussion of padding in transposed convolutions with jax here to be helpful.
After some discussion, it seems like this unintuitive behavior of lax.conv_transpose (in which string paddings are interpreted as the paddings for the ordinary convolution by integer paddings are interpreted as the padding for the transpose convolution) was not something the jax authors intended. It's a bug, just like you thought. The PR above (for the jax repo) should address the issue.