Allow passing a `CACHE_MAPPING` of previously downloaded components and their sha256 to avoid deduplicated downloading
A common problem of diffusers models/pipelines is that many components of different pipelines share the exact same underlying weights, but it's hard to avoid not downloading them twice. We could solve this problem by providing a cache_mapping: Dict[str, path] to DiffusionPipeline.from_pretrained(...) that would check if the file has previously been downloaded and if yes, it won't be downloaded again. If not, it will be downloaded and added to cache_mapping.
It's quite trivial to look up the sha256 hashes of files before downloading them, e.g.:
from huggingface_hub import model_info
info = model_info("runwayml/stable-diffusion-v1-5", files_metadata=True)
files = info.siblings
shas = {f.rfilename: f.lfs["sha256"] for f in files if f.lfs is not None}
shas
gives
{'safety_checker/pytorch_model.bin': '193490b58ef62739077262e833bf091c66c29488058681ac25cf7df3d8190974',
'text_encoder/pytorch_model.bin': '770a47a9ffdcfda0b05506a7888ed714d06131d60267e6cf52765d61cf59fd67',
'unet/diffusion_pytorch_model.bin': 'c7da0e21ba7ea50637bee26e81c220844defdf01aafca02b2c42ecdadb813de4',
'v1-5-pruned-emaonly.ckpt': 'cc6cb27103417325ff94f52b7a5d2dde45a7515b25c255d8e396c90014281516',
'v1-5-pruned.ckpt': 'e1441589a6f3c5a53f5f54d0975a18a7feb7cdf0b0dee276dfc3331ae376a053',
'vae/diffusion_pytorch_model.bin': '1b134cded8eb78b184aefb8805b6b572f36fa77b255c483665dda931fa0130c5'}
See colab here: https://colab.research.google.com/drive/1WGLdcgnzbIf_dn9QF51TRO_6ogEqVsea?usp=sharing
Now we could integrate this code quite easily into from_pretrained(...) since we're making a call to the Hub anyways already: https://github.com/huggingface/diffusers/blob/f73ed179610653bf100215a54ca2c8a3cba91cf0/src/diffusers/pipelines/pipeline_utils.py#L509
From the user API it could look as follows:
cache_mapping = {}
from diffusers import DiffusionPipeline
pipeline, cache_mapping = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", cache_mapping=cache_mapping}
# then cache mapping would look as follows:
# {"193490b58ef62739077262e833bf091c66c29488058681ac25cf7df3d8190974": "./cache/.... <path/to/file>, ...}
pipeline, cache_mapping = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", cache_mapping=cache_mapping} # now the safety checker won't be downloaded again.
cc @pcuenca @keturn @patil-suraj @anton-l what do you think?
Does this only cover the one weights file, or does it cover accompanying files like config.json or vocab.json?
sha256 is slow for multi-gigabyte files. Consider https://github.com/escherba/python-metrohash or https://github.com/oconnor663/blake3-py (b3sum) or https://github.com/ifduyue/python-xxhash
I'd say it only covers the weight files as config.json and vocab.json are "cheap" in terms of memory storage. So the idea would be to simply add a cached weight file to the "ignore patterns": https://github.com/huggingface/diffusers/blob/135567f18edc0bf02d515d5c76cc736d1ebddad3/src/diffusers/pipelines/pipeline_utils.py#L484
Regarding sha256 being slow. I don't think this is fully related since we don't need to compute the sha256 - it's already computed when models are uploaded to the Hub, e.g. see: https://huggingface.co/runwayml/stable-diffusion-v1-5/raw/main/unet/diffusion_pytorch_model.bin
What I mean here is that from_pretrained(...) would never compute any hashes, it expects that the hashes are available online.
What do you think?
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Is that still of interest?
Regarding sha256 being slow. I don't think this is fully related since we don't need to compute the sha256 - it's already computed when models are uploaded to the Hub,
Oh, is this a side-effect of the way git lfs works? Maybe a little too tied in to that implementation detail, but I suppose as long as we also have access to that on the local side for comparison it would be fine.
And yes, very much still of interest. lstein ran the numbers and figured out that he could save 20% of his model storage with some deduplication. Not to mention the bandwidth savings.
Actually thinking more about it, I wonder whether we can solve this problem directly upstream in huggingface hub - opened an issue there.
Also, I'm not sure whether a gain of 20% model storage is really a big improvement of life for users - in my experience if people are limited by storage a 20% improvement usually doesn't help, often much more drastic storage improvements are needed if storage becomes a problem.
As your retrieval and storage backend seems disinclined to make cross-project links in the storage layer, one alternative would be to push those links to the application layer:
allow a model_index or config.json file to link to another project, so it could say something like
vae: https://huggingface.co/stabilityai/sd-vae-ft-mse
safety_checker: https://huggingface.co/CompVis/stable-diffusion-v1-4/tree/4d5006d3bac581bad4474213863450187137f58e/safety_checker
That would make it explicitly clear what the provenance of the file is, where to obtain it, and would be completely compatible with the current huggingface-hub fetching and caching.
Yeah I think should then probably add some functionality to diffusers right away. Hmm not too keen on changing the format of existing configs.
I was thinking about about a smart_cache=True flag here because we cannot really expect the user to know which model checkpoints are duplicated. So I'd maybe just propose that if smart_cache=True then from_pretrained(...) will create & look up a local index json file that has the mapping:
<hash_key>-<path_to_file>
and if a file is already present it won't load it again.
E.g. we could do something like the following:
if smart_cache:
# load local index
local_checkpoints = ... # {sha<>path_to_checkpoint}
# retrieve requested shas
info = model_info("CompVis/stable-diffusion-v1-4", files_metadata=True)
files_info = info.siblings
checkpoint_shas = {s.sha: s.rfilename.split("./")[0] for s in files_info}
# figure out if requested shas match existing local_checkpoints
# for every match we add a `ignore_pattern` to snapshot download
# we then create simlinks in the downloaded folder
- Add a
smart_cache=Trueflag tofrom_pretrained(...)
As Wauplin pointed out, symlinks may get you in trouble when the link target disappears from the cache. Hard links will work better, and I think all our expected platforms support hard links these days? (I'm not entirely sure about Windows but search suggests it should.)
The "corrupting one file will corrupt all its links" problem is still a potential concern. But as this is supposed to be a cache of exact copies of a remote resource, I'm tempted to hand-wave that one away by setting the files to read-only and saying anyone that tampers with them is making their own trouble.
You may also need to adjust the scan-cache command to account for the links, otherwise the "total usage" it reports will include the duplicates.
Sadly I won't find time for this anytime soon I'm afraid. It's still on my TODO-List but it's not a prio right now