InvokeAI icon indicating copy to clipboard operation
InvokeAI copied to clipboard

[enhancement]: Graphcore IPU Support

Open brucethemoose opened this issue 3 years ago • 5 comments

Is there an existing issue for this?

  • [X] I have searched the existing issues

Contact Details

InvokeAI Discord user of the same name.

What should this feature add?

Hello, since yall are migrating to diffusers anyway (#1583) , would you consider adding Graphcore IPU support as seen in the repo here?

https://www.graphcore.ai/posts/how-to-run-stable-diffusion-inference-on-ipus-with-paperspace

See ipu_models.py in the text-to-image demo, it looks like a fairly simple extension of StableDiffusionPipeline

I mention this because Paperspace is offering some particularly beefy free IPU instances now. Note that Nvidia claims an RTX 4090 is 165 FP16 tensor tflops: Screenshot_4

brucethemoose avatar Dec 24 '22 04:12 brucethemoose

(Note that I would make a crude attempt to hack this in myself and test it, but it looks like the diffusers implementation isn't done yet).

brucethemoose avatar Dec 24 '22 04:12 brucethemoose

Their pipeline code is here: https://github.com/gradient-ai/Graphcore-HuggingFace/blob/main/stable-diffusion/ipu_models.py

Seems totally plausible. I just don't know how to do it cleanly yet. So much of this stuff is very subclass-happy -- including the work-in-progress pipeline in #1583 -- and that's not the best for composition. They also monkey with the cross-attention code which might conflict with some of Invoke's own desire to monkey with cross-attention code.

If you want to work on this, I think a good next step would be to take a look at this new diffusers API for cross-attention: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py

See if you can use that to replace graphcore's current override_attention method. And if you can't, raise that as an issue on the diffusers repo, because that cross-attention API isn't released yet and that could be good design feedback for them.

keturn avatar Dec 24 '22 05:12 keturn

Hmmm, looks like the function descriptions are accurate: https://github.com/gradient-ai/Graphcore-HuggingFace/blob/main/stable-diffusion/ipu_models.py#L36

Overriding this implementation as the torch.baddbmm op is not registered.

I can't find it at the moment, but I read on some Graphcore blog post that torch.baddbmm is not supported on CPU. Replacing the attention OR sliced attention override with the diffusers versions always executes this, which trips it up:

https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py#L184

Cross Attention Error
---------------------------------------------------------------------------
  Error                                     Traceback (most recent call last)
  <ipython-input-10-3a9910a550ac> in <module>
  ----> 1 pipe("apple", height=image_height, width=image_width, num_inference_steps=25, guidance_scale=9);
  
  /usr/local/lib/python3.8/dist-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
       26         def decorate_context(*args, **kwargs):
       27             with self.__class__():
  ---> 28                 return func(*args, **kwargs)
       29         return cast(F, decorate_context)
       30 
  
  /usr/local/lib/python3.8/dist-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py in __call__(self, prompt, height, width, num_inference_steps, guidance_scale, negative_prompt, num_images_per_prompt, eta, generator, latents, output_type, return_dict, callback, callback_steps)
      527 
      528                 # predict the noise residual
  --> 529                 noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
      530 
      531                 # perform guidance
  
  /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
     1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
     1101                 or _global_forward_hooks or _global_forward_pre_hooks):
  -> 1102             return forward_call(*input, **kwargs)
     1103         # Do not call functions when jit is used
     1104         full_backward_hooks, non_full_backward_hooks = [], []
  
  /notebooks/stable-diffusion/ipu_models.py in forward(self, sample, timestep, encoder_hidden_states, return_dict)
      105         encoder_hidden_states = encoder_hidden_states.to(self.input_dtype)
      106 
  --> 107         ret = self.unet(sample, timestep, encoder_hidden_states)
      108 
      109         ret.sample = ret.sample.to(self.output_dtype)
  
  /usr/local/lib/python3.8/dist-packages/poptorch/_poplar_executor.py in __call__(self, *args, **kwargs)
      919 
      920         if not self.isCompiled():
  --> 921             self._compile(in_tensors)
      922 
      923         if not self._is_attached:
  
  /usr/local/lib/python3.8/dist-packages/poptorch/_impl.py in wrapper(self, *args, **kwargs)
      356         def wrapper(self, *args, **kwargs):
      357             with self._profiling.tracepoint(label):  # pylint: disable=protected-access
  --> 358                 return func(self, *args, **kwargs)
      359 
      360         return wrapper
  
  /usr/local/lib/python3.8/dist-packages/poptorch/_poplar_executor.py in _compile(self, in_tensors)
      644                 self._executable = poptorch_core.compileWithTrace(*trace_args)
      645             else:
  --> 646                 self._executable = self._compileWithDispatch(
      647                     in_tensors_trace_view)
      648 
  
  /usr/local/lib/python3.8/dist-packages/poptorch/_poplar_executor.py in _compileWithDispatch(self, in_tensors, executable_filename)
      593                                    **in_tensors.kwargs)
      594             else:
  --> 595                 ctx.compile(*in_tensors.args, **in_tensors.kwargs)
      596             self._outputs_structure = ctx.ipu._outputs_structure  # pylint: disable=protected-access
      597 
  
  /usr/local/lib/python3.8/dist-packages/poptorch/experimental.py in compile(self, *args, **kwargs)
      339 
      340     def compile(self, *args, **kwargs):
  --> 341         return self._compileOrLoadExecutable(args, kwargs)
      342 
      343     def loadExecutable(self, filename, *args, **kwargs):
  
  /usr/local/lib/python3.8/dist-packages/poptorch/_impl.py in wrapper(*args, **kwargs)
      162     def wrapper(*args, **kwargs):
      163         with OnExit():
  --> 164             return func(*args, **kwargs)
      165 
      166     return wrapper
  
  /usr/local/lib/python3.8/dist-packages/poptorch/experimental.py in _compileOrLoadExecutable(self, args, kwargs, filename)
      380                                                       tensor_args)
      381 
  --> 382             result = self.func(*args, **kwargs)
      383             if result is not None:
      384                 ipu.outputs(result)
  
  /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
     1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
     1101                 or _global_forward_hooks or _global_forward_pre_hooks):
  -> 1102             return forward_call(*input, **kwargs)
     1103         # Do not call functions when jit is used
     1104         full_backward_hooks, non_full_backward_hooks = [], []
  
  /usr/local/lib/python3.8/dist-packages/diffusers/models/unet_2d_condition.py in forward(self, sample, timestep, encoder_hidden_states, class_labels, attention_mask, return_dict)
      422         for downsample_block in self.down_blocks:
      423             if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
  --> 424                 sample, res_samples = downsample_block(
      425                     hidden_states=sample,
      426                     temb=emb,
  
  /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
     1118             input = bw_hook.setup_input_hook(input)
     1119 
  -> 1120         result = forward_call(*input, **kwargs)
     1121         if _global_forward_hooks or self._forward_hooks:
     1122             for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
  
  /usr/local/lib/python3.8/dist-packages/diffusers/models/unet_2d_blocks.py in forward(self, hidden_states, temb, encoder_hidden_states, attention_mask)
      775             else:
      776                 hidden_states = resnet(hidden_states, temb)
  --> 777                 hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
      778 
      779             output_states += (hidden_states,)
  
  /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
     1118             input = bw_hook.setup_input_hook(input)
     1119 
  -> 1120         result = forward_call(*input, **kwargs)
     1121         if _global_forward_hooks or self._forward_hooks:
     1122             for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
  
  /usr/local/lib/python3.8/dist-packages/diffusers/models/attention.py in forward(self, hidden_states, encoder_hidden_states, timestep, return_dict)
      214         # 2. Blocks
      215         for block in self.transformer_blocks:
  --> 216             hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep)
      217 
      218         # 3. Output
  
  /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
     1118             input = bw_hook.setup_input_hook(input)
     1119 
  -> 1120         result = forward_call(*input, **kwargs)
     1121         if _global_forward_hooks or self._forward_hooks:
     1122             for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
  
  /usr/local/lib/python3.8/dist-packages/diffusers/models/attention.py in forward(self, hidden_states, encoder_hidden_states, timestep, attention_mask)
      488             )
      489         else:
  --> 490             hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
      491 
      492         if self.attn2 is not None:
  
  /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
     1118             input = bw_hook.setup_input_hook(input)
     1119 
  -> 1120         result = forward_call(*input, **kwargs)
     1121         if _global_forward_hooks or self._forward_hooks:
     1122             for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
  
  /usr/local/lib/python3.8/dist-packages/diffusers/models/attention.py in forward(self, hidden_states, encoder_hidden_states, attention_mask)
      638                 hidden_states = self._attention(query, key, value, attention_mask)
      639             else:
  --> 640                 hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
      641 
      642         # linear proj
  
  /usr/local/lib/python3.8/dist-packages/diffusers/models/attention.py in _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask)
      695                 key_slice = key_slice.float()
      696 
  --> 697             attn_slice = torch.baddbmm(
      698                 torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
      699                 query_slice,
  
  Error: In poptorch/source/dispatch_tracer/dispatchers/MLIRDispatch.cpp:372: 'poptorch_cpp_error': No shape inference handler for aten::baddbmm

Not sure if there is an easy way to work around that.

brucethemoose avatar Dec 26 '22 00:12 brucethemoose

I worked with the mentioned code a few weeks ago (on optimum, but same attention overwrite). The code is for older attention (before diffuser refactored much of the attention code). It should be possible to get it to work without too much hassle, since it's just replacing attention with simpler/less efficient code.

Though, there is another major problem that require solving. You will need to compile the model and cache the complied model somewhere, for every possible resolution.

The first time the model is run, it will compile itself. This process takes around 15 minute on the free pod16 machine offered a while ago. After it's compiled, feeding it a different resolution will trigger an error. The only way to avoid that error is to delete the model and recompile.

The recompile can be skipped if you load a compiled model. So unless you want the user to wait 15 minute for an image each time they pick a different image resolution, caching all possible resolution is sort of the only option.

Lime-Cakes avatar Dec 28 '22 19:12 Lime-Cakes

I worked with the mentioned code a few weeks ago (on optimum, but same attention overwrite). The code is for older attention (before diffuser refactored much of the attention code). It should be possible to get it to work without too much hassle, since it's just replacing attention with simpler/less efficient code.

Though, there is another major problem that require solving. You will need to compile the model and cache the complied model somewhere, for every possible resolution.

The first time the model is run, it will compile itself. This process takes around 15 minute on the free pod16 machine offered a while ago. After it's compiled, feeding it a different resolution will trigger an error. The only way to avoid that error is to delete the model and recompile.

The recompile can be skipped if you load a compiled model. So unless you want the user to wait 15 minute for an image each time they pick a different image resolution, caching all possible resolution is sort of the only option.

Yes I noticed this on the paperspace Pod4 demo more recently. I also noticed that it only seems to compile on a single thread, and that changing num_images_per_prompt to 2 takes so long to compile that it makes the kernel time out (after ~30min) before it finishes.

The free Pod4 instance is a 56 thread vm with gobs of RAM, so in some kind of theoretical paperspace notebook, some resolution/model/batch combinations could be selected via the UI at the start and they could be compiled in parallel at startup and stored? But if this is too difficult, I am content with a single resolution/model combination for each existing pipe.

On the topic of an InvokeAI paperspace notebook, instead if recreating the UI, I think a workaround like this would allow the user to access the notebook directly:

https://nni.readthedocs.io/en/stable/sharings/nni_colab_support.html

brucethemoose avatar Dec 28 '22 20:12 brucethemoose

There has been no activity in this issue for 14 days. If this issue is still being experienced, please reply with an updated confirmation that the issue is still being experienced with the latest release.

github-actions[bot] avatar Mar 13 '23 06:03 github-actions[bot]