[Core] support saving and loading of sharded checkpoints
What does this PR do?
Follow-up of https://github.com/huggingface/diffusers/pull/6396.
This PR adds support for saving a big model's state dict into multiple shards for efficient portability and loading. Adds support for loading the sharded checkpoints, too.
This is much akin to handling big models like T5XXL.
Also, added a nice test to ensure the models that have _no_split_modules specified can be sharded and loaded back to perform inference ensuring numerical assertions.
Here's a real use-case. Consider this Transformer2DModel checkpoint: https://huggingface.co/sayakpaul/actual_bigger_transformer/.
It was serialized like so:
from diffusers import Transformer2DModel
from accelerate.utils import compute_module_sizes, shard_checkpoint
from accelerate import init_empty_weights
import torch.nn as nn
def bytes_to_giga_bytes(bytes):
return f"{(bytes / 1024 / 1024 / 1024):.3f}"
with init_empty_weights():
pixart_transformer = Transformer2DModel.from_config("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="transformer")
bigger_transformer = Transformer2DModel.from_config(
pixart_transformer.config, num_layers=72, num_attention_heads=36, cross_attention_dim=2592,
)
module_size = bytes_to_giga_bytes(compute_module_sizes(bigger_transformer)[""])
print(f"{module_size=} GB")
pytorch_total_params = sum(p.numel() for p in bigger_transformer.parameters()) / 1e9
print(f"{pytorch_total_params=} B")
model = nn.Sequential(*[nn.Linear(8944, 8944) for _ in range(1000)])
module_size = bytes_to_giga_bytes(compute_module_sizes(model)[""])
print(f"{module_size=} GB")
pytorch_total_params = sum(p.numel() for p in model.parameters()) / 1e9
print(f"{pytorch_total_params=} B")
actual_bigger_transformer = Transformer2DModel.from_config(
pixart_transformer.config, num_layers=72, num_attention_heads=36, cross_attention_dim=2592
)
actual_bigger_transformer.save_pretrained("/raid/.cache/actual_bigger_transformer", max_shard_size="10GB", push_to_hub=True)
As we can see from the Hub repo that its state dict is sharded. To perform with the model, all we have to do is this:
from diffusers import Transformer2DModel
import tempfile
import torch
import os
def get_inputs():
sample = torch.randn(1, 4, 128, 128)
timestep = torch.randint(0, 1000, size=(1, ))
encoder_hidden_states = torch.randn(1, 120, 4096)
resolution = torch.tensor([1024, 1024]).repeat(1, 1)
aspect_ratio = torch.tensor([1.]).repeat(1, 1)
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
return sample, timestep, encoder_hidden_states, added_cond_kwargs
with torch.no_grad():
# max_memory = {0: "15GB"} # reasonable estimate for a consumer-gpu.
with tempfile.TemporaryDirectory() as tmp_dir:
new_model = Transformer2DModel.from_pretrained(
"sayakpaul/actual_bigger_transformer",
device_map="auto",
)
sample, timestep, encoder_hidden_states, added_cond_kwargs = get_inputs()
out = new_model(
hidden_states=sample,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
added_cond_kwargs=added_cond_kwargs
).sample
print(f"{out.shape=}, {out.device=}")
I haven't purposefully haven't added documentation because all of this will become useful once we use this in the context of a full-fledged pipeline execution (up next) :)
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
@yiyixuxu @SunMarc a gentle ping here.
@BenjaminBossan thanks for awesome reviews. I have either addressed or replied to them.
@yiyixuxu gentle ping.
@Wauplin thanks for your comments. I have incorporated the state dict sharding utility from huggingface_hub which greatly simplifies things. @SunMarc @yiyixuxu I think this is ready for a final review now.
Tests are failing because the Dockerfiles haven't been updated with the latest huggingface_hub. Triggered it: https://github.com/huggingface/diffusers/actions/runs/9254070149.
@Wauplin about the failing test here, it passes on the main.
After a bit of investigation, I traced it to: https://github.com/huggingface/diffusers/blob/c98d7790627ca8d66bcfd363e520713bf799c85c/src/diffusers/models/modeling_utils.py#L334
Printing state_dict_split.filename_to_tensors.keys() returns an empty dictionary. While I know that the module under question doesn't have any layer per se, it is still handled in main (we don't give any special treatment there), I wonder if it needs to tackled from huggingface_hub. LMK.
Hi @Wauplin. Thanks for your detailed comments. In summary:
- I have maintained
os.path.join()throughout when dealing with local paths. - Modified the directory file removal logic per your suggestions.
LMK if there's anything unclear about the issue I mentioned in https://github.com/huggingface/diffusers/pull/7830#pullrequestreview-2081930869.
Printing
state_dict_split.filename_to_tensors.keys()returns an empty dictionary.
What would be the expected output for you here? Happy to fix something in huggingface_hub but since there is nothing to store and therefore no tensors I don't see what could be outputed.
What would be the expected output for you here? Happy to fix something in huggingface_hub but since there is nothing to store and therefore no tensors I don't see what could be outputted.
The expectation is that the test would pass; otherwise, this seems like backwards-breaking behavior to me. I can introduce a layer in the concerned encoder module so that it passes, but I'm not sure yet.
What I meant is "what change would make sense to have to handle empty models?". If it is a corner case that is there only to test internal logic of diffusers, then I won't update huggingface_hub for that and it's fine to have specific logic for this test. If the test is actually testing a real-world use case that makes sense to support then I'm happy to update split_torch_state_dict_into_shards output accordingly. Hence my question about what you would expect from split_torch_state_dict_into_shards in such a case?
@Wauplin addressed your comments.
I do think we need to still have two different codepaths for serialization and I provided my reasoning in the comments above. LMK.
https://github.com/huggingface/diffusers/pull/7830/commits/0706cae4daa223bd923cb77fd025deb99de24c16 should resolve the test failure. LMK if that works for you.
https://github.com/huggingface/diffusers/commit/0706cae4daa223bd923cb77fd025deb99de24c16 should resolve the test failure. LMK if that works for you.
Yes sure!
@Wauplin landed another set of changes.
I'd rather have another pair of eyes reviewing it, given it's fairly easy to miss something when iterating/reviewing several times on the same code.
Yeah. @yiyixuxu would be the final approver here :)
@yiyixuxu do the recent changes work for you?
(I have run the tests)
Yay! Great job @sayakpaul ! :tada: