diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

Flax models cannot be loaded if they don't contain dtype

Open borisdayma opened this issue 3 years ago • 2 comments

Describe the bug

When creating a flax model, dtype needs to be defined as a required arg for the model definition.

It should be possible to define them without a dtype, for example if the model wraps 2 submodules that require a different dtype (or don't require any at all).

Reproduction

The below code returns an error:

from diffusers import ConfigMixin, FlaxModelMixin
from diffusers.configuration_utils import flax_register_to_config
from flax import linen as nn

@flax_register_to_config
class MyClass(nn.Module, FlaxModelMixin, ConfigMixin):
    pass

# create an instance
instance = MyClass()

# save config
instance.save_config("test")

# load config
instance.from_config("test")

If I define a fake key with dtype then the model can be loaded.

Logs

Error:


File ~/diffusers/src/diffusers/configuration_utils.py:195, in ConfigMixin.from_config(cls, config, return_unused_kwargs, **kwargs)
    192     deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
    193     config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
--> 195 init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
    197 # Allow dtype to be specified on initialization
    198 if "dtype" in unused_kwargs:

File ~/diffusers/src/diffusers/configuration_utils.py:399, in ConfigMixin.extract_init_dict(cls, config_dict, **kwargs)
    397 if hasattr(cls, "_flax_internal_args"):
    398     for arg in cls._flax_internal_args:
--> 399         expected_keys.remove(arg)
    401 # 2. Remove attributes that cannot be expected from expected config attributes
    402 # remove keys to be ignored
    403 if len(cls.ignore_for_config) > 0:

KeyError: 'dtype'

System Info

  • diffusers version: 0.8.0.dev0 (installed from source)
  • Platform: Linux-5.4.0-1043-gcp-x86_64-with-glibc2.31
  • Python version: 3.9.7
  • PyTorch version (GPU?): 1.12.1+cu102 (False)
  • Huggingface_hub version: 0.10.1
  • Transformers version: 4.25.0.dev0
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: N/A

borisdayma avatar Nov 16 '22 23:11 borisdayma

Thanks for opening an issue here @borisdayma!

Would you be willing to open a PR to fix it? :-)

patrickvonplaten avatar Nov 20 '22 18:11 patrickvonplaten

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Dec 17 '22 15:12 github-actions[bot]