flaxmodels icon indicating copy to clipboard operation
flaxmodels copied to clipboard

StyleGAN2 model broken in JAX v0.4.36 and newer

Open Skylion007 opened this issue 4 months ago • 0 comments

Downloading: "https://www.dropbox.com/s/e8de1peq7p8gq9d/stylegan2_generator_ffhq.h5" to /tmp/flaxmodels/stylegan2_generator_ffhq.h5

100%|██████████| 133M/133M [00:02<00:00, 59.2MiB/s]

---------------------------------------------------------------------------

TypeError                                 Traceback (most recent call last)

[/tmp/ipython-input-4126621917.py](https://localhost:8080/#) in <cell line: 0>()
     13 # ['afhqcat', 'afhqdog', 'afhqwild', 'brecahad', 'car', 'cat', 'church', 'cifar10', 'ffhq', 'horse', 'metfaces']
     14 generator = fm.stylegan2.Generator(pretrained='ffhq')
---> 15 params = generator.init(key, z)
     16 images = generator.apply(params, z, train=False)
     17 

    [... skipping hidden 9 frame]

8 frames

[/usr/local/lib/python3.12/dist-packages/flaxmodels/stylegan2/generator.py](https://localhost:8080/#) in __call__(self, z, c, truncation_psi, truncation_cutoff, skip_w_avg_update, train, noise_mode, rng)
    699                                      name='mapping_network')(z, c, truncation_psi, truncation_cutoff, skip_w_avg_update, train)
    700 
--> 701         x = SynthesisNetwork(resolution=self.resolution_,
    702                              num_channels=self.num_channels,
    703                              w_dim=self.w_dim,

    [... skipping hidden 2 frame]

[/usr/local/lib/python3.12/dist-packages/flaxmodels/stylegan2/generator.py](https://localhost:8080/#) in __call__(self, dlatents_in, noise_mode, rng)
    562         for res in range(2, resolution_log2 + 1):
    563             init_rng, init_key = random.split(init_rng)
--> 564             x, y = SynthesisBlock(fmaps=nf(res - 1),
    565                                   res=res,
    566                                   num_layers=1 if res == 2 else 2,

    [... skipping hidden 2 frame]

[/usr/local/lib/python3.12/dist-packages/flaxmodels/stylegan2/generator.py](https://localhost:8080/#) in __call__(self, x, y, dlatents_in, noise_mode, rng)
    433         for i in range(self.num_layers):
    434             init_rng, init_key = random.split(init_rng)
--> 435             x = SynthesisLayer(fmaps=self.fmaps, 
    436                                kernel=3,
    437                                layer_idx=self.res * 2 - (5 - i) if self.res > 2 else 0,

    [... skipping hidden 2 frame]

[/usr/local/lib/python3.12/dist-packages/flaxmodels/stylegan2/generator.py](https://localhost:8080/#) in __call__(self, x, dlatents, noise_mode, rng)
    287         b = ops.equalize_lr_bias(b, self.lr_multiplier)
    288 
--> 289         x = ops.modulated_conv2d_layer(x=x, 
    290                                        w=w,
    291                                        s=s,

[/usr/local/lib/python3.12/dist-packages/flaxmodels/stylegan2/ops.py](https://localhost:8080/#) in modulated_conv2d_layer(x, w, s, fmaps, kernel, up, down, demodulate, resample_kernel, fused_modconv)
    459 
    460     # 2D convolution.
--> 461     x = conv2d(x, w.astype(x.dtype), up=up, down=down, resample_kernel=resample_kernel)
    462 
    463     # Reshape/scale output.

[/usr/local/lib/python3.12/dist-packages/flaxmodels/stylegan2/ops.py](https://localhost:8080/#) in conv2d(x, w, up, down, resample_kernel, padding)
    415     w = w.astype(x.dtype)
    416     if up:
--> 417         x = upsample_conv_2d(x, w, k=resample_kernel, padding=padding)
    418     elif down:
    419         x = conv_downsample_2d(x, w, k=resample_kernel, padding=padding)

[/usr/local/lib/python3.12/dist-packages/flaxmodels/stylegan2/ops.py](https://localhost:8080/#) in upsample_conv_2d(x, w, k, factor, gain, padding)
    401     pad0 = (k.shape[0] + factor - cw) // 2 + padding
    402     pad1 = (k.shape[0] - factor - cw + 3) // 2 + padding
--> 403     x = upfirdn2d(x=x, f=k, padding=(pad0, pad1, pad0, pad1))
    404     return x
    405 

[/usr/local/lib/python3.12/dist-packages/flaxmodels/stylegan2/ops.py](https://localhost:8080/#) in upfirdn2d(x, f, padding, up, down, strides, flip_filter, gain)
    168 
    169     # upsample by inserting zeros
--> 170     x = jnp.reshape(x, newshape=(B, H, 1, W, 1, C))
    171     x = jnp.pad(x, pad_width=((0, 0), (0, 0), (0, up - 1), (0, 0), (0, up - 1), (0, 0)))
    172     x = jnp.reshape(x, newshape=(B, H * up, W * up, C))

[/usr/local/lib/python3.12/dist-packages/jax/_src/numpy/lax_numpy.py](https://localhost:8080/#) in reshape(***failed resolving arguments***)
   2024   # TODO(jakevdp): finalized 2024-12-2; remove argument after JAX v0.4.40.
   2025   if not isinstance(newshape, DeprecatedArg):
-> 2026     raise TypeError("The newshape argument to jnp.reshape was removed in JAX v0.4.36."
   2027                     " Use shape instead.")
   2028   if shape is None:

TypeError: The newshape argument to jnp.reshape was removed in JAX v0.4.36. Use shape instead.

Requires a code-modification here to run flaxmodels with the latest JAX version. Tried using the demo stylegan2 notebook

Skylion007 avatar Sep 28 '25 21:09 Skylion007