FlaxStableDiffusionPipeline.from_pretrained() fails to load SD-2
Describe the bug
I am unable to load SD-2 with FlaxStableDiffusionPipeline.from_pretrained(). FlaxStableDiffusionPipeline is unable to load SD-2 because FlaxStableDiffusionPipeline.from_pretrained() looking for an integer value for 'attention_head_dim', while the new SD-2 model config has a list of values for 'attention_head_dim'.
After making hacky modifications to accomodate the list of 'attention_head_dim' values, I run into an out of memory error. This OOM error persists after adding modifications in by @pcuenca 's flax-sd-2 branch.
Reproduction
To reproduce: import diffusers and dependencies
pip install --upgrade git+https://github.com/huggingface/diffusers.git transformers accelerate scipy
pip install flax
Then try to set up the Pipeline
from diffusers import FlaxStableDiffusionPipeline
repo_id = "stabilityai/stable-diffusion-2"
device = "cuda"
import torch
pipe , params = FlaxStableDiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16",from_pt=True)
The error received is:
OSError: Error no file named flax_model.msgpack or pytorch_model.bin found in directory
/root/.cache/huggingface/diffusers/models--stabilityai--stable-diffusion-2/snapshots/f97795c9354774aaf9087079c984be
0291f82ae0/text_encoder.
I noticed that this error can be avoided if create a pretrain pipeline with pytorch first and then try to create a flax pipeline like this:
from diffusers import DiffusionPipeline, FlaxStableDiffusionPipeline
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16")
pipe , params = FlaxStableDiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16",from_pt=True)
However, then if you create a pipe with DiffusionPipeline and then try create a pipeline with FlaxStableDiffusionPipeline you run into the error │ /usr/local/lib/python3.7/dist-packages/diffusers/models/unet_2d_blocks_flax.py:69 in setup │ │ │ │ 66 │ │ │ attn_block = FlaxTransformer2DModel( │ │ 67 │ │ │ │ in_channels=self.out_channels, │ │ 68 │ │ │ │ n_heads=self.attn_num_head_channels, │ │ ❱ 69 │ │ │ │ d_head=self.out_channels // self.attn_num_head_channels, │ │ 70 │ │ │ │ depth=1, │ │ 71 │ │ │ │ dtype=self.dtype, │ │ 72 │ │ │ ) │
TypeError: unsupported operand type(s) for //: 'int' and 'tuple'
This error comes from the fact that the value associated to the key 'attention_head_dim' in unet/config.json is a list but FlaxTransformer2DModel is expecting an integer. I tried to hardcode n_heads = 64 and d_head = self.out_channels//64 into the initialization arguments for FlaxTransformer2DModel both in lines 68/69, 201/202 , 331/332 of diffusers/models/unet_2d_blocks_flax.py. After making this change I then re-ran the following:
from diffusers import FlaxStableDiffusionPipeline
pipe , params = FlaxStableDiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16",from_pt=True)
And received an out of memory error, with logs displayed in the section below. Any suggestions would be greatly appreciated! These errors also persisted after implementing the changes in @pcuenca 's flax-sd-2 branch.
Logs
/usr/local/lib/python3.7/dist-packages/diffusers/utils/deprecation_utils.py:35: FutureWarning: It is deprecated to pass a pretrained model name or path to `from_config`.If you were trying to load a model, please use <class 'diffusers.models.unet_2d_condition_flax.FlaxUNet2DConditionModel'>.load_config(...) followed by <class 'diffusers.models.unet_2d_condition_flax.FlaxUNet2DConditionModel'>.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary instead. This functionality will be removed in v1.0.0.
warnings.warn(warning + message, FutureWarning)
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ <ipython-input-1-6166fe597f80>:5 in <module> │
│ │
│ /usr/local/lib/python3.7/dist-packages/diffusers/pipeline_flax_utils.py:456 in from_pretrained │
│ │
│ 453 │ │ │ │ │ loaded_sub_model = cached_folder │
│ 454 │ │ │ │ │
│ 455 │ │ │ │ if issubclass(class_obj, FlaxModelMixin): │
│ ❱ 456 │ │ │ │ │ loaded_sub_model, loaded_params = load_method(loadable_folder, from_ │
│ 457 │ │ │ │ │ params[name] = loaded_params │
│ 458 │ │ │ │ elif is_transformers_available() and issubclass(class_obj, FlaxPreTraine │
│ 459 │ │ │ │ │ if from_pt: │
│ │
│ /usr/local/lib/python3.7/dist-packages/diffusers/modeling_flax_utils.py:409 in from_pretrained │
│ │
│ 406 │ │ │ pytorch_model_file = load_state_dict(model_file) │
│ 407 │ │ │ │
│ 408 │ │ │ # Step 2: Convert the weights │
│ ❱ 409 │ │ │ state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model) │
│ 410 │ │ else: │
│ 411 │ │ │ try: │
│ 412 │ │ │ │ with open(model_file, "rb") as state_f: │
│ │
│ /usr/local/lib/python3.7/dist-packages/diffusers/modeling_flax_pytorch_utils.py:94 in │
│ convert_pytorch_state_dict_to_flax │
│ │
│ 91 │ pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} │
│ 92 │ │
│ 93 │ # Step 2: Since the model is stateless, get random Flax params │
│ ❱ 94 │ random_flax_params = flax_model.init_weights(PRNGKey(init_key)) │
│ 95 │ │
│ 96 │ random_flax_state_dict = flatten_dict(random_flax_params) │
│ 97 │ flax_state_dict = {} │
│ │
│ /usr/local/lib/python3.7/dist-packages/diffusers/models/unet_2d_condition_flax.py:118 in │
│ init_weights │
│ │
│ 115 │ │ params_rng, dropout_rng = jax.random.split(rng) │
│ 116 │ │ rngs = {"params": params_rng, "dropout": dropout_rng} │
│ 117 │ │ │
│ ❱ 118 │ │ return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"] │
│ 119 │ │
│ 120 │ def setup(self): │
│ 121 │ │ block_out_channels = self.block_out_channels │
│ │
│ /usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py:162 in │
│ reraise_with_filtered_traceback │
│ │
│ 159 def reraise_with_filtered_traceback(*args, **kwargs): │
│ 160 │ __tracebackhide__ = True │
│ 161 │ try: │
│ ❱ 162 │ return fun(*args, **kwargs) │
│ 163 │ except Exception as e: │
│ 164 │ mode = filtering_mode() │
│ 165 │ if is_under_reraiser(e) or mode == "off": │
│ │
│ /usr/local/lib/python3.7/dist-packages/flax/linen/module.py:1387 in init │
│ │
│ 1384 │ │ method=method, │
│ 1385 │ │ mutable=mutable, │
│ 1386 │ │ capture_intermediates=capture_intermediates, │
│ ❱ 1387 │ │ **kwargs) │
│ 1388 │ return v_out │
│ 1389 │
│ 1390 @property │
│ │
│ /usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py:162 in │
│ reraise_with_filtered_traceback │
│ │
│ 159 def reraise_with_filtered_traceback(*args, **kwargs): │
│ 160 │ __tracebackhide__ = True │
│ 161 │ try: │
│ ❱ 162 │ return fun(*args, **kwargs) │
│ 163 │ except Exception as e: │
│ 164 │ mode = filtering_mode() │
│ 165 │ if is_under_reraiser(e) or mode == "off": │
│ │
│ /usr/local/lib/python3.7/dist-packages/flax/linen/module.py:1340 in init_with_output │
│ │
│ 1337 │ │ self, │
│ 1338 │ │ mutable=mutable, │
│ 1339 │ │ capture_intermediates=capture_intermediates │
│ ❱ 1340 │ )(rngs, *args, **kwargs) │
│ 1341 │
│ 1342 @traceback_util.api_boundary │
│ 1343 def init(self, │
│ │
│ /usr/local/lib/python3.7/dist-packages/flax/core/scope.py:898 in wrapper │
│ │
│ 895 │ rngs = {'params': rngs} │
│ 896 │ init_flags = {**(flags if flags is not None else {}), 'initializing': True} │
│ 897 │ return apply(fn, mutable=mutable, flags=init_flags)({}, *args, rngs=rngs, │
│ ❱ 898 │ │ │ │ │ │ │ │ │ │ │ │ │ │ **kwargs) │
│ 899 │
│ 900 return wrapper │
│ 901 │
│ │
│ /usr/local/lib/python3.7/dist-packages/flax/core/scope.py:865 in wrapper │
│ │
│ 862 │ │
│ 863 │ with bind(variables, rngs=rngs, mutable=mutable, │
│ 864 │ │ │ flags=flags).temporary() as root: │
│ ❱ 865 │ y = fn(root, *args, **kwargs) │
│ 866 │ if mutable is not False: │
│ 867 │ return y, root.mutable_variables() │
│ 868 │ else: │
│ │
│ /usr/local/lib/python3.7/dist-packages/flax/linen/module.py:1798 in scope_fn │
│ │
│ 1795 def scope_fn(scope, *args, **kwargs): │
│ 1796 │ _context.capture_stack.append(capture_intermediates) │
│ 1797 │ try: │
│ ❱ 1798 │ return fn(module.clone(parent=scope), *args, **kwargs) │
│ 1799 │ finally: │
│ 1800 │ _context.capture_stack.pop() │
│ 1801 │
│ │
│ /usr/local/lib/python3.7/dist-packages/flax/linen/module.py:411 in wrapped_module_method │
│ │
│ 408 │ # otherwise call the wrapped function as is. │
│ 409 │ if args and isinstance(args[0], Module): │
│ 410 │ self, args = args[0], args[1:] │
│ ❱ 411 │ return self._call_wrapped_method(fun, args, kwargs) │
│ 412 │ else: │
│ 413 │ return fun(*args, **kwargs) │
│ 414 wrapped_module_method.method_handler_wrapped = True # type: ignore[attr-defined] │
│ │
│ /usr/local/lib/python3.7/dist-packages/flax/linen/module.py:735 in _call_wrapped_method │
│ │
│ 732 │ # call method │
│ 733 │ if _use_named_call: │
│ 734 │ │ with jax.named_scope(_derive_profiling_name(self, fun)): │
│ ❱ 735 │ │ y = fun(self, *args, **kwargs) │
│ 736 │ else: │
│ 737 │ │ y = fun(self, *args, **kwargs) │
│ 738 │
│ │
│ /usr/local/lib/python3.7/dist-packages/diffusers/models/unet_2d_condition_flax.py:265 in │
│ __call__ │
│ │
│ 262 │ │ down_block_res_samples = (sample,) │
│ 263 │ │ for down_block in self.down_blocks: │
│ 264 │ │ │ if isinstance(down_block, FlaxCrossAttnDownBlock2D): │
│ ❱ 265 │ │ │ │ sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, d │
│ 266 │ │ │ else: │
│ 267 │ │ │ │ sample, res_samples = down_block(sample, t_emb, deterministic=not train) │
│ 268 │ │ │ down_block_res_samples += res_samples │
│ │
│ /usr/local/lib/python3.7/dist-packages/flax/linen/module.py:411 in wrapped_module_method │
│ │
│ 408 │ # otherwise call the wrapped function as is. │
│ 409 │ if args and isinstance(args[0], Module): │
│ 410 │ self, args = args[0], args[1:] │
│ ❱ 411 │ return self._call_wrapped_method(fun, args, kwargs) │
│ 412 │ else: │
│ 413 │ return fun(*args, **kwargs) │
│ 414 wrapped_module_method.method_handler_wrapped = True # type: ignore[attr-defined] │
│ │
│ /usr/local/lib/python3.7/dist-packages/flax/linen/module.py:735 in _call_wrapped_method │
│ │
│ 732 │ # call method │
│ 733 │ if _use_named_call: │
│ 734 │ │ with jax.named_scope(_derive_profiling_name(self, fun)): │
│ ❱ 735 │ │ y = fun(self, *args, **kwargs) │
│ 736 │ else: │
│ 737 │ │ y = fun(self, *args, **kwargs) │
│ 738 │
│ │
│ /usr/local/lib/python3.7/dist-packages/diffusers/models/unet_2d_blocks_flax.py:86 in __call__ │
│ │
│ 83 │ │ │
│ 84 │ │ for resnet, attn in zip(self.resnets, self.attentions): │
│ 85 │ │ │ hidden_states = resnet(hidden_states, temb, deterministic=deterministic) │
│ ❱ 86 │ │ │ hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=det │
│ 87 │ │ │ output_states += (hidden_states,) │
│ 88 │ │ │
│ 89 │ │ if self.add_downsample: │
│ │
│ /usr/local/lib/python3.7/dist-packages/flax/linen/module.py:411 in wrapped_module_method │
│ │
│ 408 │ # otherwise call the wrapped function as is. │
│ 409 │ if args and isinstance(args[0], Module): │
│ 410 │ self, args = args[0], args[1:] │
│ ❱ 411 │ return self._call_wrapped_method(fun, args, kwargs) │
│ 412 │ else: │
│ 413 │ return fun(*args, **kwargs) │
│ 414 wrapped_module_method.method_handler_wrapped = True # type: ignore[attr-defined] │
│ │
│ /usr/local/lib/python3.7/dist-packages/flax/linen/module.py:735 in _call_wrapped_method │
│ │
│ 732 │ # call method │
│ 733 │ if _use_named_call: │
│ 734 │ │ with jax.named_scope(_derive_profiling_name(self, fun)): │
│ ❱ 735 │ │ y = fun(self, *args, **kwargs) │
│ 736 │ else: │
│ 737 │ │ y = fun(self, *args, **kwargs) │
│ 738 │
│ │
│ /usr/local/lib/python3.7/dist-packages/diffusers/models/attention_flax.py:233 in __call__ │
│ │
│ 230 │ │ │ hidden_states = hidden_states.reshape(batch, height * width, channels) │
│ 231 │ │ │
│ 232 │ │ for transformer_block in self.transformer_blocks: │
│ ❱ 233 │ │ │ hidden_states = transformer_block(hidden_states, context, deterministic=dete │
│ 234 │ │ │
│ 235 │ │ if self.use_linear_projection: │
│ 236 │ │ │ hidden_states = self.proj_out(hidden_states) │
│ │
│ /usr/local/lib/python3.7/dist-packages/flax/linen/module.py:411 in wrapped_module_method │
│ │
│ 408 │ # otherwise call the wrapped function as is. │
│ 409 │ if args and isinstance(args[0], Module): │
│ 410 │ self, args = args[0], args[1:] │
│ ❱ 411 │ return self._call_wrapped_method(fun, args, kwargs) │
│ 412 │ else: │
│ 413 │ return fun(*args, **kwargs) │
│ 414 wrapped_module_method.method_handler_wrapped = True # type: ignore[attr-defined] │
│ │
│ /usr/local/lib/python3.7/dist-packages/flax/linen/module.py:735 in _call_wrapped_method │
│ │
│ 732 │ # call method │
│ 733 │ if _use_named_call: │
│ 734 │ │ with jax.named_scope(_derive_profiling_name(self, fun)): │
│ ❱ 735 │ │ y = fun(self, *args, **kwargs) │
│ 736 │ else: │
│ 737 │ │ y = fun(self, *args, **kwargs) │
│ 738 │
│ │
│ /usr/local/lib/python3.7/dist-packages/diffusers/models/attention_flax.py:136 in __call__ │
│ │
│ 133 │ │ if self.only_cross_attention: │
│ 134 │ │ │ hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic │
│ 135 │ │ else: │
│ ❱ 136 │ │ │ hidden_states = self.attn1(self.norm1(hidden_states), deterministic=determin │
│ 137 │ │ hidden_states = hidden_states + residual │
│ 138 │ │ │
│ 139 │ │ # cross attention │
│ │
│ /usr/local/lib/python3.7/dist-packages/flax/linen/module.py:411 in wrapped_module_method │
│ │
│ 408 │ # otherwise call the wrapped function as is. │
│ 409 │ if args and isinstance(args[0], Module): │
│ 410 │ self, args = args[0], args[1:] │
│ ❱ 411 │ return self._call_wrapped_method(fun, args, kwargs) │
│ 412 │ else: │
│ 413 │ return fun(*args, **kwargs) │
│ 414 wrapped_module_method.method_handler_wrapped = True # type: ignore[attr-defined] │
│ │
│ /usr/local/lib/python3.7/dist-packages/flax/linen/module.py:735 in _call_wrapped_method │
│ │
│ 732 │ # call method │
│ 733 │ if _use_named_call: │
│ 734 │ │ with jax.named_scope(_derive_profiling_name(self, fun)): │
│ ❱ 735 │ │ y = fun(self, *args, **kwargs) │
│ 736 │ else: │
│ 737 │ │ y = fun(self, *args, **kwargs) │
│ 738 │
│ │
│ /usr/local/lib/python3.7/dist-packages/diffusers/models/attention_flax.py:80 in __call__ │
│ │
│ 77 │ │ value_states = self.reshape_heads_to_batch_dim(value_proj) │
│ 78 │ │ │
│ 79 │ │ # compute attentions │
│ ❱ 80 │ │ attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states) │
│ 81 │ │ attention_scores = attention_scores * self.scale │
│ 82 │ │ attention_probs = nn.softmax(attention_scores, axis=2) │
│ 83 │
│ │
│ /usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py:3066 in einsum │
│ │
│ 3063 │
│ 3064 _einsum_computation = jax.named_call( │
│ 3065 │ _einsum, name=spec) if spec is not None else _einsum │
│ ❱ 3066 return _einsum_computation(operands, contractions, precision) │
│ 3067 │
│ 3068 # Enable other modules to override einsum_contact_path. │
│ 3069 # Indexed by the type of the non constant dimension │
│ │
│ /usr/lib/python3.7/contextlib.py:74 in inner │
│ │
│ 71 │ │ @wraps(func) │
│ 72 │ │ def inner(*args, **kwds): │
│ 73 │ │ │ with self._recreate_cm(): │
│ ❱ 74 │ │ │ │ return func(*args, **kwds) │
│ 75 │ │ return inner │
│ 76 │
│ 77 │
│ │
│ /usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py:162 in │
│ reraise_with_filtered_traceback │
│ │
│ 159 def reraise_with_filtered_traceback(*args, **kwargs): │
│ 160 │ __tracebackhide__ = True │
│ 161 │ try: │
│ ❱ 162 │ return fun(*args, **kwargs) │
│ 163 │ except Exception as e: │
│ 164 │ mode = filtering_mode() │
│ 165 │ if is_under_reraiser(e) or mode == "off": │
│ │
│ /usr/local/lib/python3.7/dist-packages/jax/_src/api.py:622 in cache_miss │
│ │
│ 619 │ execute = None │
│ 620 │ if isinstance(top_trace, core.EvalTrace) and not ( │
│ 621 │ │ jax.config.jax_debug_nans or jax.config.jax_debug_infs): │
│ ❱ 622 │ execute = dispatch._xla_call_impl_lazy(fun_, *tracers, **params) │
│ 623 │ out_flat = call_bind_continuation(execute(*args_flat)) │
│ 624 │ else: │
│ 625 │ out_flat = call_bind_continuation( │
│ │
│ /usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py:237 in _xla_call_impl_lazy │
│ │
│ 234 │ raise NotImplementedError('Dynamic shapes do not work with Array.') │
│ 235 │ arg_specs = [(None, getattr(x, '_device', None)) for x in args] │
│ 236 return xla_callable(fun, device, backend, name, donated_invars, keep_unused, │
│ ❱ 237 │ │ │ │ │ *arg_specs) │
│ 238 │
│ 239 │
│ 240 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, │
│ │
│ /usr/local/lib/python3.7/dist-packages/jax/linear_util.py:303 in memoized_fun │
│ │
│ 300 │ ans, stores = result │
│ 301 │ fun.populate_stores(stores) │
│ 302 │ else: │
│ ❱ 303 │ ans = call(fun, *args) │
│ 304 │ cache[key] = (ans, fun.stores) │
│ 305 │ │
│ 306 │ return ans │
│ │
│ /usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py:360 in _xla_callable_uncached │
│ │
│ 357 │ return computation.compile(_allow_propagation_to_outputs=True).unsafe_call │
│ 358 else: │
│ 359 │ return lower_xla_callable(fun, device, backend, name, donated_invars, False, │
│ ❱ 360 │ │ │ │ │ │ │ keep_unused, *arg_specs).compile().unsafe_call │
│ 361 │
│ 362 xla_callable = lu.cache(_xla_callable_uncached) │
│ 363 │
│ │
│ /usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py:998 in compile │
│ │
│ 995 │ │ assert self._out_type is not None │
│ 996 │ │ self._executable = XlaCompiledComputation.from_xla_computation( │
│ 997 │ │ │ self.name, self._hlo, self._in_type, self._out_type, │
│ ❱ 998 │ │ │ **self.compile_args) │
│ 999 │ │
│ 1000 │ return self._executable │
│ 1001 │
│ │
│ /usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py:1195 in from_xla_computation │
│ │
│ 1192 │ with log_elapsed_time(f"Finished XLA compilation of {name} " │
│ 1193 │ │ │ │ │ │ "in {elapsed_time} sec"): │
│ 1194 │ compiled = compile_or_get_cached(backend, xla_computation, options, │
│ ❱ 1195 │ │ │ │ │ │ │ │ │ host_callbacks) │
│ 1196 │ buffer_counts = get_buffer_counts(out_avals, ordered_effects, │
│ 1197 │ │ │ │ │ │ │ │ │ has_unordered_effects) │
│ 1198 │ execute = _execute_compiled if nreps == 1 else _execute_replicated │
│ │
│ /usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py:1078 in compile_or_get_cached │
│ │
│ 1075 │ return compiled │
│ 1076 │
│ 1077 return backend_compile(backend, serialized_computation, compile_options, │
│ ❱ 1078 │ │ │ │ │ │ host_callbacks) │
│ 1079 │
│ 1080 │
│ 1081 def _cache_read(computation: Union[str, bytes, ir.Module], module_name: str, │
│ │
│ /usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py:314 in wrapper │
│ │
│ 311 @wraps(func) │
│ 312 def wrapper(*args, **kwargs): │
│ 313 │ with TraceAnnotation(name, **decorator_kwargs): │
│ ❱ 314 │ return func(*args, **kwargs) │
│ 315 │ return wrapper │
│ 316 return wrapper │
│ 317 │
│ │
│ /usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py:1012 in backend_compile │
│ │
│ 1009 # Some backends don't have `host_callbacks` option yet │
│ 1010 # TODO(sharadmv): remove this fallback when all backends allow `compile` │
│ 1011 # to take in `host_callbacks` │
│ ❱ 1012 return backend.compile(built_c, compile_options=options) │
│ 1013 │
│ 1014 # TODO(phawkins): update users. │
│ 1015 xla.backend_compile = backend_compile │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to
allocate 21760049152 bytes.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ <ipython-input-1-6166fe597f80>:5 in <module> │
│ │
│ /usr/local/lib/python3.7/dist-packages/diffusers/pipeline_flax_utils.py:456 in from_pretrained │
│ │
│ 453 │ │ │ │ │ loaded_sub_model = cached_folder │
│ 454 │ │ │ │ │
│ 455 │ │ │ │ if issubclass(class_obj, FlaxModelMixin): │
│ ❱ 456 │ │ │ │ │ loaded_sub_model, loaded_params = load_method(loadable_folder, from_ │
│ 457 │ │ │ │ │ params[name] = loaded_params │
│ 458 │ │ │ │ elif is_transformers_available() and issubclass(class_obj, FlaxPreTraine │
│ 459 │ │ │ │ │ if from_pt: │
│ │
│ /usr/local/lib/python3.7/dist-packages/diffusers/modeling_flax_utils.py:409 in from_pretrained │
│ │
│ 406 │ │ │ pytorch_model_file = load_state_dict(model_file) │
│ 407 │ │ │ │
│ 408 │ │ │ # Step 2: Convert the weights │
│ ❱ 409 │ │ │ state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model) │
│ 410 │ │ else: │
│ 411 │ │ │ try: │
│ 412 │ │ │ │ with open(model_file, "rb") as state_f: │
│ │
│ /usr/local/lib/python3.7/dist-packages/diffusers/modeling_flax_pytorch_utils.py:94 in │
│ convert_pytorch_state_dict_to_flax │
│ │
│ 91 │ pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} │
│ 92 │ │
│ 93 │ # Step 2: Since the model is stateless, get random Flax params │
│ ❱ 94 │ random_flax_params = flax_model.init_weights(PRNGKey(init_key)) │
│ 95 │ │
│ 96 │ random_flax_state_dict = flatten_dict(random_flax_params) │
│ 97 │ flax_state_dict = {} │
│ │
│ /usr/local/lib/python3.7/dist-packages/flax/linen/module.py:411 in wrapped_module_method │
│ │
│ 408 │ # otherwise call the wrapped function as is. │
│ 409 │ if args and isinstance(args[0], Module): │
│ 410 │ self, args = args[0], args[1:] │
│ ❱ 411 │ return self._call_wrapped_method(fun, args, kwargs) │
│ 412 │ else: │
│ 413 │ return fun(*args, **kwargs) │
│ 414 wrapped_module_method.method_handler_wrapped = True # type: ignore[attr-defined] │
│ │
│ /usr/local/lib/python3.7/dist-packages/flax/linen/module.py:735 in _call_wrapped_method │
│ │
│ 732 │ # call method │
│ 733 │ if _use_named_call: │
│ 734 │ │ with jax.named_scope(_derive_profiling_name(self, fun)): │
│ ❱ 735 │ │ y = fun(self, *args, **kwargs) │
│ 736 │ else: │
│ 737 │ │ y = fun(self, *args, **kwargs) │
│ 738 │
│ │
│ /usr/local/lib/python3.7/dist-packages/diffusers/models/unet_2d_condition_flax.py:118 in │
│ init_weights │
│ │
│ 115 │ │ params_rng, dropout_rng = jax.random.split(rng) │
│ 116 │ │ rngs = {"params": params_rng, "dropout": dropout_rng} │
│ 117 │ │ │
│ ❱ 118 │ │ return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"] │
│ 119 │ │
│ 120 │ def setup(self): │
│ 121 │ │ block_out_channels = self.block_out_channels │
│ │
│ /usr/local/lib/python3.7/dist-packages/diffusers/models/unet_2d_condition_flax.py:265 in │
│ __call__ │
│ │
│ 262 │ │ down_block_res_samples = (sample,) │
│ 263 │ │ for down_block in self.down_blocks: │
│ 264 │ │ │ if isinstance(down_block, FlaxCrossAttnDownBlock2D): │
│ ❱ 265 │ │ │ │ sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, d │
│ 266 │ │ │ else: │
│ 267 │ │ │ │ sample, res_samples = down_block(sample, t_emb, deterministic=not train) │
│ 268 │ │ │ down_block_res_samples += res_samples │
│ │
│ /usr/local/lib/python3.7/dist-packages/diffusers/models/unet_2d_blocks_flax.py:86 in __call__ │
│ │
│ 83 │ │ │
│ 84 │ │ for resnet, attn in zip(self.resnets, self.attentions): │
│ 85 │ │ │ hidden_states = resnet(hidden_states, temb, deterministic=deterministic) │
│ ❱ 86 │ │ │ hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=det │
│ 87 │ │ │ output_states += (hidden_states,) │
│ 88 │ │ │
│ 89 │ │ if self.add_downsample: │
│ │
│ /usr/local/lib/python3.7/dist-packages/diffusers/models/attention_flax.py:233 in __call__ │
│ │
│ 230 │ │ │ hidden_states = hidden_states.reshape(batch, height * width, channels) │
│ 231 │ │ │
│ 232 │ │ for transformer_block in self.transformer_blocks: │
│ ❱ 233 │ │ │ hidden_states = transformer_block(hidden_states, context, deterministic=dete │
│ 234 │ │ │
│ 235 │ │ if self.use_linear_projection: │
│ 236 │ │ │ hidden_states = self.proj_out(hidden_states) │
│ │
│ /usr/local/lib/python3.7/dist-packages/diffusers/models/attention_flax.py:136 in __call__ │
│ │
│ 133 │ │ if self.only_cross_attention: │
│ 134 │ │ │ hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic │
│ 135 │ │ else: │
│ ❱ 136 │ │ │ hidden_states = self.attn1(self.norm1(hidden_states), deterministic=determin │
│ 137 │ │ hidden_states = hidden_states + residual │
│ 138 │ │ │
│ 139 │ │ # cross attention │
│ │
│ /usr/local/lib/python3.7/dist-packages/diffusers/models/attention_flax.py:80 in __call__ │
│ │
│ 77 │ │ value_states = self.reshape_heads_to_batch_dim(value_proj) │
│ 78 │ │ │
│ 79 │ │ # compute attentions │
│ ❱ 80 │ │ attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states) │
│ 81 │ │ attention_scores = attention_scores * self.scale │
│ 82 │ │ attention_probs = nn.softmax(attention_scores, axis=2) │
│ 83 │
│ │
│ /usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py:3066 in einsum │
│ │
│ 3063 │
│ 3064 _einsum_computation = jax.named_call( │
│ 3065 │ _einsum, name=spec) if spec is not None else _einsum │
│ ❱ 3066 return _einsum_computation(operands, contractions, precision) │
│ 3067 │
│ 3068 # Enable other modules to override einsum_contact_path. │
│ 3069 # Indexed by the type of the non constant dimension │
│ │
│ /usr/lib/python3.7/contextlib.py:74 in inner │
│ │
│ 71 │ │ @wraps(func) │
│ 72 │ │ def inner(*args, **kwds): │
│ 73 │ │ │ with self._recreate_cm(): │
│ ❱ 74 │ │ │ │ return func(*args, **kwds) │
│ 75 │ │ return inner │
│ 76 │
│ 77 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 21760049152 bytes.
System Info
Google colab
-
diffusersversion: 0.9.0 - Platform: Linux-5.10.133+-x86_64-with-Ubuntu-18.04-bionic
- Python version: 3.7.15
- PyTorch version (GPU?): 1.12.1+cu113 (True)
- Huggingface_hub version: 0.11.0
- Transformers version: 4.24.0
- Using GPU in script?:
- Using distributed or parallel set-up in script?: