Allow Flax model to save as pytorch (If converted from pytorch)
Allows saving mapping_dict when loading pytorch model through from_pretrained(from_pt=True), which can be used to save flax model in pytorch bin format.
Add new option return_mapping_dict to from_pretrained. When set to True (default False), mapping_dict will be returned as well. The mapping_dict can be passed to new method save_as_pytorch_bin, allowing fax params dict to be converted and saved as pytorch bin.
Currently config isn't saved, as current config loading for flax seemed outdated, as FutureWarning is triggered. I'll see if I can implement saving config as well, after flax models got config loading updated.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.
(If converted from pytorch) oh no 🥺
The same mapping dict could probably be use to convert any flax model back to pytorch. I only tested with unet. But all model converted through 'convert_pytorch_state_dict_to_flax' could be reversed.
The same model class should generate the same mapping_dict, so old models can be converted back as well. But for native pytorch loading flax model, check pull#1241
@patrickvonplaten knows the code how transformers can do? why diffusers need mapping_dict but transformers not also every human being on this planet is waiting for this model to be converted to pytorch 😋 @patrickvonplaten help!

from_flax=True
https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.from_pretrained
from_flax=Truehttps://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.from_pretrained
Thanks! I didn't know it was implemented. Does from_flax work with diffusers?
@averad from_flax=True for transformers we need diffusers 😋
@averad
from_flax=Truefor transformers we need diffusers 😋
Just providing documentation for when @pcuenca inevitably pops in here.
Maybe also cc @yiyixuxu here :-)
hi @Lime-Cakes please change the title people not reading the (If converted from pytorch) part like me I also not read that part and I realized I need mapping_dict when I setup everything 😋 maybe like "🚨 If converted from pytorch 🚨 Allow Flax model to save as pytorch"
note for newcomers please also check: https://github.com/huggingface/diffusers/issues/1161 https://github.com/huggingface/diffusers/pull/1217 https://github.com/huggingface/diffusers/pull/1241
https://github.com/camenduru/diffusers/blob/from_flax_v2/src/diffusers/modeling_pytorch_flax_utils.py
https://huggingface.co/camenduru/plushies-pt
No longer needed after #1900 merge. Closing. Use from_pretrained(from_flax=True) instead.