diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

[Core] Introduce class variants for `Transformer2DModel`

Open sayakpaul opened this issue 1 year ago • 12 comments

What does this PR do?

Introduces two variants of Transformer2DModel:

  • DiTTransformer2DModel
  • PixArtTransformer2DModel

For the other instances where Transformer2DModel is used, they should later be turned to blocks as they shouldn't be inheriting from ModelMixin (has been discussed internally).

TODO:

(Will be tackled after I get an initial review)

  • [x] Tests for each individual variant
  • [x] Documentation

Some comments are in-line.

LMK.

sayakpaul avatar Apr 12 '24 05:04 sayakpaul

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yiyixuxu @DN6 a gentle ping here.

sayakpaul avatar Apr 29 '24 02:04 sayakpaul

Is the plan here to eventually map the Transformer2DModel to the variant? e.g A pipeline that uses Transformer2DModel with patched inference will now try to create PatchedTransformer2DModel under the hood?

Also how feasible is it to break it up into model specific variants rather than input specific variants? e.g PixArtTransformer2DModel?

DN6 avatar Apr 29 '24 06:04 DN6

Is the plan here to eventually map the Transformer2DModel to the variant? e.g A pipeline that uses Transformer2DModel with patched inference will now try to create PatchedTransformer2DModel under the hood?

Yeah, that's the plan.

Also how feasible is it to break it up into model specific variants rather than input specific variants? e.g PixArtTransformer2DModel?

Feasible, but I am not sure if we have enough such transformer-based pipelines yet. Most of them vary across very few things (such as the norm type and a cross-attention layer).

I think there is a fair trade-off to be had when deciding which variant to use. If there are too many arguments that are changing, better to use a dedicated class (like we did for the private model). If not, rely on an existing variant that is dependent on the input type.

sayakpaul avatar Apr 29 '24 06:04 sayakpaul

@DN6 can I get another review? I have reflected your feedback and I am currently running the SLOW tests of DiT and PixArt-{Alpha,Sigma} to ensure nothing is breaking. After another round of feedback is addressed, I will add tests and documentation.

sayakpaul avatar May 02 '24 03:05 sayakpaul

Have run the slow tests too and they pass barring https://huggingface.slack.com/archives/C061LUF9G6B/p1714621340122669.

sayakpaul avatar May 02 '24 04:05 sayakpaul

@DN6 I think this is ready for another review now. I have gone ahead and ran the slow tests on the PixArt and DiT tests and they are passing as expected except for:

FAILED tests/models/transformers/test_models_pixart_transformer2d.py::PixArtTransformer2DModelTests::test_cpu_offload - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
FAILED tests/models/transformers/test_models_pixart_transformer2d.py::PixArtTransformer2DModelTests::test_model_parallelism - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and CPU!

I would appreciate @SunMarc's guidance for the above two. Here's what I have already tried:

  • I added PatchEmbed to _no_split_modules here, because of this. This does prevent the first test failure but the second error still remains.
  • The second one happens because scale_shift_table and it gets utilized here.

sayakpaul avatar May 15 '24 10:05 sayakpaul

@yiyixuxu thanks for your review.

I have knocked off all the smaller comments (like removing error checks as appropriate, making args optional, etc.). But I will wait for @DN6 to return and decide on LegacyModelMixin.

Even though we don't have as many transformer-based checkpoints as the UNet, I think we all have reasons and signals to believe that this situation is going to change soon (especially with things like SD3, PixArt-Sigma, open SD3 aka HunyuanDiT, etc.). So, would very much like to treat this PR as a good learning experience for us to conquer other similar (and potentially bigger) refactor PRs.

sayakpaul avatar May 16 '24 11:05 sayakpaul

@DN6 I think this is ready for another review now. I have gone ahead and ran the slow tests on the PixArt and DiT tests and they are passing as expected except for:

FAILED tests/models/transformers/test_models_pixart_transformer2d.py::PixArtTransformer2DModelTests::test_cpu_offload - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
FAILED tests/models/transformers/test_models_pixart_transformer2d.py::PixArtTransformer2DModelTests::test_model_parallelism - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and CPU!

I would appreciate @SunMarc's guidance for the above two. Here's what I have already tried:

  • I added PatchEmbed to _no_split_modules here, because of this. This does prevent the first test failure but the second error still remains.
  • The second one happens because scale_shift_table and it gets utilized here.

For PatchEmbed, it is indeed needed. For the second error, we don't have a choice apart from moving the tensors to the same device. It is because these operations are done in the forward of PixArtTransformer2DModel but not wrapped in a nn.Module, so we cannot move the tensor to the same device with device_map. ( + not really an option to put "PixArtTransformer2DModel" in no_split_modules).

What could be great is a way to specify modules that should stay on the same device. We already do that for tied weights. I can have a look at some point, so that we don't touch to the modeling code. Right now, the easiest solution but not the prettiest is to do:

        # 3. Output
        shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)).chunk(2, dim=1)
        hidden_states = self.norm_out(hidden_states)
        # Modulation
        hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)

LMK if it makes sense !

SunMarc avatar May 16 '24 15:05 SunMarc

Okay that cut it, @SunMarc. Thanks much! Might be worthwhile to revisit:

What could be great is a way to specify modules that should stay on the same device. We already do that for tied weights. I can have a look at some point, so that we don't touch to the modeling code.

sayakpaul avatar May 17 '24 09:05 sayakpaul

@mishig25 any reason why the doc build test would fail? I unable to locally reproduce that the dependency import failure as noticed here.

sayakpaul avatar May 21 '24 10:05 sayakpaul

@DN6 done. I think I have addressed all your comments. LMK.

sayakpaul avatar May 22 '24 07:05 sayakpaul

@DN6 resolved your comment on the location of _CLASS_REMAPPING_DICT. I have also moved _fetch_remapped_cls_from_config to model_loading_utils. I think this is better as _fetch_remapped_cls_from_config has nothing to do with the Hub.

sayakpaul avatar May 28 '24 14:05 sayakpaul

LGTM. cc: @yiyixuxu in case you want to take a look too.

DN6 avatar May 30 '24 14:05 DN6