diffusers
diffusers copied to clipboard
Not able to run FlaxStableDiffusionPipeline example of CompVis/stable-diffusion-v1-4
Describe the bug
I'm not able to run the Flax pipeline example of CompVis/stable-diffusion-v1-4 on Colab with GPU enabled. I get this error
ValueError: pmap got inconsistent sizes for array axes to be mapped:
* most axes (1530 of them) had size 1, e.g. axis 0 of args[0] of type int32[1,1,77];
* one axis had size 8: axis 0 of args[2] of type uint32[8,2]
Reproduction
I simply run the Flax example from the model page CompVis/stable-diffusion-v1-4.
First, I install dependencies like this
%%capture
%%bash
pip install --upgrade diffusers transformers scipy
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install --upgrade flax
Then, I login to HF
!huggingface-cli login
Then, I download the pipeline
from diffusers import FlaxStableDiffusionPipeline
pipeline, pipeline_params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="flax", dtype=jax.numpy.bfloat16
)
Now I try to run the pipeline to generate something giving my prompt
prompt = "a photo of an astronaut riding a horse on mars"
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
# shard inputs and rng
params = replicate(pipeline_params)
prng_seed = jax.random.split(prng_seed, 8)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
Logs
---------------------------------------------------------------------------
UnfilteredStackTrace Traceback (most recent call last)
[<ipython-input-4-8859a334219d>](https://localhost:8080/#) in <module>
14
---> 15 images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
16 images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
6 frames
[/usr/local/lib/python3.7/dist-packages/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py](https://localhost:8080/#) in __call__(self, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, return_dict, jit, debug, neg_prompt_ids, **kwargs)
352 debug,
--> 353 neg_prompt_ids,
354 )
[/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py](https://localhost:8080/#) in reraise_with_filtered_traceback(*args, **kwargs)
161 try:
--> 162 return fun(*args, **kwargs)
163 except Exception as e:
[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in cache_miss(*args, **kwargs)
2233 p = _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
-> 2234 donate_tuple, global_arg_shapes, devices, args, kwargs)
2235 for arg in p.flat_args:
[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple, global_arg_shapes, in_devices, args, kwargs)
2059 kws=True))
-> 2060 local_axis_size = _mapped_axis_size(fun, in_tree, args, in_axes_flat, "pmap")
2061
[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in _mapped_axis_size(fn, tree, vals, dims, name)
1753 msg.append(f" * some axes ({ct} of them) had size {sz}, e.g. axis {ax} of {ex};\n")
-> 1754 raise ValueError(''.join(msg)[:-2]) # remove last semicolon and newline
1755
UnfilteredStackTrace: ValueError: pmap got inconsistent sizes for array axes to be mapped:
* most axes (1530 of them) had size 1, e.g. axis 0 of args[0] of type int32[1,1,77];
* one axis had size 8: axis 0 of args[2] of type uint32[8,2]
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last)
[<ipython-input-4-8859a334219d>](https://localhost:8080/#) in <module>
13 prompt_ids = shard(prompt_ids)
14
---> 15 images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
16 images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
[/usr/local/lib/python3.7/dist-packages/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py](https://localhost:8080/#) in __call__(self, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, return_dict, jit, debug, neg_prompt_ids, **kwargs)
351 latents,
352 debug,
--> 353 neg_prompt_ids,
354 )
355 else:
ValueError: pmap got inconsistent sizes for array axes to be mapped:
* most axes (1530 of them) had size 1, e.g. axis 0 of args[0] of type int32[1,1,77];
* one axis had size 8: axis 0 of args[2] of type uint32[8,2]
System Info
Using Colab with GPU enabled.