Convert Flax trained model to PyTorch
Based on #1161. I've mostly taken and adapted Flax->PT code from transformers.
My test scenario is currently this:
from diffusers import FlaxStableDiffusionPipeline, StableDiffusionPipeline
from diffusers import AutoencoderKL, UNet2DConditionModel
import jax.numpy as jnp
pipe, params = FlaxStableDiffusionPipeline.from_pretrained('CompVis/stable-diffusion-v1-4', revision='bf16', dtype=jnp.bfloat16)
pipe.save_pretrained('output-flax', params)
vae = AutoencoderKL.from_pretrained('output-flax/vae', from_flax=True)
unet = UNet2DConditionModel.from_pretrained('output-flax/unet', from_flax=True)
pipe2 = StableDiffusionPipeline.from_pretrained('output-flax', from_flax=True)
VAE and UNet load, but I have yet to test, if they run without problems.
The last line of my test code for the pipeline currently fails with:
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ test.py:11 in <module> │
│ │
│ 8 │
│ 9 vae = AutoencoderKL.from_pretrained('output-flax/vae', from_flax=True) │
│ 10 unet = UNet2DConditionModel.from_pretrained('output-flax/unet', from_flax=True) │
│ ❱ 11 pipe2 = StableDiffusionPipeline.from_pretrained('output-flax', from_flax=True) │
│ 12 │
│ │
│ diffusers/src/diffusers/pipeline_utils.py:632 in │
│ from_pretrained │
│ │
│ 629 │ │ │ │ │
│ 630 │ │ │ │ # check if the module is in a subdirectory │
│ 631 │ │ │ │ if os.path.isdir(os.path.join(cached_folder, name)): │
│ ❱ 632 │ │ │ │ │ loaded_sub_model = load_method(os.path.join(cached_folder, name), ** │
│ 633 │ │ │ │ else: │
│ 634 │ │ │ │ │ # else load from the root directory │
│ 635 │ │ │ │ │ loaded_sub_model = load_method(cached_folder, **loading_kwargs) │
│ │
│ miniconda3/lib/python3.9/site-packages/transformers/modeling_utils.py:1815 in │
│ from_pretrained │
│ │
│ 1812 │ │ │ │ │ │ "weights." │
│ 1813 │ │ │ │ │ ) │
│ 1814 │ │ │ │ elif os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME): │
│ ❱ 1815 │ │ │ │ │ raise EnvironmentError( │
│ 1816 │ │ │ │ │ │ f"Error no file named {WEIGHTS_NAME} found in directory {pretrai │
│ 1817 │ │ │ │ │ │ "there is a file for Flax weights. Use `from_flax=True` to load │
│ 1818 │ │ │ │ │ │ "weights." │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
OSError: Error no file named pytorch_model.bin found in directory output-flax/text_encoder but there is a file for Flax weights. Use `from_flax=True` to load this model from those weights.
Not sure if this needs fixing in diffusers or in transformers, though I assume it should be possible to get this to work without modifying transformers.
@patrickvonplaten
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.
@jorahn ~~I haven't looked at your code yet, but the issue may be that diffusers uses a different name for the weights: https://github.com/huggingface/diffusers/blob/main/src/diffusers/utils/init.py#L68.~~
Sorry I think you already got that.
yes, it's currently not in the right place (should go into the __init__.py you mention), but i created a variable for that
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.
this runs without exceptions now:
from diffusers import FlaxStableDiffusionPipeline, StableDiffusionPipeline
from diffusers import AutoencoderKL, UNet2DConditionModel
import jax.numpy as jnp
pipe, params = FlaxStableDiffusionPipeline.from_pretrained('CompVis/stable-diffusion-v1-4', revision='bf16', dtype=jnp.bfloat16)
pipe.save_pretrained('output-flax', params)
pipe2 = StableDiffusionPipeline.from_pretrained('output-flax', from_flax=True, safety_checker=None)
still need to test if it actually produces images.
I am testing...
😥
RuntimeError Traceback (most recent call last)
[<ipython-input-3-185344f22e90>](https://localhost:8080/#) in <module>
1 get_ipython().system('pip install diffusers transformers -qq')
2 from diffusers import StableDiffusionPipeline
----> 3 pipe = StableDiffusionPipeline.from_pretrained("camenduru/test", safety_checker=None).to("cuda")
2 frames
[/usr/local/lib/python3.7/dist-packages/transformers/modeling_utils.py](https://localhost:8080/#) in _load_pretrained_model(cls, model, state_dict, loaded_keys, resolved_archive_file, pretrained_model_name_or_path, ignore_mismatched_sizes, sharded_metadata, _fast_init, low_cpu_mem_usage, device_map, offload_folder, offload_state_dict, dtype, load_in_8bit)
2593 "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
2594 )
-> 2595 raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
2596
2597 if len(unexpected_keys) > 0:
RuntimeError: Error(s) in loading state_dict for CLIPTextModel:
While copying the parameter named "text_model.embeddings.token_embedding.weight", whose dimensions in the model are torch.Size([49408, 768]) and whose dimensions in the checkpoint are torch.Size([49408, 768]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).
While copying the parameter named "text_model.embeddings.position_embedding.weight", whose dimensions in the model are torch.Size([77, 768]) and whose dimensions in the checkpoint are torch.Size([77, 768]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).
While copying the parameter named "text_model.encoder.layers.0.self_attn.k_proj.weight", whose dimensions in the model are torch.Size([768, 768]) and whose dimensions in the checkpoint are torch.Size([768, 768]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).
While copying the parameter named "text_model.encoder.layers.0.self_attn.k_proj.bias", whose dimensions in the model are torch.Size([768]) and whose dimensions in the checkpoint are torch.Size([768]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).
While copying the parameter named "text_model.encoder.layers.0.self_attn.v_proj.weight", whose dimensions in the model are torch.Size([768, 768]) and whose dimensions in the checkpoint are torch.Size([768, 768]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).
While copying the parameter named "text_model.encoder.layers.0.self_attn.v_proj.bias", whose dimensions in the model are torch.Size([768]) and whose dimensions in the checkpoint are torch.Size([768]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).
While copying the parameter named "text_model.encoder.layers.0.self_attn.q_proj.weight", whose dimensions in the model are torch.Size([768, 768]) and whose dimensions in the checkpoint are torch.Size([768, 768]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).
While copying the parameter named "text_model.encoder.layers.0.self_attn.q_proj.bias", whose dimensions in the model are torch.Size([768]) and whose dimensions in the checkpoint are torch.Size([768]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).
While copying the parameter named "text_model.encoder.layers.0.self_attn.out_proj.weight", whose dimensions in the model are torch.Size([768, 768]) and whose dimensions in the checkpoint are torch.Size([768, 768]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).
While copying the parameter named "text_model.encoder.layers.0.self_attn.out_proj.bias", whose dimensions in the model are torch.Size([768]) and whose dimensions in the checkpoint are torch.Size([768]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).
While copying the parameter named "text_model.encoder.layers.0.layer_norm1.weight", whose dimensions in the model are torch.Size([768]) and whose dimensions in the checkpoint are torch.Size([768]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).
While copying the parameter named "text_model.encoder.layers.0.layer_norm1.bias", whose dimensions in the model are torch.Size([768]) and whose dimensions in the checkpoint are torch.Size([768]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).
While copying the parameter named "text_model.encoder.layers.0.mlp.fc1.weight", whose dimensions in the model are torch.Size([3072, 768]) and whose dimensions in the checkpoint are torch.Size([3072, 768]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).
While copying the parameter named "text_model.encoder.layers.0.mlp.fc1.bias", whose dimensions in the model are torch.Size([3072]) and whose dimensions in the checkpoint are torch.Size([3072]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).
While copying the parameter named "text_model.encoder.layers.0.mlp.fc2.weight", whose dimensions in the model are torch.Size([768, 3072]) and whose dimensions in the checkpoint are torch.Size([768, 3072]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).
While copying the parameter named "text_model.encoder.layers.0.mlp.fc2.bias", whose dimensions in the model are torch.Size([768]) and whose dimensions in the checkpoint are torch.Size([768]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).
this is my updated test scenario:
from diffusers import FlaxStableDiffusionPipeline, StableDiffusionPipeline
from diffusers import AutoencoderKL, UNet2DConditionModel
import jax.numpy as jnp
import jax
from flax.jax_utils import replicate
from flax.training.common_utils import shard
prompt = 'an astronaut on a horse'
rng = jax.random.PRNGKey(42)
rng = jax.random.split(rng, jax.device_count())
prompt_jax = [prompt] * jax.device_count()
pipe, params = FlaxStableDiffusionPipeline.from_pretrained('CompVis/stable-diffusion-v1-4', revision='bf16', dtype=jnp.bfloat16)
prompt_ids = pipe.prepare_inputs(prompt_jax)
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
images = pipe(prompt_ids, p_params, rng, jit=True, num_inference_steps=10)[0]
images = images.reshape((images.shape[0],) + images.shape[-3:])
images = pipe.numpy_to_pil(images)
images[0].save('test_flax.jpg')
pipe.save_pretrained('output-flax', params)
del pipe
pipe2 = StableDiffusionPipeline.from_pretrained('output-flax', from_flax=True, safety_checker=None)
pipe2(prompt, num_inference_steps=10).images[0].save('test_pt.jpg')
the first image looks fine

the second doesn't

so i assume there is something wrong with the conversion of the weights.
and if i don't set safety_checker=None it also doesn't work, since it tries to import FlaxStableDiffusionSafetyChecker from transformers instead of stable-diffusion
@jorahn I add the model here https://huggingface.co/camenduru/test if you need
Looks like you made some nice progress - let me or @pcuenca know if you need help / are stuck somewhere :-)
hi @patrickvonplaten yes we need help, maybe the problem is this https://github.com/huggingface/diffusers/pull/1217#issuecomment-1312739419 please tell us how to adapt the old function to diffusers
I'm pretty busy right now, but happy to take a look at the end of the week :)
Thank you!
Also quite busy at the moment - @jorahn essentially we should copy-paste the functionality from PyTorch here and then have to debug that each component can be correctly converted
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
I found an alternative, easier but hacky method. Still testing it at the moment. But it should work allow saving flax unet/vae/any model that uses convert_pytorch_state_dict_to_flax in pytorch format. Instead of implementing loading flax on pytorch model's from_pretrained, I wrote a script to save the flax params as a pytorch format directly by reversing convert_pytorch_state_dict_to_flax.
Would this alternative method be a better option?
Hi @Lime-Cakes, that's a perfect approach in my opinion :)
thanks @Lime-Cakes ❤ this is super cool news 🔥 if you need tester I can test
thanks @Lime-Cakes heart this is super cool news fire if you need tester I can test
Hi @Lime-Cakes, that's a perfect approach in my opinion :)
https://github.com/camenduru/diffusers/blob/from_flax_v2/src/diffusers/modeling_pytorch_flax_utils.py
https://huggingface.co/camenduru/plushies-pt
@camenduru, it seems like you got it working? :-) Very nice! Would you like to open a PR to add your conversion script?
Hi @patrickvonplaten 👋 ok
Linking here for reference: https://github.com/huggingface/diffusers/pull/1900