diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

Slow SDXL inference with JAX on Cloud TPU v5e for sizes other than 1024x1024

Open zstiggz opened this issue 2 years ago • 5 comments

Describe the bug

Following the blog post on Accelerating Stable Diffusion XL Inference with JAX on Cloud TPU v5e. This worked magically until I tried to generate an image in a different size. At 1024x1024 we get inference latency of ~3s per image (as compared to ~8s on the NVIDIA A10G). But change the resolution to 1280x960 and we see next to no improvement.

Reproduction

Use the same code as in the blog post: https://huggingface.co/blog/sdxl_jax

Changes:

def generate(
    prompt,
    negative_prompt,
    seed=default_seed,
    guidance_scale=default_guidance_scale,
    num_inference_steps=default_num_steps,
    width=1024,
    height=1024,
):
    prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
    prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
    images = pipeline(
        prompt_ids,
        p_params,
        rng,
        num_inference_steps=num_inference_steps,
        neg_prompt_ids=neg_prompt_ids,
        guidance_scale=guidance_scale,
        width=width,
        height=height,
        jit=True,
    ).images

    # convert the images to PIL
    images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
    return pipeline.numpy_to_pil(np.array(images))
start = time.time()
print(f"Compiling ...")
generate(default_prompt, default_neg_prompt, width=960, height=1280)
print(f"Compiled in {time.time() - start}")
start = time.time()
print("starting")
prompt = "llama in ancient Greece, oil on canvas"
neg_prompt = "cartoon, illustration, animation"
images = generate(prompt, neg_prompt, width=960, height=1280)
print(f"Inference in {time.time() - start}")

Logs

No response

System Info

Python: 3.10.6 Diffusers: 0.26.2 Torch: 2.2.0+cu121 Jax: 0.4.23 Flax: 0.8.0

Who can help?

@patrickvonplaten @yiyixuxu @DN6

zstiggz avatar Feb 07 '24 01:02 zstiggz

cc: @pcuenca for visibility here

DN6 avatar Feb 07 '24 13:02 DN6

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Mar 13 '24 15:03 github-actions[bot]

This could be for a number of reasons, but unfortunately I don't currently have access to TPU v5e instances to test. I'll see if we can get one to verify.

pcuenca avatar Mar 13 '24 16:03 pcuenca

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Apr 07 '24 15:04 github-actions[bot]

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar May 02 '24 15:05 github-actions[bot]

I have a similar issue when I tried SD1.5. SD15 inferences with custom resolution take unfair inference times. 512x512 50 step ⇒ 1.06 sec 512w x 640h 50 step ⇒ 3.36 sec 640x640 50 step ⇒ 5.02 sec 768x768 50 step ⇒ 4.34 sec

TPU-v5e-1chip

huseyintemiz avatar Jul 24 '24 09:07 huseyintemiz