diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

Allow Flax model to save as pytorch (If converted from pytorch)

Open Lime-Cakes opened this issue 3 years ago • 9 comments

Allows saving mapping_dict when loading pytorch model through from_pretrained(from_pt=True), which can be used to save flax model in pytorch bin format.

Add new option return_mapping_dict to from_pretrained. When set to True (default False), mapping_dict will be returned as well. The mapping_dict can be passed to new method save_as_pytorch_bin, allowing fax params dict to be converted and saved as pytorch bin.

Currently config isn't saved, as current config loading for flax seemed outdated, as FutureWarning is triggered. I'll see if I can implement saving config as well, after flax models got config loading updated.

Lime-Cakes avatar Dec 17 '22 11:12 Lime-Cakes

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

(If converted from pytorch) oh no 🥺

camenduru avatar Dec 17 '22 12:12 camenduru

The same mapping dict could probably be use to convert any flax model back to pytorch. I only tested with unet. But all model converted through 'convert_pytorch_state_dict_to_flax' could be reversed.

The same model class should generate the same mapping_dict, so old models can be converted back as well. But for native pytorch loading flax model, check pull#1241

Lime-Cakes avatar Dec 17 '22 13:12 Lime-Cakes

@patrickvonplaten knows the code how transformers can do? why diffusers need mapping_dict but transformers not also every human being on this planet is waiting for this model to be converted to pytorch 😋 @patrickvonplaten help!

Screenshot 2022-12-17 193934

camenduru avatar Dec 17 '22 13:12 camenduru

from_flax=True https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.from_pretrained

averad avatar Dec 19 '22 18:12 averad

from_flax=True https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.from_pretrained

Thanks! I didn't know it was implemented. Does from_flax work with diffusers?

Lime-Cakes avatar Dec 19 '22 20:12 Lime-Cakes

@averad from_flax=True for transformers we need diffusers 😋

camenduru avatar Dec 19 '22 20:12 camenduru

@averad from_flax=True for transformers we need diffusers 😋

Just providing documentation for when @pcuenca inevitably pops in here.

averad avatar Dec 19 '22 20:12 averad

Maybe also cc @yiyixuxu here :-)

patrickvonplaten avatar Dec 20 '22 00:12 patrickvonplaten

hi @Lime-Cakes please change the title people not reading the (If converted from pytorch) part like me I also not read that part and I realized I need mapping_dict when I setup everything 😋 maybe like "🚨 If converted from pytorch 🚨 Allow Flax model to save as pytorch"

note for newcomers please also check: https://github.com/huggingface/diffusers/issues/1161 https://github.com/huggingface/diffusers/pull/1217 https://github.com/huggingface/diffusers/pull/1241

camenduru avatar Dec 21 '22 14:12 camenduru

download https://github.com/camenduru/diffusers/blob/from_flax_v2/src/diffusers/modeling_pytorch_flax_utils.py https://huggingface.co/camenduru/plushies-pt

camenduru avatar Dec 24 '22 07:12 camenduru

No longer needed after #1900 merge. Closing. Use from_pretrained(from_flax=True) instead.

Lime-Cakes avatar Jan 20 '23 07:01 Lime-Cakes