diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

Convert Flax trained model to PyTorch

Open jorahn opened this issue 3 years ago • 19 comments

Based on #1161. I've mostly taken and adapted Flax->PT code from transformers.

My test scenario is currently this:

from diffusers import FlaxStableDiffusionPipeline, StableDiffusionPipeline
from diffusers import AutoencoderKL, UNet2DConditionModel
import jax.numpy as jnp

pipe, params = FlaxStableDiffusionPipeline.from_pretrained('CompVis/stable-diffusion-v1-4', revision='bf16', dtype=jnp.bfloat16)
pipe.save_pretrained('output-flax', params)


vae = AutoencoderKL.from_pretrained('output-flax/vae', from_flax=True)
unet = UNet2DConditionModel.from_pretrained('output-flax/unet', from_flax=True)
pipe2 = StableDiffusionPipeline.from_pretrained('output-flax', from_flax=True)

VAE and UNet load, but I have yet to test, if they run without problems.

The last line of my test code for the pipeline currently fails with:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ test.py:11 in <module>                                                                           │
│                                                                                                  │
│    8                                                                                             │
│    9 vae = AutoencoderKL.from_pretrained('output-flax/vae', from_flax=True)                      │
│   10 unet = UNet2DConditionModel.from_pretrained('output-flax/unet', from_flax=True)             │
│ ❱ 11 pipe2 = StableDiffusionPipeline.from_pretrained('output-flax', from_flax=True)              │
│   12                                                                                             │
│                                                                                                  │
│ diffusers/src/diffusers/pipeline_utils.py:632 in                                                 │
│ from_pretrained                                                                                  │
│                                                                                                  │
│   629 │   │   │   │                                                                              │
│   630 │   │   │   │   # check if the module is in a subdirectory                                 │
│   631 │   │   │   │   if os.path.isdir(os.path.join(cached_folder, name)):                       │
│ ❱ 632 │   │   │   │   │   loaded_sub_model = load_method(os.path.join(cached_folder, name), **   │
│   633 │   │   │   │   else:                                                                      │
│   634 │   │   │   │   │   # else load from the root directory                                    │
│   635 │   │   │   │   │   loaded_sub_model = load_method(cached_folder, **loading_kwargs)        │
│                                                                                                  │
│ miniconda3/lib/python3.9/site-packages/transformers/modeling_utils.py:1815 in                    │
│ from_pretrained                                                                                  │
│                                                                                                  │
│   1812 │   │   │   │   │   │   "weights."                                                        │
│   1813 │   │   │   │   │   )                                                                     │
│   1814 │   │   │   │   elif os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME):      │
│ ❱ 1815 │   │   │   │   │   raise EnvironmentError(                                               │
│   1816 │   │   │   │   │   │   f"Error no file named {WEIGHTS_NAME} found in directory {pretrai  │
│   1817 │   │   │   │   │   │   "there is a file for Flax weights. Use `from_flax=True` to load   │
│   1818 │   │   │   │   │   │   "weights."                                                        │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
OSError: Error no file named pytorch_model.bin found in directory output-flax/text_encoder but there is a file for Flax weights. Use `from_flax=True` to load this model from those weights.

Not sure if this needs fixing in diffusers or in transformers, though I assume it should be possible to get this to work without modifying transformers.

@patrickvonplaten

jorahn avatar Nov 10 '22 07:11 jorahn

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@jorahn ~~I haven't looked at your code yet, but the issue may be that diffusers uses a different name for the weights: https://github.com/huggingface/diffusers/blob/main/src/diffusers/utils/init.py#L68.~~

Sorry I think you already got that.

pcuenca avatar Nov 10 '22 11:11 pcuenca

yes, it's currently not in the right place (should go into the __init__.py you mention), but i created a variable for that

jorahn avatar Nov 10 '22 11:11 jorahn

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

this runs without exceptions now:

from diffusers import FlaxStableDiffusionPipeline, StableDiffusionPipeline
from diffusers import AutoencoderKL, UNet2DConditionModel
import jax.numpy as jnp

pipe, params = FlaxStableDiffusionPipeline.from_pretrained('CompVis/stable-diffusion-v1-4', revision='bf16', dtype=jnp.bfloat16)
pipe.save_pretrained('output-flax', params)

pipe2 = StableDiffusionPipeline.from_pretrained('output-flax', from_flax=True, safety_checker=None)

still need to test if it actually produces images.

jorahn avatar Nov 10 '22 12:11 jorahn

I am testing...

camenduru avatar Nov 10 '22 13:11 camenduru

😥

RuntimeError                              Traceback (most recent call last)
[<ipython-input-3-185344f22e90>](https://localhost:8080/#) in <module>
      1 get_ipython().system('pip install diffusers transformers -qq')
      2 from diffusers import StableDiffusionPipeline
----> 3 pipe = StableDiffusionPipeline.from_pretrained("camenduru/test", safety_checker=None).to("cuda")

2 frames
[/usr/local/lib/python3.7/dist-packages/transformers/modeling_utils.py](https://localhost:8080/#) in _load_pretrained_model(cls, model, state_dict, loaded_keys, resolved_archive_file, pretrained_model_name_or_path, ignore_mismatched_sizes, sharded_metadata, _fast_init, low_cpu_mem_usage, device_map, offload_folder, offload_state_dict, dtype, load_in_8bit)
   2593                     "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
   2594                 )
-> 2595             raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
   2596 
   2597         if len(unexpected_keys) > 0:

RuntimeError: Error(s) in loading state_dict for CLIPTextModel:
	While copying the parameter named "text_model.embeddings.token_embedding.weight", whose dimensions in the model are torch.Size([49408, 768]) and whose dimensions in the checkpoint are torch.Size([49408, 768]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).
	While copying the parameter named "text_model.embeddings.position_embedding.weight", whose dimensions in the model are torch.Size([77, 768]) and whose dimensions in the checkpoint are torch.Size([77, 768]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).
	While copying the parameter named "text_model.encoder.layers.0.self_attn.k_proj.weight", whose dimensions in the model are torch.Size([768, 768]) and whose dimensions in the checkpoint are torch.Size([768, 768]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).
	While copying the parameter named "text_model.encoder.layers.0.self_attn.k_proj.bias", whose dimensions in the model are torch.Size([768]) and whose dimensions in the checkpoint are torch.Size([768]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).
	While copying the parameter named "text_model.encoder.layers.0.self_attn.v_proj.weight", whose dimensions in the model are torch.Size([768, 768]) and whose dimensions in the checkpoint are torch.Size([768, 768]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).
	While copying the parameter named "text_model.encoder.layers.0.self_attn.v_proj.bias", whose dimensions in the model are torch.Size([768]) and whose dimensions in the checkpoint are torch.Size([768]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).
	While copying the parameter named "text_model.encoder.layers.0.self_attn.q_proj.weight", whose dimensions in the model are torch.Size([768, 768]) and whose dimensions in the checkpoint are torch.Size([768, 768]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).
	While copying the parameter named "text_model.encoder.layers.0.self_attn.q_proj.bias", whose dimensions in the model are torch.Size([768]) and whose dimensions in the checkpoint are torch.Size([768]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).
	While copying the parameter named "text_model.encoder.layers.0.self_attn.out_proj.weight", whose dimensions in the model are torch.Size([768, 768]) and whose dimensions in the checkpoint are torch.Size([768, 768]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).
	While copying the parameter named "text_model.encoder.layers.0.self_attn.out_proj.bias", whose dimensions in the model are torch.Size([768]) and whose dimensions in the checkpoint are torch.Size([768]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).
	While copying the parameter named "text_model.encoder.layers.0.layer_norm1.weight", whose dimensions in the model are torch.Size([768]) and whose dimensions in the checkpoint are torch.Size([768]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).
	While copying the parameter named "text_model.encoder.layers.0.layer_norm1.bias", whose dimensions in the model are torch.Size([768]) and whose dimensions in the checkpoint are torch.Size([768]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).
	While copying the parameter named "text_model.encoder.layers.0.mlp.fc1.weight", whose dimensions in the model are torch.Size([3072, 768]) and whose dimensions in the checkpoint are torch.Size([3072, 768]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).
	While copying the parameter named "text_model.encoder.layers.0.mlp.fc1.bias", whose dimensions in the model are torch.Size([3072]) and whose dimensions in the checkpoint are torch.Size([3072]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).
	While copying the parameter named "text_model.encoder.layers.0.mlp.fc2.weight", whose dimensions in the model are torch.Size([768, 3072]) and whose dimensions in the checkpoint are torch.Size([768, 3072]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).
	While copying the parameter named "text_model.encoder.layers.0.mlp.fc2.bias", whose dimensions in the model are torch.Size([768]) and whose dimensions in the checkpoint are torch.Size([768]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).

camenduru avatar Nov 10 '22 13:11 camenduru

this is my updated test scenario:

from diffusers import FlaxStableDiffusionPipeline, StableDiffusionPipeline
from diffusers import AutoencoderKL, UNet2DConditionModel
import jax.numpy as jnp
import jax
from flax.jax_utils import replicate
from flax.training.common_utils import shard

prompt = 'an astronaut on a horse'

rng = jax.random.PRNGKey(42)
rng = jax.random.split(rng, jax.device_count())
prompt_jax = [prompt] * jax.device_count()

pipe, params = FlaxStableDiffusionPipeline.from_pretrained('CompVis/stable-diffusion-v1-4', revision='bf16', dtype=jnp.bfloat16)
prompt_ids = pipe.prepare_inputs(prompt_jax)
p_params = replicate(params)
prompt_ids = shard(prompt_ids)

images = pipe(prompt_ids, p_params, rng, jit=True, num_inference_steps=10)[0]
images = images.reshape((images.shape[0],) + images.shape[-3:])
images = pipe.numpy_to_pil(images)
images[0].save('test_flax.jpg')
pipe.save_pretrained('output-flax', params)
del pipe

pipe2 = StableDiffusionPipeline.from_pretrained('output-flax', from_flax=True, safety_checker=None)
pipe2(prompt, num_inference_steps=10).images[0].save('test_pt.jpg')

the first image looks fine test_flax

the second doesn't test_pt

so i assume there is something wrong with the conversion of the weights.

and if i don't set safety_checker=None it also doesn't work, since it tries to import FlaxStableDiffusionSafetyChecker from transformers instead of stable-diffusion

jorahn avatar Nov 10 '22 13:11 jorahn

@jorahn I add the model here https://huggingface.co/camenduru/test if you need

camenduru avatar Nov 10 '22 13:11 camenduru

Looks like you made some nice progress - let me or @pcuenca know if you need help / are stuck somewhere :-)

patrickvonplaten avatar Nov 15 '22 22:11 patrickvonplaten

hi @patrickvonplaten yes we need help, maybe the problem is this https://github.com/huggingface/diffusers/pull/1217#issuecomment-1312739419 please tell us how to adapt the old function to diffusers

camenduru avatar Nov 15 '22 23:11 camenduru

I'm pretty busy right now, but happy to take a look at the end of the week :)

pcuenca avatar Nov 16 '22 10:11 pcuenca

Thank you!

jorahn avatar Nov 16 '22 11:11 jorahn

Also quite busy at the moment - @jorahn essentially we should copy-paste the functionality from PyTorch here and then have to debug that each component can be correctly converted

patrickvonplaten avatar Nov 16 '22 16:11 patrickvonplaten

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Dec 11 '22 15:12 github-actions[bot]

I found an alternative, easier but hacky method. Still testing it at the moment. But it should work allow saving flax unet/vae/any model that uses convert_pytorch_state_dict_to_flax in pytorch format. Instead of implementing loading flax on pytorch model's from_pretrained, I wrote a script to save the flax params as a pytorch format directly by reversing convert_pytorch_state_dict_to_flax.

Would this alternative method be a better option?

Lime-Cakes avatar Dec 15 '22 19:12 Lime-Cakes

Hi @Lime-Cakes, that's a perfect approach in my opinion :)

pcuenca avatar Dec 15 '22 20:12 pcuenca

thanks @Lime-Cakes ❤ this is super cool news 🔥 if you need tester I can test

camenduru avatar Dec 15 '22 21:12 camenduru

thanks @Lime-Cakes heart this is super cool news fire if you need tester I can test

Hi @Lime-Cakes, that's a perfect approach in my opinion :)

Pull request here

Lime-Cakes avatar Dec 17 '22 11:12 Lime-Cakes

download https://github.com/camenduru/diffusers/blob/from_flax_v2/src/diffusers/modeling_pytorch_flax_utils.py https://huggingface.co/camenduru/plushies-pt

camenduru avatar Dec 24 '22 07:12 camenduru

@camenduru, it seems like you got it working? :-) Very nice! Would you like to open a PR to add your conversion script?

patrickvonplaten avatar Jan 03 '23 11:01 patrickvonplaten

Hi @patrickvonplaten 👋 ok

camenduru avatar Jan 03 '23 12:01 camenduru

Linking here for reference: https://github.com/huggingface/diffusers/pull/1900

patrickvonplaten avatar Jan 04 '23 21:01 patrickvonplaten