diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

[Question] Can unet_traced.pt (TORCH.JIT.TRACE) make diffusers faster?

Open camenduru opened this issue 3 years ago โ€ข 5 comments

I saw it here https://huggingface.co/riffusion/riffusion-model-v1/tree/main/unet_traced and found here https://huggingface.co/docs/diffusers/optimization/fp16 and here https://pytorch.org/docs/stable/generated/torch.jit.trace.html

Screenshot 2022-12-15 203423

If it is compiled version of the model it should be faster right? is this like jax.jit ?

we can save it but how can we use with diffusers?

unet_traced = torch.jit.trace(unet, inputs)
unet_traced.save("unet_traced.pt")

camenduru avatar Dec 17 '22 17:12 camenduru

I found it ๐ŸŽ‰ https://github.com/riffusion/riffusion-inference/blob/40e1e51c6a4e2c97cf1bc1193820862a941df62a/riffusion/server.py#L89-L114

camenduru avatar Dec 19 '22 23:12 camenduru

I test it ๐Ÿงช T4, with unet_traced without xformers 30/50 [00:04<00:01, 6.51it/s] ๐Ÿ˜ฒ T4, with xformers without unet_traced 30/50 [00:05<00:02, 6.38it/s] ๐Ÿ˜ฒ

camenduru avatar Dec 19 '22 23:12 camenduru

Interested!

@camenduru what setup do you use exactly (which PyTorch versions etc...)

Also could you post a code snippet the community could play with?

patrickvonplaten avatar Dec 20 '22 00:12 patrickvonplaten

hi @patrickvonplaten ๐Ÿ‘‹ this is the vm I tested on and I am preparing a colab

camenduru avatar Dec 20 '22 13:12 camenduru

update: it is working with this pipeline but not StableDiffusionPipeline

hi @NouamaneTazi ๐Ÿ‘‹ I am trying to run your code from https://github.com/huggingface/diffusers/blob/d07f73003d4d077854869b8f73275657f280334c/docs/source/optimization/fp16.mdx?plain=1#L287-L321 how did you run this code and get

traced UNet 3.21s x2.96

please tell me how can I run your code without 'TracedUNet' object has no attribute 'config'

camenduru avatar Dec 20 '22 15:12 camenduru

Good news ๐ŸŽ‰ it is working with custom_pipeline="interpolate_stable_diffusion" but not that fast ๐Ÿ˜ญ

51/51 [00:08<00:00, 5.72it/s] maybe I am creating unet_traced.pt wrong idk ๐Ÿค” I will ask to riffusion team

colab: https://github.com/camenduru/notebooks/blob/main/camenduru's_unet_traced.ipynb

camenduru avatar Dec 21 '22 18:12 camenduru

update: just colab is slow I tested with same T4 vm ๐ŸŽ‰๐ŸŽ‰๐ŸŽ‰ Screenshot 2022-12-21 215634

camenduru avatar Dec 21 '22 19:12 camenduru

xformers vs JIT

Screenshot 2022-12-21 224006

camenduru avatar Dec 21 '22 19:12 camenduru

colab T4 Screenshot 2022-12-21 232002

camenduru avatar Dec 21 '22 20:12 camenduru