[refactor] Making the xformers mem-efficient attention activation recursive
- move the enable/disable call to being part of the base DiffusionPipeline (removes a bunch of duplicates)
- make the call recursive across all the modules in the model graph, so that exposing
set_use_memory_efficient_attention_xformersin a leaf module is all it takes for it to be picked up (important for some pipelines, like superres, which are not properly covered right now - see for instance #1492 )
cc @patrickvonplaten, discussed a couple of days ago. Note that there does not seem to be unit tests covering this part, unless I missed them
The documentation is not available anymore as the PR was closed or merged.
open for feedback, this is a suggestion of course @patrickvonplaten @kashif
I tested
from diffusers import StableDiffusionPipeline
import torch
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
revision="fp16",
torch_dtype=torch.float16,
).to("cuda")
pipe.enable_xformers_memory_efficient_attention()
with torch.inference_mode():
sample = pipe("a small cat")
sample[0][0].save("cat.png")
which works fine with this PR
PR looks very nice to me! Given that xformers can essentially be used with every attention layer and every unet pretty much has an attention layer and every pipeline has at least one unet, I think it's a good idea to make it a "global" method by adding it to DiffusionPipeline - what do the others think here?
if you check a PR like this one, the changes here make it a lot easier and would remove 2/3rd of the lines of code