[Core] Introduce class variants for `Transformer2DModel`
What does this PR do?
Introduces two variants of Transformer2DModel:
-
DiTTransformer2DModel -
PixArtTransformer2DModel
For the other instances where Transformer2DModel is used, they should later be turned to blocks as they shouldn't be inheriting from ModelMixin (has been discussed internally).
TODO:
(Will be tackled after I get an initial review)
- [x] Tests for each individual variant
- [x] Documentation
Some comments are in-line.
LMK.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
@yiyixuxu @DN6 a gentle ping here.
Is the plan here to eventually map the Transformer2DModel to the variant? e.g A pipeline that uses Transformer2DModel with patched inference will now try to create PatchedTransformer2DModel under the hood?
Also how feasible is it to break it up into model specific variants rather than input specific variants? e.g PixArtTransformer2DModel?
Is the plan here to eventually map the Transformer2DModel to the variant? e.g A pipeline that uses Transformer2DModel with patched inference will now try to create PatchedTransformer2DModel under the hood?
Yeah, that's the plan.
Also how feasible is it to break it up into model specific variants rather than input specific variants? e.g PixArtTransformer2DModel?
Feasible, but I am not sure if we have enough such transformer-based pipelines yet. Most of them vary across very few things (such as the norm type and a cross-attention layer).
I think there is a fair trade-off to be had when deciding which variant to use. If there are too many arguments that are changing, better to use a dedicated class (like we did for the private model). If not, rely on an existing variant that is dependent on the input type.
@DN6 can I get another review? I have reflected your feedback and I am currently running the SLOW tests of DiT and PixArt-{Alpha,Sigma} to ensure nothing is breaking. After another round of feedback is addressed, I will add tests and documentation.
Have run the slow tests too and they pass barring https://huggingface.slack.com/archives/C061LUF9G6B/p1714621340122669.
@DN6 I think this is ready for another review now. I have gone ahead and ran the slow tests on the PixArt and DiT tests and they are passing as expected except for:
FAILED tests/models/transformers/test_models_pixart_transformer2d.py::PixArtTransformer2DModelTests::test_cpu_offload - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
FAILED tests/models/transformers/test_models_pixart_transformer2d.py::PixArtTransformer2DModelTests::test_model_parallelism - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and CPU!
I would appreciate @SunMarc's guidance for the above two. Here's what I have already tried:
- I added
PatchEmbedto_no_split_moduleshere, because of this. This does prevent the first test failure but the second error still remains. - The second one happens because
scale_shift_tableand it gets utilized here.
@yiyixuxu thanks for your review.
I have knocked off all the smaller comments (like removing error checks as appropriate, making args optional, etc.). But I will wait for @DN6 to return and decide on LegacyModelMixin.
Even though we don't have as many transformer-based checkpoints as the UNet, I think we all have reasons and signals to believe that this situation is going to change soon (especially with things like SD3, PixArt-Sigma, open SD3 aka HunyuanDiT, etc.). So, would very much like to treat this PR as a good learning experience for us to conquer other similar (and potentially bigger) refactor PRs.
@DN6 I think this is ready for another review now. I have gone ahead and ran the slow tests on the PixArt and DiT tests and they are passing as expected except for:
FAILED tests/models/transformers/test_models_pixart_transformer2d.py::PixArtTransformer2DModelTests::test_cpu_offload - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! FAILED tests/models/transformers/test_models_pixart_transformer2d.py::PixArtTransformer2DModelTests::test_model_parallelism - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and CPU!I would appreciate @SunMarc's guidance for the above two. Here's what I have already tried:
- I added
PatchEmbedto_no_split_moduleshere, because of this. This does prevent the first test failure but the second error still remains.- The second one happens because
scale_shift_tableand it gets utilized here.
For PatchEmbed, it is indeed needed. For the second error, we don't have a choice apart from moving the tensors to the same device. It is because these operations are done in the forward of PixArtTransformer2DModel but not wrapped in a nn.Module, so we cannot move the tensor to the same device with device_map. ( + not really an option to put "PixArtTransformer2DModel" in no_split_modules).
What could be great is a way to specify modules that should stay on the same device. We already do that for tied weights. I can have a look at some point, so that we don't touch to the modeling code. Right now, the easiest solution but not the prettiest is to do:
# 3. Output
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)
LMK if it makes sense !
Okay that cut it, @SunMarc. Thanks much! Might be worthwhile to revisit:
What could be great is a way to specify modules that should stay on the same device. We already do that for tied weights. I can have a look at some point, so that we don't touch to the modeling code.
@mishig25 any reason why the doc build test would fail? I unable to locally reproduce that the dependency import failure as noticed here.
@DN6 done. I think I have addressed all your comments. LMK.
@DN6 resolved your comment on the location of _CLASS_REMAPPING_DICT. I have also moved _fetch_remapped_cls_from_config to model_loading_utils. I think this is better as _fetch_remapped_cls_from_config has nothing to do with the Hub.
LGTM. cc: @yiyixuxu in case you want to take a look too.