diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

FlaxStableDiffusionPipeline.from_pretrained() fails to load SD-2

Open lk-wq opened this issue 3 years ago • 0 comments

Describe the bug

I am unable to load SD-2 with FlaxStableDiffusionPipeline.from_pretrained(). FlaxStableDiffusionPipeline is unable to load SD-2 because FlaxStableDiffusionPipeline.from_pretrained() looking for an integer value for 'attention_head_dim', while the new SD-2 model config has a list of values for 'attention_head_dim'.

After making hacky modifications to accomodate the list of 'attention_head_dim' values, I run into an out of memory error. This OOM error persists after adding modifications in by @pcuenca 's flax-sd-2 branch.

Reproduction

To reproduce: import diffusers and dependencies

pip install --upgrade git+https://github.com/huggingface/diffusers.git transformers accelerate scipy
pip install flax

Then try to set up the Pipeline

from diffusers import FlaxStableDiffusionPipeline
repo_id = "stabilityai/stable-diffusion-2"
device = "cuda"
import torch
pipe , params = FlaxStableDiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16",from_pt=True)

The error received is:

OSError: Error no file named flax_model.msgpack or pytorch_model.bin found in directory 
/root/.cache/huggingface/diffusers/models--stabilityai--stable-diffusion-2/snapshots/f97795c9354774aaf9087079c984be
0291f82ae0/text_encoder.

I noticed that this error can be avoided if create a pretrain pipeline with pytorch first and then try to create a flax pipeline like this:

from diffusers import DiffusionPipeline, FlaxStableDiffusionPipeline
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16")
pipe , params = FlaxStableDiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16",from_pt=True)

