[enhancement]: Graphcore IPU Support
Is there an existing issue for this?
- [X] I have searched the existing issues
Contact Details
InvokeAI Discord user of the same name.
What should this feature add?
Hello, since yall are migrating to diffusers anyway (#1583) , would you consider adding Graphcore IPU support as seen in the repo here?
https://www.graphcore.ai/posts/how-to-run-stable-diffusion-inference-on-ipus-with-paperspace
See ipu_models.py in the text-to-image demo, it looks like a fairly simple extension of StableDiffusionPipeline
I mention this because Paperspace is offering some particularly beefy free IPU instances now. Note that Nvidia claims an RTX 4090 is 165 FP16 tensor tflops:

(Note that I would make a crude attempt to hack this in myself and test it, but it looks like the diffusers implementation isn't done yet).
Their pipeline code is here: https://github.com/gradient-ai/Graphcore-HuggingFace/blob/main/stable-diffusion/ipu_models.py
Seems totally plausible. I just don't know how to do it cleanly yet. So much of this stuff is very subclass-happy -- including the work-in-progress pipeline in #1583 -- and that's not the best for composition. They also monkey with the cross-attention code which might conflict with some of Invoke's own desire to monkey with cross-attention code.
If you want to work on this, I think a good next step would be to take a look at this new diffusers API for cross-attention: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py
See if you can use that to replace graphcore's current override_attention method. And if you can't, raise that as an issue on the diffusers repo, because that cross-attention API isn't released yet and that could be good design feedback for them.
Hmmm, looks like the function descriptions are accurate: https://github.com/gradient-ai/Graphcore-HuggingFace/blob/main/stable-diffusion/ipu_models.py#L36
Overriding this implementation as the torch.baddbmm op is not registered.
I can't find it at the moment, but I read on some Graphcore blog post that torch.baddbmm is not supported on CPU. Replacing the attention OR sliced attention override with the diffusers versions always executes this, which trips it up:
https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py#L184
Cross Attention Error
---------------------------------------------------------------------------
Error Traceback (most recent call last)
<ipython-input-10-3a9910a550ac> in <module>
----> 1 pipe("apple", height=image_height, width=image_width, num_inference_steps=25, guidance_scale=9);
/usr/local/lib/python3.8/dist-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
26 def decorate_context(*args, **kwargs):
27 with self.__class__():
---> 28 return func(*args, **kwargs)
29 return cast(F, decorate_context)
30
/usr/local/lib/python3.8/dist-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py in __call__(self, prompt, height, width, num_inference_steps, guidance_scale, negative_prompt, num_images_per_prompt, eta, generator, latents, output_type, return_dict, callback, callback_steps)
527
528 # predict the noise residual
--> 529 noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
530
531 # perform guidance
/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
/notebooks/stable-diffusion/ipu_models.py in forward(self, sample, timestep, encoder_hidden_states, return_dict)
105 encoder_hidden_states = encoder_hidden_states.to(self.input_dtype)
106
--> 107 ret = self.unet(sample, timestep, encoder_hidden_states)
108
109 ret.sample = ret.sample.to(self.output_dtype)
/usr/local/lib/python3.8/dist-packages/poptorch/_poplar_executor.py in __call__(self, *args, **kwargs)
919
920 if not self.isCompiled():
--> 921 self._compile(in_tensors)
922
923 if not self._is_attached:
/usr/local/lib/python3.8/dist-packages/poptorch/_impl.py in wrapper(self, *args, **kwargs)
356 def wrapper(self, *args, **kwargs):
357 with self._profiling.tracepoint(label): # pylint: disable=protected-access
--> 358 return func(self, *args, **kwargs)
359
360 return wrapper
/usr/local/lib/python3.8/dist-packages/poptorch/_poplar_executor.py in _compile(self, in_tensors)
644 self._executable = poptorch_core.compileWithTrace(*trace_args)
645 else:
--> 646 self._executable = self._compileWithDispatch(
647 in_tensors_trace_view)
648
/usr/local/lib/python3.8/dist-packages/poptorch/_poplar_executor.py in _compileWithDispatch(self, in_tensors, executable_filename)
593 **in_tensors.kwargs)
594 else:
--> 595 ctx.compile(*in_tensors.args, **in_tensors.kwargs)
596 self._outputs_structure = ctx.ipu._outputs_structure # pylint: disable=protected-access
597
/usr/local/lib/python3.8/dist-packages/poptorch/experimental.py in compile(self, *args, **kwargs)
339
340 def compile(self, *args, **kwargs):
--> 341 return self._compileOrLoadExecutable(args, kwargs)
342
343 def loadExecutable(self, filename, *args, **kwargs):
/usr/local/lib/python3.8/dist-packages/poptorch/_impl.py in wrapper(*args, **kwargs)
162 def wrapper(*args, **kwargs):
163 with OnExit():
--> 164 return func(*args, **kwargs)
165
166 return wrapper
/usr/local/lib/python3.8/dist-packages/poptorch/experimental.py in _compileOrLoadExecutable(self, args, kwargs, filename)
380 tensor_args)
381
--> 382 result = self.func(*args, **kwargs)
383 if result is not None:
384 ipu.outputs(result)
/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
/usr/local/lib/python3.8/dist-packages/diffusers/models/unet_2d_condition.py in forward(self, sample, timestep, encoder_hidden_states, class_labels, attention_mask, return_dict)
422 for downsample_block in self.down_blocks:
423 if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
--> 424 sample, res_samples = downsample_block(
425 hidden_states=sample,
426 temb=emb,
/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1118 input = bw_hook.setup_input_hook(input)
1119
-> 1120 result = forward_call(*input, **kwargs)
1121 if _global_forward_hooks or self._forward_hooks:
1122 for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
/usr/local/lib/python3.8/dist-packages/diffusers/models/unet_2d_blocks.py in forward(self, hidden_states, temb, encoder_hidden_states, attention_mask)
775 else:
776 hidden_states = resnet(hidden_states, temb)
--> 777 hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
778
779 output_states += (hidden_states,)
/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1118 input = bw_hook.setup_input_hook(input)
1119
-> 1120 result = forward_call(*input, **kwargs)
1121 if _global_forward_hooks or self._forward_hooks:
1122 for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
/usr/local/lib/python3.8/dist-packages/diffusers/models/attention.py in forward(self, hidden_states, encoder_hidden_states, timestep, return_dict)
214 # 2. Blocks
215 for block in self.transformer_blocks:
--> 216 hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep)
217
218 # 3. Output
/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1118 input = bw_hook.setup_input_hook(input)
1119
-> 1120 result = forward_call(*input, **kwargs)
1121 if _global_forward_hooks or self._forward_hooks:
1122 for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
/usr/local/lib/python3.8/dist-packages/diffusers/models/attention.py in forward(self, hidden_states, encoder_hidden_states, timestep, attention_mask)
488 )
489 else:
--> 490 hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
491
492 if self.attn2 is not None:
/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1118 input = bw_hook.setup_input_hook(input)
1119
-> 1120 result = forward_call(*input, **kwargs)
1121 if _global_forward_hooks or self._forward_hooks:
1122 for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
/usr/local/lib/python3.8/dist-packages/diffusers/models/attention.py in forward(self, hidden_states, encoder_hidden_states, attention_mask)
638 hidden_states = self._attention(query, key, value, attention_mask)
639 else:
--> 640 hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
641
642 # linear proj
/usr/local/lib/python3.8/dist-packages/diffusers/models/attention.py in _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask)
695 key_slice = key_slice.float()
696
--> 697 attn_slice = torch.baddbmm(
698 torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
699 query_slice,
Error: In poptorch/source/dispatch_tracer/dispatchers/MLIRDispatch.cpp:372: 'poptorch_cpp_error': No shape inference handler for aten::baddbmm
Not sure if there is an easy way to work around that.
I worked with the mentioned code a few weeks ago (on optimum, but same attention overwrite). The code is for older attention (before diffuser refactored much of the attention code). It should be possible to get it to work without too much hassle, since it's just replacing attention with simpler/less efficient code.
Though, there is another major problem that require solving. You will need to compile the model and cache the complied model somewhere, for every possible resolution.
The first time the model is run, it will compile itself. This process takes around 15 minute on the free pod16 machine offered a while ago. After it's compiled, feeding it a different resolution will trigger an error. The only way to avoid that error is to delete the model and recompile.
The recompile can be skipped if you load a compiled model. So unless you want the user to wait 15 minute for an image each time they pick a different image resolution, caching all possible resolution is sort of the only option.
I worked with the mentioned code a few weeks ago (on optimum, but same attention overwrite). The code is for older attention (before diffuser refactored much of the attention code). It should be possible to get it to work without too much hassle, since it's just replacing attention with simpler/less efficient code.
Though, there is another major problem that require solving. You will need to compile the model and cache the complied model somewhere, for every possible resolution.
The first time the model is run, it will compile itself. This process takes around 15 minute on the free pod16 machine offered a while ago. After it's compiled, feeding it a different resolution will trigger an error. The only way to avoid that error is to delete the model and recompile.
The recompile can be skipped if you load a compiled model. So unless you want the user to wait 15 minute for an image each time they pick a different image resolution, caching all possible resolution is sort of the only option.
Yes I noticed this on the paperspace Pod4 demo more recently. I also noticed that it only seems to compile on a single thread, and that changing num_images_per_prompt to 2 takes so long to compile that it makes the kernel time out (after ~30min) before it finishes.
The free Pod4 instance is a 56 thread vm with gobs of RAM, so in some kind of theoretical paperspace notebook, some resolution/model/batch combinations could be selected via the UI at the start and they could be compiled in parallel at startup and stored? But if this is too difficult, I am content with a single resolution/model combination for each existing pipe.
On the topic of an InvokeAI paperspace notebook, instead if recreating the UI, I think a workaround like this would allow the user to access the notebook directly:
https://nni.readthedocs.io/en/stable/sharings/nni_colab_support.html
There has been no activity in this issue for 14 days. If this issue is still being experienced, please reply with an updated confirmation that the issue is still being experienced with the latest release.