diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

Not able to run FlaxStableDiffusionPipeline example of CompVis/stable-diffusion-v1-4

Open dzlab opened this issue 3 years ago • 0 comments

Describe the bug

I'm not able to run the Flax pipeline example of CompVis/stable-diffusion-v1-4 on Colab with GPU enabled. I get this error

ValueError: pmap got inconsistent sizes for array axes to be mapped:
  * most axes (1530 of them) had size 1, e.g. axis 0 of args[0] of type int32[1,1,77];
  * one axis had size 8: axis 0 of args[2] of type uint32[8,2]

Reproduction

I simply run the Flax example from the model page CompVis/stable-diffusion-v1-4.

First, I install dependencies like this

%%capture
%%bash

pip install --upgrade diffusers transformers scipy
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install --upgrade flax

Then, I login to HF

!huggingface-cli login

Then, I download the pipeline

from diffusers import FlaxStableDiffusionPipeline

pipeline, pipeline_params = FlaxStableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4", revision="flax", dtype=jax.numpy.bfloat16
)

Now I try to run the pipeline to generate something giving my prompt

prompt = "a photo of an astronaut riding a horse on mars"

prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50

num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)

# shard inputs and rng
params = replicate(pipeline_params)
prng_seed = jax.random.split(prng_seed, 8)
prompt_ids = shard(prompt_ids)

images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))

Logs

---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
[<ipython-input-4-8859a334219d>](https://localhost:8080/#) in <module>
     14 
---> 15 images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
     16 images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))

6 frames
[/usr/local/lib/python3.7/dist-packages/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py](https://localhost:8080/#) in __call__(self, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, return_dict, jit, debug, neg_prompt_ids, **kwargs)
    352                 debug,
--> 353                 neg_prompt_ids,
    354             )

[/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py](https://localhost:8080/#) in reraise_with_filtered_traceback(*args, **kwargs)
    161     try:
--> 162       return fun(*args, **kwargs)
    163     except Exception as e:

[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in cache_miss(*args, **kwargs)
   2233     p = _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
-> 2234                       donate_tuple, global_arg_shapes, devices, args, kwargs)
   2235     for arg in p.flat_args:

[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple, global_arg_shapes, in_devices, args, kwargs)
   2059       kws=True))
-> 2060   local_axis_size = _mapped_axis_size(fun, in_tree, args, in_axes_flat, "pmap")
   2061 

[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in _mapped_axis_size(fn, tree, vals, dims, name)
   1753       msg.append(f"  * some axes ({ct} of them) had size {sz}, e.g. axis {ax} of {ex};\n")
-> 1754   raise ValueError(''.join(msg)[:-2])  # remove last semicolon and newline
   1755 

UnfilteredStackTrace: ValueError: pmap got inconsistent sizes for array axes to be mapped:
  * most axes (1530 of them) had size 1, e.g. axis 0 of args[0] of type int32[1,1,77];
  * one axis had size 8: axis 0 of args[2] of type uint32[8,2]

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
[<ipython-input-4-8859a334219d>](https://localhost:8080/#) in <module>
     13 prompt_ids = shard(prompt_ids)
     14 
---> 15 images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
     16 images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))

[/usr/local/lib/python3.7/dist-packages/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py](https://localhost:8080/#) in __call__(self, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, return_dict, jit, debug, neg_prompt_ids, **kwargs)
    351                 latents,
    352                 debug,
--> 353                 neg_prompt_ids,
    354             )
    355         else:

ValueError: pmap got inconsistent sizes for array axes to be mapped:
  * most axes (1530 of them) had size 1, e.g. axis 0 of args[0] of type int32[1,1,77];
  * one axis had size 8: axis 0 of args[2] of type uint32[8,2]

System Info

Using Colab with GPU enabled.

dzlab avatar Nov 28 '22 12:11 dzlab