Flax models cannot be loaded if they don't contain dtype
Describe the bug
When creating a flax model, dtype needs to be defined as a required arg for the model definition.
It should be possible to define them without a dtype, for example if the model wraps 2 submodules that require a different dtype (or don't require any at all).
Reproduction
The below code returns an error:
from diffusers import ConfigMixin, FlaxModelMixin
from diffusers.configuration_utils import flax_register_to_config
from flax import linen as nn
@flax_register_to_config
class MyClass(nn.Module, FlaxModelMixin, ConfigMixin):
pass
# create an instance
instance = MyClass()
# save config
instance.save_config("test")
# load config
instance.from_config("test")
If I define a fake key with dtype then the model can be loaded.
Logs
Error:
File ~/diffusers/src/diffusers/configuration_utils.py:195, in ConfigMixin.from_config(cls, config, return_unused_kwargs, **kwargs)
192 deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
193 config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
--> 195 init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
197 # Allow dtype to be specified on initialization
198 if "dtype" in unused_kwargs:
File ~/diffusers/src/diffusers/configuration_utils.py:399, in ConfigMixin.extract_init_dict(cls, config_dict, **kwargs)
397 if hasattr(cls, "_flax_internal_args"):
398 for arg in cls._flax_internal_args:
--> 399 expected_keys.remove(arg)
401 # 2. Remove attributes that cannot be expected from expected config attributes
402 # remove keys to be ignored
403 if len(cls.ignore_for_config) > 0:
KeyError: 'dtype'
System Info
-
diffusersversion: 0.8.0.dev0 (installed from source) - Platform: Linux-5.4.0-1043-gcp-x86_64-with-glibc2.31
- Python version: 3.9.7
- PyTorch version (GPU?): 1.12.1+cu102 (False)
- Huggingface_hub version: 0.10.1
- Transformers version: 4.25.0.dev0
- Using GPU in script?: No
- Using distributed or parallel set-up in script?: N/A
Thanks for opening an issue here @borisdayma!
Would you be willing to open a PR to fix it? :-)
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.