flaxmodels
flaxmodels copied to clipboard
StyleGAN2 model broken in JAX v0.4.36 and newer
Downloading: "https://www.dropbox.com/s/e8de1peq7p8gq9d/stylegan2_generator_ffhq.h5" to /tmp/flaxmodels/stylegan2_generator_ffhq.h5
100%|██████████| 133M/133M [00:02<00:00, 59.2MiB/s]
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
[/tmp/ipython-input-4126621917.py](https://localhost:8080/#) in <cell line: 0>()
13 # ['afhqcat', 'afhqdog', 'afhqwild', 'brecahad', 'car', 'cat', 'church', 'cifar10', 'ffhq', 'horse', 'metfaces']
14 generator = fm.stylegan2.Generator(pretrained='ffhq')
---> 15 params = generator.init(key, z)
16 images = generator.apply(params, z, train=False)
17
[... skipping hidden 9 frame]
8 frames
[/usr/local/lib/python3.12/dist-packages/flaxmodels/stylegan2/generator.py](https://localhost:8080/#) in __call__(self, z, c, truncation_psi, truncation_cutoff, skip_w_avg_update, train, noise_mode, rng)
699 name='mapping_network')(z, c, truncation_psi, truncation_cutoff, skip_w_avg_update, train)
700
--> 701 x = SynthesisNetwork(resolution=self.resolution_,
702 num_channels=self.num_channels,
703 w_dim=self.w_dim,
[... skipping hidden 2 frame]
[/usr/local/lib/python3.12/dist-packages/flaxmodels/stylegan2/generator.py](https://localhost:8080/#) in __call__(self, dlatents_in, noise_mode, rng)
562 for res in range(2, resolution_log2 + 1):
563 init_rng, init_key = random.split(init_rng)
--> 564 x, y = SynthesisBlock(fmaps=nf(res - 1),
565 res=res,
566 num_layers=1 if res == 2 else 2,
[... skipping hidden 2 frame]
[/usr/local/lib/python3.12/dist-packages/flaxmodels/stylegan2/generator.py](https://localhost:8080/#) in __call__(self, x, y, dlatents_in, noise_mode, rng)
433 for i in range(self.num_layers):
434 init_rng, init_key = random.split(init_rng)
--> 435 x = SynthesisLayer(fmaps=self.fmaps,
436 kernel=3,
437 layer_idx=self.res * 2 - (5 - i) if self.res > 2 else 0,
[... skipping hidden 2 frame]
[/usr/local/lib/python3.12/dist-packages/flaxmodels/stylegan2/generator.py](https://localhost:8080/#) in __call__(self, x, dlatents, noise_mode, rng)
287 b = ops.equalize_lr_bias(b, self.lr_multiplier)
288
--> 289 x = ops.modulated_conv2d_layer(x=x,
290 w=w,
291 s=s,
[/usr/local/lib/python3.12/dist-packages/flaxmodels/stylegan2/ops.py](https://localhost:8080/#) in modulated_conv2d_layer(x, w, s, fmaps, kernel, up, down, demodulate, resample_kernel, fused_modconv)
459
460 # 2D convolution.
--> 461 x = conv2d(x, w.astype(x.dtype), up=up, down=down, resample_kernel=resample_kernel)
462
463 # Reshape/scale output.
[/usr/local/lib/python3.12/dist-packages/flaxmodels/stylegan2/ops.py](https://localhost:8080/#) in conv2d(x, w, up, down, resample_kernel, padding)
415 w = w.astype(x.dtype)
416 if up:
--> 417 x = upsample_conv_2d(x, w, k=resample_kernel, padding=padding)
418 elif down:
419 x = conv_downsample_2d(x, w, k=resample_kernel, padding=padding)
[/usr/local/lib/python3.12/dist-packages/flaxmodels/stylegan2/ops.py](https://localhost:8080/#) in upsample_conv_2d(x, w, k, factor, gain, padding)
401 pad0 = (k.shape[0] + factor - cw) // 2 + padding
402 pad1 = (k.shape[0] - factor - cw + 3) // 2 + padding
--> 403 x = upfirdn2d(x=x, f=k, padding=(pad0, pad1, pad0, pad1))
404 return x
405
[/usr/local/lib/python3.12/dist-packages/flaxmodels/stylegan2/ops.py](https://localhost:8080/#) in upfirdn2d(x, f, padding, up, down, strides, flip_filter, gain)
168
169 # upsample by inserting zeros
--> 170 x = jnp.reshape(x, newshape=(B, H, 1, W, 1, C))
171 x = jnp.pad(x, pad_width=((0, 0), (0, 0), (0, up - 1), (0, 0), (0, up - 1), (0, 0)))
172 x = jnp.reshape(x, newshape=(B, H * up, W * up, C))
[/usr/local/lib/python3.12/dist-packages/jax/_src/numpy/lax_numpy.py](https://localhost:8080/#) in reshape(***failed resolving arguments***)
2024 # TODO(jakevdp): finalized 2024-12-2; remove argument after JAX v0.4.40.
2025 if not isinstance(newshape, DeprecatedArg):
-> 2026 raise TypeError("The newshape argument to jnp.reshape was removed in JAX v0.4.36."
2027 " Use shape instead.")
2028 if shape is None:
TypeError: The newshape argument to jnp.reshape was removed in JAX v0.4.36. Use shape instead.
Requires a code-modification here to run flaxmodels with the latest JAX version. Tried using the demo stylegan2 notebook