[Question] Can unet_traced.pt (TORCH.JIT.TRACE) make diffusers faster?
I saw it here https://huggingface.co/riffusion/riffusion-model-v1/tree/main/unet_traced and found here https://huggingface.co/docs/diffusers/optimization/fp16 and here https://pytorch.org/docs/stable/generated/torch.jit.trace.html

If it is compiled version of the model it should be faster right? is this like jax.jit ?
we can save it but how can we use with diffusers?
unet_traced = torch.jit.trace(unet, inputs)
unet_traced.save("unet_traced.pt")
I found it ๐ https://github.com/riffusion/riffusion-inference/blob/40e1e51c6a4e2c97cf1bc1193820862a941df62a/riffusion/server.py#L89-L114
I test it ๐งช T4, with unet_traced without xformers 30/50 [00:04<00:01, 6.51it/s] ๐ฒ T4, with xformers without unet_traced 30/50 [00:05<00:02, 6.38it/s] ๐ฒ
Interested!
@camenduru what setup do you use exactly (which PyTorch versions etc...)
Also could you post a code snippet the community could play with?
hi @patrickvonplaten ๐ this is the vm I tested on and I am preparing a colab
update: it is working with this pipeline but not StableDiffusionPipeline
hi @NouamaneTazi ๐ I am trying to run your code from https://github.com/huggingface/diffusers/blob/d07f73003d4d077854869b8f73275657f280334c/docs/source/optimization/fp16.mdx?plain=1#L287-L321 how did you run this code and get
| traced UNet | 3.21s | x2.96 |
|---|
please tell me how can I run your code without 'TracedUNet' object has no attribute 'config'
Good news ๐ it is working with custom_pipeline="interpolate_stable_diffusion" but not that fast ๐ญ
51/51 [00:08<00:00, 5.72it/s] maybe I am creating unet_traced.pt wrong idk ๐ค I will ask to riffusion team
colab: https://github.com/camenduru/notebooks/blob/main/camenduru's_unet_traced.ipynb
update: just colab is slow I tested with same T4 vm ๐๐๐

xformers vs JIT

colab T4