However, then if you create a pipe with DiffusionPipeline and then try create a pipeline with FlaxStableDiffusionPipeline you run into the error │ /usr/local/lib/python3.7/dist-packages/diffusers/models/unet_2d_blocks_flax.py:69 in setup │ │ │ │ 66 │ │ │ attn_block = FlaxTransformer2DModel( │ │ 67 │ │ │ │ in_channels=self.out_channels, │ │ 68 │ │ │ │ n_heads=self.attn_num_head_channels, │ │ ❱ 69 │ │ │ │ d_head=self.out_channels // self.attn_num_head_channels, │ │ 70 │ │ │ │ depth=1, │ │ 71 │ │ │ │ dtype=self.dtype, │ │ 72 │ │ │ ) │

TypeError: unsupported operand type(s) for //: 'int' and 'tuple'

This error comes from the fact that the value associated to the key 'attention_head_dim' in unet/config.json is a list but FlaxTransformer2DModel is expecting an integer. I tried to hardcode n_heads = 64 and d_head = self.out_channels//64 into the initialization arguments for FlaxTransformer2DModel both in lines 68/69, 201/202 , 331/332 of diffusers/models/unet_2d_blocks_flax.py. After making this change I then re-ran the following:

from diffusers import FlaxStableDiffusionPipeline
pipe , params = FlaxStableDiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16",from_pt=True)

And received an out of memory error, with logs displayed in the section below. Any suggestions would be greatly appreciated! These errors also persisted after implementing the changes in @pcuenca 's flax-sd-2 branch.

Logs

/usr/local/lib/python3.7/dist-packages/diffusers/utils/deprecation_utils.py:35: FutureWarning: It is deprecated to pass a pretrained model name or path to `from_config`.If you were trying to load a model, please use <class 'diffusers.models.unet_2d_condition_flax.FlaxUNet2DConditionModel'>.load_config(...) followed by <class 'diffusers.models.unet_2d_condition_flax.FlaxUNet2DConditionModel'>.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary instead. This functionality will be removed in v1.0.0.
  warnings.warn(warning + message, FutureWarning)
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ <ipython-input-1-6166fe597f80>:5 in <module>                                                     │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/diffusers/pipeline_flax_utils.py:456 in from_pretrained   │
│                                                                                                  │
│   453 │   │   │   │   │   loaded_sub_model = cached_folder                                       │
│   454 │   │   │   │                                                                              │
│   455 │   │   │   │   if issubclass(class_obj, FlaxModelMixin):                                  │
│ ❱ 456 │   │   │   │   │   loaded_sub_model, loaded_params = load_method(loadable_folder, from_   │
│   457 │   │   │   │   │   params[name] = loaded_params                                           │
│   458 │   │   │   │   elif is_transformers_available() and issubclass(class_obj, FlaxPreTraine   │
│   459 │   │   │   │   │   if from_pt:                                                            │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/diffusers/modeling_flax_utils.py:409 in from_pretrained   │
│                                                                                                  │
│   406 │   │   │   pytorch_model_file = load_state_dict(model_file)                               │
│   407 │   │   │                                                                                  │
│   408 │   │   │   # Step 2: Convert the weights                                                  │
│ ❱ 409 │   │   │   state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model)          │
│   410 │   │   else:                                                                              │
│   411 │   │   │   try:                                                                           │
│   412 │   │   │   │   with open(model_file, "rb") as state_f:                                    │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/diffusers/modeling_flax_pytorch_utils.py:94 in            │
│ convert_pytorch_state_dict_to_flax                                                               │
│                                                                                                  │
│    91 │   pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}                       │
│    92 │                                                                                          │
│    93 │   # Step 2: Since the model is stateless, get random Flax params                         │
│ ❱  94 │   random_flax_params = flax_model.init_weights(PRNGKey(init_key))                        │
│    95 │                                                                                          │
│    96 │   random_flax_state_dict = flatten_dict(random_flax_params)                              │
│    97 │   flax_state_dict = {}                                                                   │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/diffusers/models/unet_2d_condition_flax.py:118 in         │
│ init_weights                                                                                     │
│                                                                                                  │
│   115 │   │   params_rng, dropout_rng = jax.random.split(rng)                                    │
│   116 │   │   rngs = {"params": params_rng, "dropout": dropout_rng}                              │
│   117 │   │                                                                                      │
│ ❱ 118 │   │   return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"]         │
│   119 │                                                                                          │
│   120 │   def setup(self):                                                                       │
│   121 │   │   block_out_channels = self.block_out_channels                                       │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py:162 in                         │
│ reraise_with_filtered_traceback                                                                  │
│                                                                                                  │
│   159   def reraise_with_filtered_traceback(*args, **kwargs):                                    │
│   160 │   __tracebackhide__ = True                                                               │
│   161 │   try:                                                                                   │
│ ❱ 162 │     return fun(*args, **kwargs)                                                          │
│   163 │   except Exception as e:                                                                 │
│   164 │     mode = filtering_mode()                                                              │
│   165 │     if is_under_reraiser(e) or mode == "off":                                            │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/flax/linen/module.py:1387 in init                         │
│                                                                                                  │
│   1384 │   │   method=method,                                                                    │
│   1385 │   │   mutable=mutable,                                                                  │
│   1386 │   │   capture_intermediates=capture_intermediates,                                      │
│ ❱ 1387 │   │   **kwargs)                                                                         │
│   1388 │   return v_out                                                                          │
│   1389                                                                                           │
│   1390   @property                                                                               │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py:162 in                         │
│ reraise_with_filtered_traceback                                                                  │
│                                                                                                  │
│   159   def reraise_with_filtered_traceback(*args, **kwargs):                                    │
│   160 │   __tracebackhide__ = True                                                               │
│   161 │   try:                                                                                   │
│ ❱ 162 │     return fun(*args, **kwargs)                                                          │
│   163 │   except Exception as e:                                                                 │
│   164 │     mode = filtering_mode()                                                              │
│   165 │     if is_under_reraiser(e) or mode == "off":                                            │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/flax/linen/module.py:1340 in init_with_output             │
│                                                                                                  │
│   1337 │   │   self,                                                                             │
│   1338 │   │   mutable=mutable,                                                                  │
│   1339 │   │   capture_intermediates=capture_intermediates                                       │
│ ❱ 1340 │   )(rngs, *args, **kwargs)                                                              │
│   1341                                                                                           │
│   1342   @traceback_util.api_boundary                                                            │
│   1343   def init(self,                                                                          │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/flax/core/scope.py:898 in wrapper                         │
│                                                                                                  │
│   895 │     rngs = {'params': rngs}                                                              │
│   896 │   init_flags = {**(flags if flags is not None else {}), 'initializing': True}            │
│   897 │   return apply(fn, mutable=mutable, flags=init_flags)({}, *args, rngs=rngs,              │
│ ❱ 898 │   │   │   │   │   │   │   │   │   │   │   │   │   │   **kwargs)                          │
│   899                                                                                            │
│   900   return wrapper                                                                           │
│   901                                                                                            │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/flax/core/scope.py:865 in wrapper                         │
│                                                                                                  │
│   862 │                                                                                          │
│   863 │   with bind(variables, rngs=rngs, mutable=mutable,                                       │
│   864 │   │   │     flags=flags).temporary() as root:                                            │
│ ❱ 865 │     y = fn(root, *args, **kwargs)                                                        │
│   866 │   if mutable is not False:                                                               │
│   867 │     return y, root.mutable_variables()                                                   │
│   868 │   else:                                                                                  │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/flax/linen/module.py:1798 in scope_fn                     │
│                                                                                                  │
│   1795   def scope_fn(scope, *args, **kwargs):                                                   │
│   1796 │   _context.capture_stack.append(capture_intermediates)                                  │
│   1797 │   try:                                                                                  │
│ ❱ 1798 │     return fn(module.clone(parent=scope), *args, **kwargs)                              │
│   1799 │   finally:                                                                              │
│   1800 │     _context.capture_stack.pop()                                                        │
│   1801                                                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/flax/linen/module.py:411 in wrapped_module_method         │
│                                                                                                  │
│    408 │   # otherwise call the wrapped function as is.                                          │
│    409 │   if args and isinstance(args[0], Module):                                              │
│    410 │     self, args = args[0], args[1:]                                                      │
│ ❱  411 │     return self._call_wrapped_method(fun, args, kwargs)                                 │
│    412 │   else:                                                                                 │
│    413 │     return fun(*args, **kwargs)                                                         │
│    414   wrapped_module_method.method_handler_wrapped = True  # type: ignore[attr-defined]       │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/flax/linen/module.py:735 in _call_wrapped_method          │
│                                                                                                  │
│    732 │     # call method                                                                       │
│    733 │     if _use_named_call:                                                                 │
│    734 │   │   with jax.named_scope(_derive_profiling_name(self, fun)):                          │
│ ❱  735 │   │     y = fun(self, *args, **kwargs)                                                  │
│    736 │     else:                                                                               │
│    737 │   │   y = fun(self, *args, **kwargs)                                                    │
│    738                                                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/diffusers/models/unet_2d_condition_flax.py:265 in         │
│ __call__                                                                                         │
│                                                                                                  │
│   262 │   │   down_block_res_samples = (sample,)                                                 │
│   263 │   │   for down_block in self.down_blocks:                                                │
│   264 │   │   │   if isinstance(down_block, FlaxCrossAttnDownBlock2D):                           │
│ ❱ 265 │   │   │   │   sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, d   │
│   266 │   │   │   else:                                                                          │
│   267 │   │   │   │   sample, res_samples = down_block(sample, t_emb, deterministic=not train)   │
│   268 │   │   │   down_block_res_samples += res_samples                                          │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/flax/linen/module.py:411 in wrapped_module_method         │
│                                                                                                  │
│    408 │   # otherwise call the wrapped function as is.                                          │
│    409 │   if args and isinstance(args[0], Module):                                              │
│    410 │     self, args = args[0], args[1:]                                                      │
│ ❱  411 │     return self._call_wrapped_method(fun, args, kwargs)                                 │
│    412 │   else:                                                                                 │
│    413 │     return fun(*args, **kwargs)                                                         │
│    414   wrapped_module_method.method_handler_wrapped = True  # type: ignore[attr-defined]       │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/flax/linen/module.py:735 in _call_wrapped_method          │
│                                                                                                  │
│    732 │     # call method                                                                       │
│    733 │     if _use_named_call:                                                                 │
│    734 │   │   with jax.named_scope(_derive_profiling_name(self, fun)):                          │
│ ❱  735 │   │     y = fun(self, *args, **kwargs)                                                  │
│    736 │     else:                                                                               │
│    737 │   │   y = fun(self, *args, **kwargs)                                                    │
│    738                                                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/diffusers/models/unet_2d_blocks_flax.py:86 in __call__    │
│                                                                                                  │
│    83 │   │                                                                                      │
│    84 │   │   for resnet, attn in zip(self.resnets, self.attentions):                            │
│    85 │   │   │   hidden_states = resnet(hidden_states, temb, deterministic=deterministic)       │
│ ❱  86 │   │   │   hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=det   │
│    87 │   │   │   output_states += (hidden_states,)                                              │
│    88 │   │                                                                                      │
│    89 │   │   if self.add_downsample:                                                            │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/flax/linen/module.py:411 in wrapped_module_method         │
│                                                                                                  │
│    408 │   # otherwise call the wrapped function as is.                                          │
│    409 │   if args and isinstance(args[0], Module):                                              │
│    410 │     self, args = args[0], args[1:]                                                      │
│ ❱  411 │     return self._call_wrapped_method(fun, args, kwargs)                                 │
│    412 │   else:                                                                                 │
│    413 │     return fun(*args, **kwargs)                                                         │
│    414   wrapped_module_method.method_handler_wrapped = True  # type: ignore[attr-defined]       │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/flax/linen/module.py:735 in _call_wrapped_method          │
│                                                                                                  │
│    732 │     # call method                                                                       │
│    733 │     if _use_named_call:                                                                 │
│    734 │   │   with jax.named_scope(_derive_profiling_name(self, fun)):                          │
│ ❱  735 │   │     y = fun(self, *args, **kwargs)                                                  │
│    736 │     else:                                                                               │
│    737 │   │   y = fun(self, *args, **kwargs)                                                    │
│    738                                                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/diffusers/models/attention_flax.py:233 in __call__        │
│                                                                                                  │
│   230 │   │   │   hidden_states = hidden_states.reshape(batch, height * width, channels)         │
│   231 │   │                                                                                      │
│   232 │   │   for transformer_block in self.transformer_blocks:                                  │
│ ❱ 233 │   │   │   hidden_states = transformer_block(hidden_states, context, deterministic=dete   │
│   234 │   │                                                                                      │
│   235 │   │   if self.use_linear_projection:                                                     │
│   236 │   │   │   hidden_states = self.proj_out(hidden_states)                                   │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/flax/linen/module.py:411 in wrapped_module_method         │
│                                                                                                  │
│    408 │   # otherwise call the wrapped function as is.                                          │
│    409 │   if args and isinstance(args[0], Module):                                              │
│    410 │     self, args = args[0], args[1:]                                                      │
│ ❱  411 │     return self._call_wrapped_method(fun, args, kwargs)                                 │
│    412 │   else:                                                                                 │
│    413 │     return fun(*args, **kwargs)                                                         │
│    414   wrapped_module_method.method_handler_wrapped = True  # type: ignore[attr-defined]       │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/flax/linen/module.py:735 in _call_wrapped_method          │
│                                                                                                  │
│    732 │     # call method                                                                       │
│    733 │     if _use_named_call:                                                                 │
│    734 │   │   with jax.named_scope(_derive_profiling_name(self, fun)):                          │
│ ❱  735 │   │     y = fun(self, *args, **kwargs)                                                  │
│    736 │     else:                                                                               │
│    737 │   │   y = fun(self, *args, **kwargs)                                                    │
│    738                                                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/diffusers/models/attention_flax.py:136 in __call__        │
│                                                                                                  │
│   133 │   │   if self.only_cross_attention:                                                      │
│   134 │   │   │   hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic   │
│   135 │   │   else:                                                                              │
│ ❱ 136 │   │   │   hidden_states = self.attn1(self.norm1(hidden_states), deterministic=determin   │
│   137 │   │   hidden_states = hidden_states + residual                                           │
│   138 │   │                                                                                      │
│   139 │   │   # cross attention                                                                  │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/flax/linen/module.py:411 in wrapped_module_method         │
│                                                                                                  │
│    408 │   # otherwise call the wrapped function as is.                                          │
│    409 │   if args and isinstance(args[0], Module):                                              │
│    410 │     self, args = args[0], args[1:]                                                      │
│ ❱  411 │     return self._call_wrapped_method(fun, args, kwargs)                                 │
│    412 │   else:                                                                                 │
│    413 │     return fun(*args, **kwargs)                                                         │
│    414   wrapped_module_method.method_handler_wrapped = True  # type: ignore[attr-defined]       │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/flax/linen/module.py:735 in _call_wrapped_method          │
│                                                                                                  │
│    732 │     # call method                                                                       │
│    733 │     if _use_named_call:                                                                 │
│    734 │   │   with jax.named_scope(_derive_profiling_name(self, fun)):                          │
│ ❱  735 │   │     y = fun(self, *args, **kwargs)                                                  │
│    736 │     else:                                                                               │
│    737 │   │   y = fun(self, *args, **kwargs)                                                    │
│    738                                                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/diffusers/models/attention_flax.py:80 in __call__         │
│                                                                                                  │
│    77 │   │   value_states = self.reshape_heads_to_batch_dim(value_proj)                         │
│    78 │   │                                                                                      │
│    79 │   │   # compute attentions                                                               │
│ ❱  80 │   │   attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)     │
│    81 │   │   attention_scores = attention_scores * self.scale                                   │
│    82 │   │   attention_probs = nn.softmax(attention_scores, axis=2)                             │
│    83                                                                                            │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py:3066 in einsum                │
│                                                                                                  │
│   3063                                                                                           │
│   3064   _einsum_computation = jax.named_call(                                                   │
│   3065 │     _einsum, name=spec) if spec is not None else _einsum                                │
│ ❱ 3066   return _einsum_computation(operands, contractions, precision)                           │
│   3067                                                                                           │
│   3068 # Enable other modules to override einsum_contact_path.                                   │
│   3069 # Indexed by the type of the non constant dimension                                       │
│                                                                                                  │
│ /usr/lib/python3.7/contextlib.py:74 in inner                                                     │
│                                                                                                  │
│    71 │   │   @wraps(func)                                                                       │
│    72 │   │   def inner(*args, **kwds):                                                          │
│    73 │   │   │   with self._recreate_cm():                                                      │
│ ❱  74 │   │   │   │   return func(*args, **kwds)                                                 │
│    75 │   │   return inner                                                                       │
│    76                                                                                            │
│    77                                                                                            │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py:162 in                         │
│ reraise_with_filtered_traceback                                                                  │
│                                                                                                  │
│   159   def reraise_with_filtered_traceback(*args, **kwargs):                                    │
│   160 │   __tracebackhide__ = True                                                               │
│   161 │   try:                                                                                   │
│ ❱ 162 │     return fun(*args, **kwargs)                                                          │
│   163 │   except Exception as e:                                                                 │
│   164 │     mode = filtering_mode()                                                              │
│   165 │     if is_under_reraiser(e) or mode == "off":                                            │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/jax/_src/api.py:622 in cache_miss                         │
│                                                                                                  │
│    619 │   execute = None                                                                        │
│    620 │   if isinstance(top_trace, core.EvalTrace) and not (                                    │
│    621 │   │   jax.config.jax_debug_nans or jax.config.jax_debug_infs):                          │
│ ❱  622 │     execute = dispatch._xla_call_impl_lazy(fun_, *tracers, **params)                    │
│    623 │     out_flat = call_bind_continuation(execute(*args_flat))                              │
│    624 │   else:                                                                                 │
│    625 │     out_flat = call_bind_continuation(                                                  │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py:237 in _xla_call_impl_lazy           │
│                                                                                                  │
│    234 │     raise NotImplementedError('Dynamic shapes do not work with Array.')                 │
│    235 │   arg_specs = [(None, getattr(x, '_device', None)) for x in args]                       │
│    236   return xla_callable(fun, device, backend, name, donated_invars, keep_unused,            │
│ ❱  237 │   │   │   │   │     *arg_specs)                                                         │
│    238                                                                                           │
│    239                                                                                           │
│    240 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,                      │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/jax/linear_util.py:303 in memoized_fun                    │
│                                                                                                  │
│   300 │     ans, stores = result                                                                 │
│   301 │     fun.populate_stores(stores)                                                          │
│   302 │   else:                                                                                  │
│ ❱ 303 │     ans = call(fun, *args)                                                               │
│   304 │     cache[key] = (ans, fun.stores)                                                       │
│   305 │                                                                                          │
│   306 │   return ans                                                                             │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py:360 in _xla_callable_uncached        │
│                                                                                                  │
│    357 │   return computation.compile(_allow_propagation_to_outputs=True).unsafe_call            │
│    358   else:                                                                                   │
│    359 │   return lower_xla_callable(fun, device, backend, name, donated_invars, False,          │
│ ❱  360 │   │   │   │   │   │   │     keep_unused, *arg_specs).compile().unsafe_call              │
│    361                                                                                           │
│    362 xla_callable = lu.cache(_xla_callable_uncached)                                           │
│    363                                                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py:998 in compile                       │
│                                                                                                  │
│    995 │   │   assert self._out_type is not None                                                 │
│    996 │   │   self._executable = XlaCompiledComputation.from_xla_computation(                   │
│    997 │   │   │   self.name, self._hlo, self._in_type, self._out_type,                          │
│ ❱  998 │   │   │   **self.compile_args)                                                          │
│    999 │                                                                                         │
│   1000 │   return self._executable                                                               │
│   1001                                                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py:1195 in from_xla_computation         │
│                                                                                                  │
│   1192 │   with log_elapsed_time(f"Finished XLA compilation of {name} "                          │
│   1193 │   │   │   │   │   │     "in {elapsed_time} sec"):                                       │
│   1194 │     compiled = compile_or_get_cached(backend, xla_computation, options,                 │
│ ❱ 1195 │   │   │   │   │   │   │   │   │      host_callbacks)                                    │
│   1196 │   buffer_counts = get_buffer_counts(out_avals, ordered_effects,                         │
│   1197 │   │   │   │   │   │   │   │   │     has_unordered_effects)                              │
│   1198 │   execute = _execute_compiled if nreps == 1 else _execute_replicated                    │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py:1078 in compile_or_get_cached        │
│                                                                                                  │
│   1075 │     return compiled                                                                     │
│   1076                                                                                           │
│   1077   return backend_compile(backend, serialized_computation, compile_options,                │
│ ❱ 1078 │   │   │   │   │   │    host_callbacks)                                                  │
│   1079                                                                                           │
│   1080                                                                                           │
│   1081 def _cache_read(computation: Union[str, bytes, ir.Module], module_name: str,              │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py:314 in wrapper                       │
│                                                                                                  │
│   311   @wraps(func)                                                                             │
│   312   def wrapper(*args, **kwargs):                                                            │
│   313 │   with TraceAnnotation(name, **decorator_kwargs):                                        │
│ ❱ 314 │     return func(*args, **kwargs)                                                         │
│   315 │   return wrapper                                                                         │
│   316   return wrapper                                                                           │
│   317                                                                                            │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py:1012 in backend_compile              │
│                                                                                                  │
│   1009   # Some backends don't have `host_callbacks` option yet                                  │
│   1010   # TODO(sharadmv): remove this fallback when all backends allow `compile`                │
│   1011   # to take in `host_callbacks`                                                           │
│ ❱ 1012   return backend.compile(built_c, compile_options=options)                                │
│   1013                                                                                           │
│   1014 # TODO(phawkins): update users.                                                           │
│   1015 xla.backend_compile = backend_compile                                                     │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to 
allocate 21760049152 bytes.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ <ipython-input-1-6166fe597f80>:5 in <module>                                                     │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/diffusers/pipeline_flax_utils.py:456 in from_pretrained   │
│                                                                                                  │
│   453 │   │   │   │   │   loaded_sub_model = cached_folder                                       │
│   454 │   │   │   │                                                                              │
│   455 │   │   │   │   if issubclass(class_obj, FlaxModelMixin):                                  │
│ ❱ 456 │   │   │   │   │   loaded_sub_model, loaded_params = load_method(loadable_folder, from_   │
│   457 │   │   │   │   │   params[name] = loaded_params                                           │
│   458 │   │   │   │   elif is_transformers_available() and issubclass(class_obj, FlaxPreTraine   │
│   459 │   │   │   │   │   if from_pt:                                                            │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/diffusers/modeling_flax_utils.py:409 in from_pretrained   │
│                                                                                                  │
│   406 │   │   │   pytorch_model_file = load_state_dict(model_file)                               │
│   407 │   │   │                                                                                  │
│   408 │   │   │   # Step 2: Convert the weights                                                  │
│ ❱ 409 │   │   │   state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model)          │
│   410 │   │   else:                                                                              │
│   411 │   │   │   try:                                                                           │
│   412 │   │   │   │   with open(model_file, "rb") as state_f:                                    │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/diffusers/modeling_flax_pytorch_utils.py:94 in            │
│ convert_pytorch_state_dict_to_flax                                                               │
│                                                                                                  │
│    91 │   pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}                       │
│    92 │                                                                                          │
│    93 │   # Step 2: Since the model is stateless, get random Flax params                         │
│ ❱  94 │   random_flax_params = flax_model.init_weights(PRNGKey(init_key))                        │
│    95 │                                                                                          │
│    96 │   random_flax_state_dict = flatten_dict(random_flax_params)                              │
│    97 │   flax_state_dict = {}                                                                   │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/flax/linen/module.py:411 in wrapped_module_method         │
│                                                                                                  │
│    408 │   # otherwise call the wrapped function as is.                                          │
│    409 │   if args and isinstance(args[0], Module):                                              │
│    410 │     self, args = args[0], args[1:]                                                      │
│ ❱  411 │     return self._call_wrapped_method(fun, args, kwargs)                                 │
│    412 │   else:                                                                                 │
│    413 │     return fun(*args, **kwargs)                                                         │
│    414   wrapped_module_method.method_handler_wrapped = True  # type: ignore[attr-defined]       │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/flax/linen/module.py:735 in _call_wrapped_method          │
│                                                                                                  │
│    732 │     # call method                                                                       │
│    733 │     if _use_named_call:                                                                 │
│    734 │   │   with jax.named_scope(_derive_profiling_name(self, fun)):                          │
│ ❱  735 │   │     y = fun(self, *args, **kwargs)                                                  │
│    736 │     else:                                                                               │
│    737 │   │   y = fun(self, *args, **kwargs)                                                    │
│    738                                                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/diffusers/models/unet_2d_condition_flax.py:118 in         │
│ init_weights                                                                                     │
│                                                                                                  │
│   115 │   │   params_rng, dropout_rng = jax.random.split(rng)                                    │
│   116 │   │   rngs = {"params": params_rng, "dropout": dropout_rng}                              │
│   117 │   │                                                                                      │
│ ❱ 118 │   │   return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"]         │
│   119 │                                                                                          │
│   120 │   def setup(self):                                                                       │
│   121 │   │   block_out_channels = self.block_out_channels                                       │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/diffusers/models/unet_2d_condition_flax.py:265 in         │
│ __call__                                                                                         │
│                                                                                                  │
│   262 │   │   down_block_res_samples = (sample,)                                                 │
│   263 │   │   for down_block in self.down_blocks:                                                │
│   264 │   │   │   if isinstance(down_block, FlaxCrossAttnDownBlock2D):                           │
│ ❱ 265 │   │   │   │   sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, d   │
│   266 │   │   │   else:                                                                          │
│   267 │   │   │   │   sample, res_samples = down_block(sample, t_emb, deterministic=not train)   │
│   268 │   │   │   down_block_res_samples += res_samples                                          │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/diffusers/models/unet_2d_blocks_flax.py:86 in __call__    │
│                                                                                                  │
│    83 │   │                                                                                      │
│    84 │   │   for resnet, attn in zip(self.resnets, self.attentions):                            │
│    85 │   │   │   hidden_states = resnet(hidden_states, temb, deterministic=deterministic)       │
│ ❱  86 │   │   │   hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=det   │
│    87 │   │   │   output_states += (hidden_states,)                                              │
│    88 │   │                                                                                      │
│    89 │   │   if self.add_downsample:                                                            │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/diffusers/models/attention_flax.py:233 in __call__        │
│                                                                                                  │
│   230 │   │   │   hidden_states = hidden_states.reshape(batch, height * width, channels)         │
│   231 │   │                                                                                      │
│   232 │   │   for transformer_block in self.transformer_blocks:                                  │
│ ❱ 233 │   │   │   hidden_states = transformer_block(hidden_states, context, deterministic=dete   │
│   234 │   │                                                                                      │
│   235 │   │   if self.use_linear_projection:                                                     │
│   236 │   │   │   hidden_states = self.proj_out(hidden_states)                                   │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/diffusers/models/attention_flax.py:136 in __call__        │
│                                                                                                  │
│   133 │   │   if self.only_cross_attention:                                                      │
│   134 │   │   │   hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic   │
│   135 │   │   else:                                                                              │
│ ❱ 136 │   │   │   hidden_states = self.attn1(self.norm1(hidden_states), deterministic=determin   │
│   137 │   │   hidden_states = hidden_states + residual                                           │
│   138 │   │                                                                                      │
│   139 │   │   # cross attention                                                                  │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/diffusers/models/attention_flax.py:80 in __call__         │
│                                                                                                  │
│    77 │   │   value_states = self.reshape_heads_to_batch_dim(value_proj)                         │
│    78 │   │                                                                                      │
│    79 │   │   # compute attentions                                                               │
│ ❱  80 │   │   attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)     │
│    81 │   │   attention_scores = attention_scores * self.scale                                   │
│    82 │   │   attention_probs = nn.softmax(attention_scores, axis=2)                             │
│    83                                                                                            │
│                                                                                                  │
│ /usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py:3066 in einsum                │
│                                                                                                  │
│   3063                                                                                           │
│   3064   _einsum_computation = jax.named_call(                                                   │
│   3065 │     _einsum, name=spec) if spec is not None else _einsum                                │
│ ❱ 3066   return _einsum_computation(operands, contractions, precision)                           │
│   3067                                                                                           │
│   3068 # Enable other modules to override einsum_contact_path.                                   │
│   3069 # Indexed by the type of the non constant dimension                                       │
│                                                                                                  │
│ /usr/lib/python3.7/contextlib.py:74 in inner                                                     │
│                                                                                                  │
│    71 │   │   @wraps(func)                                                                       │
│    72 │   │   def inner(*args, **kwds):                                                          │
│    73 │   │   │   with self._recreate_cm():                                                      │
│ ❱  74 │   │   │   │   return func(*args, **kwds)                                                 │
│    75 │   │   return inner                                                                       │
│    76                                                                                            │
│    77                                                                                            │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 21760049152 bytes.

System Info

Google colab

  • diffusers version: 0.9.0
  • Platform: Linux-5.10.133+-x86_64-with-Ubuntu-18.04-bionic
  • Python version: 3.7.15
  • PyTorch version (GPU?): 1.12.1+cu113 (True)
  • Huggingface_hub version: 0.11.0
  • Transformers version: 4.24.0
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

lk-wq avatar Nov 25 '22 20:11 lk-wq