[ONNX] PyTorch IO binding for faster GPU inference
Context
Currently OnnxStableDiffusionPipeline perform unnecessary tensor casting between torch and numpy. The downsides of that are:
- The pipeline code is harder to maintain: we have to keep track of which tensors are numpy, when they have to be converted to torch (e.g. before
scheduler.step()), and when to convert them back. - Passing tensors between CPU and GPU hurts
CUDAExecutionProviderlatency: ideally the UNet input/outputs (latents) should stay on the same device between sampling iterations.
Proposed solution Take advantage of the IO binding mechanism in ONNXRuntime to bind the pytorch tensors to model inputs and outputs and keep them on the same device. For more details see: https://onnxruntime.ai/docs/api/python/api_summary.html#data-on-device
Standalone example of torch IO binding:
# X is a PyTorch tensor on device
session = onnxruntime.InferenceSession('model.onnx', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']))
binding = session.io_binding()
X_tensor = X.contiguous()
binding.bind_input(
name='X',
device_type='cuda',
device_id=0,
element_type=np.float32,
shape=tuple(x_tensor.shape),
buffer_ptr=x_tensor.data_ptr(),
)
## Allocate the PyTorch tensor for the model output
Y_shape = ... # You need to specify the output PyTorch tensor shape
Y_tensor = torch.empty(Y_shape, dtype=torch.float32, device='cuda:0').contiguous()
binding.bind_output(
name='Y',
device_type='cuda',
device_id=0,
element_type=np.float32,
shape=tuple(Y_tensor.shape),
buffer_ptr=Y_tensor.data_ptr(),
)
session.run_with_iobinding(binding)
This functionality can either be integrated into OnnxRuntimeModel or into each of the OnnxPipelines individually. For easier maintenance I would go with the first option.
cc @mfuntowicz @echarlaix here - it would be good to not replicate efforts of optimum here
talked to @echarlaix and @JingyaHuang, we'll make this a part of Optimum integration (moving diffusers.OnnxRuntimeModel to something like optimum.DiffusersModel): https://github.com/huggingface/optimum/pull/447