Implement asynchronous module prefetching (QWEN+WAN so far)
Draft of a generic module prefetcher. Implement the core feature and give one example of how to use it with QWEN.
This is able to get very close to compute saturation whereas --async-offload as-is still has a few compute stalls.
Leaving as a draft for now, as I am still trying to find a better way.
Start comfy use QWEN to try it out. You need the following startup args:
--async-offload --fast pinned_memory --reserve-vram 3
It consumes a bit extra VRAM so you need to --reserve-vram to avoid OOMing.
Added WAN support
I wasn't able to check the PR yet but have you looked at GrpupOffloading from diffusers: https://github.com/huggingface/diffusers/blob/main/src/diffusers/hooks/group_offloading.py ?
It is similar but should have the adwantage of not requiring any model code changes.
I wasn't able to check the PR yet but have you looked at GrpupOffloading from diffusers: https://github.com/huggingface/diffusers/blob/main/src/diffusers/hooks/group_offloading.py ?
It is similar but should have the adwantage of not requiring any model code changes.
I had a very quick skim though. I see it has awareness of nn.ModuleList which may actually short circuit the prefetching block code instrumentation I did here and get it frictionless. It's definately a good idea if going with long range prefetchers.
That approach is slightly fragile in that a model author could do something weird or have multiple or heirachical lists whereas this open-coded system give you just that tiny bit of control a model author might want anyway.
The design goal is simplicity at the moment, and ideally we get away with totally generic layer level prefetching with just incremental improvement to --async-offload.