GPU acceleration of the reproject package
This is a copy of a question I posed on the slack channel opened up to all users
Hello! I am working on speeding up the reproject package using the GPU. I've already updated the pixel to pixel functionality for a 30% reduction in computation time for this algorithm. I'm going to be working on updating the other functions in the package to run on the GPU. I wanted to know if anyone has already done this or if someone is currently working on something similar. According to the Roadmap (or at least this is how I understood it), there is a need for someone to work on this type of implementation. Would it be possible to talk to anyone on the dev team about this? I'm going to be working on this acceleration either way, so I'd like to be able to contribute if the community thinks it would be helpful
I don't know anything about using GPUs (and I wish I did), but I'm very interested in this, as are a few of my colleagues. I'd be really interested in seeing how you're doing this, and I'm happy to help in whatever (perhaps limited) ways I can!
So I rewrote the pixel_to_pixel algorithm using cupy and got a 33% reduction in time. I also implemented it in torch and got a 45% reduction. This was the most time-consuming function that didn't involve changing WCS. The next step for me is to try and tackle some of the WCS computations using torch.
Here is the implementation in torch
ef pixel_to_pixel_gpu(wcs_in: BaseHighLevelWCS, wcs_out: BaseHighLevelWCS, *inputs):
"""
GPU version: Transform pixel coordinates using PyTorch, optimized to reduce transfer overhead.
"""
# Automatically select device (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if np.isscalar(inputs[0]):
world_outputs = wcs_in.pixel_to_world(*inputs)
if not isinstance(world_outputs, (tuple, list)):
world_outputs = (world_outputs,)
return wcs_out.world_to_pixel(*world_outputs)
original_shape = inputs[0].shape
outputs = [None] * wcs_out.pixel_n_dim
# Ensure inputs are torch tensors and move to selected device
pixel_inputs = [
torch.tensor(arr).to(device) if not isinstance(arr, torch.Tensor) else arr.to(device) for arr in inputs
]
pixel_inputs = torch.broadcast_tensors(*pixel_inputs)
# Compute world outputs on the CPU using the WCS functions
world_outputs_cpu = wcs_in.pixel_to_world(*[arr.cpu() for arr in pixel_inputs])
if not isinstance(world_outputs_cpu, (tuple, list)):
world_outputs_cpu = (world_outputs_cpu,)
pixel_outputs_cpu = wcs_out.world_to_pixel(*world_outputs_cpu)
if wcs_out.pixel_n_dim == 1:
pixel_outputs_cpu = (pixel_outputs_cpu,)
for i in range(wcs_out.pixel_n_dim):
outputs[i] = pixel_outputs_cpu[i]
# Convert torch tensors back to NumPy arrays
outputs = [output for output in outputs]
return outputs[0] if wcs_out.pixel_n_dim == 1 else outputs
Thanks for posting your code! I tried out your function on my computer and I'm afraid I'm not seeing any speedup. Here's how I explored this, on a computer with an RTX 3070. The GPU version was faster for small inputs of just a thousand coordinates, but the difference could come down to just having fewer lines of Python code---I saw a similar speedup from a stripped-down CPU-only version. For a larger case of 4k x 4k input coordinates, where any speedup could be really valuable, I have the GPU version taking longer (I'm guessing from the overhead of moving data to the GPU and back). Does it work differently on your computer? Or do you have different inputs that show a difference?
I don't know anything about pyTorch, but it looks like your function only uses the GPU to broadcast the input arrays, and then still uses the CPU for the coordinate conversions (pixel_to_world and world_to_pixel)---is that right? When I profile pixel_to_pixel (at the bottom of that link), basically all the compute time is spent in those two function, so I suspect big speedups will only come from accelerating those functions (which I'm guessing would be very involved).
Let me know if I'm missing anything!
Hello, Thanks for posting this! I reran your code on my machine (RTX 4060), and, unfortunately, I get the same results as you. I'm including my test that showed me the non-negligible (45%) speedup for using the GPU on 4000x6000 images. I'll need to figure out why it is doing so much more poorly on the test you provided. Perhaps it is because of my inputs...
You are absolutely correct in your assessment of what pyTorch is doing! My goal with this little test was to begin looking at where functions could be rewritten to speed things up. I also profiled the functionality and found that the pixel_to_world and world_to_pixel functions are the main bottlenecks. I've started to, slowly, work on updating these to run using pyTorch since we lose a lot of potential speedup having to recast torch objects to numpy objects for these calculations.
I'm also getting a big speedup with your code:
Running CPU version...
CPU execution time: 4.105450 seconds
Running GPU version...
GPU execution time: 2.498859 seconds
GPU speedup: 39.13%
Results match within tolerance.
I tried filling in the details of the WCSes, to make sure we're getting a representative workload for the coordinate conversions:
wcs_in = WCS(naxis=2)
wcs_out = WCS(naxis=2)
wcs_in.wcs.crpix = 500, 500
wcs_in.wcs.crval = 0, 10
wcs_in.wcs.ctype = 'RA---CAR', 'DEC--CAR'
wcs_in.wcs.cdelt = 0.05, 0.05
wcs_out.wcs.crpix = 500, 500
wcs_out.wcs.crval = 30, 12
wcs_out.wcs.ctype = 'RA---AZP', 'DEC--AZP'
wcs_out.wcs.cdelt = 0.05, 0.05
That didn't really change the speedup factor though. I explored some and found that the execution time drops a lot if I use astropy.wcs.utils.pixel_to_pixel for the CPU version:
Running CPU version...
CPU execution time: 19.076000 seconds
Running astropy version...
astropy execution time: 9.292996 seconds
Running GPU version...
GPU execution time: 9.689164 seconds
GPU speedup: 49.21%
Results match within tolerance.
I think the problem is that in the cut-down CPU version in your file, this loop
for i in range(wcs_out.pixel_n_dim):
pixel_inputs = np.broadcast_arrays(*inputs)
world_outputs = wcs_in.pixel_to_world(*pixel_inputs)
...
has it run the whole thing twice. I think that came from the astropy version, where if the coordinates are independent (e.g. longitude depends only on x and latitude depends only on y), it tries to transform the x and y coordinates separately without broadcasting them together (and unbroadcasting them if necessary!), potentially saving a lot of compute time. The cut down version removes that independence check but keeps a modified loop, creating this bug that expands the execution time.
I really hope you're successful getting pixel_to_world and world_to_pixel to run on the GPU! The pipeline for the PUNCH mission reprojects a lot of large files, so we could save a lot of runtime with this sort of optimization!
Thanks for explaining that!
Instead of trying to wrangle astropy (which is obviously an amazing package), I wrote a standalone package written using torch that will compute reprojections. If it is acceptable, I'll post the link to the GitHub when it becomes public.
If anyone has tested cupyx.scipy.ndimage.map_coordinates?
scipy.ndimage.map_coordinates seems to be much slower than cv2.remap in fits larger than 8k×8k
maybe gpu acceleration is prefered here, but writing cp.array from gpu back to cpu costs much time
#cv2
import cv2
import numpy as np
src = np.random.rand(8000, 8000).astype(np.float32)
x, y = np.meshgrid(np.arange(2000), np.arange(2000))
map_x = (x + 0.5).astype(np.float32)
map_y = (y + 0.5).astype(np.float32)
result = cv2.remap(src, map_x, map_y, cv2.INTER_LINEAR)
#scipy
import numpy as np
from scipy.ndimage import map_coordinates
src = np.random.rand(8000, 8000)
coords = np.indices((8000, 8000)).astype(float)
coords += 0.5
result = map_coordinates(src, coords, order=1, mode='constant', cval=0)
#cuda+scipy
import cupy as cp
from cupyx.scipy.ndimage import map_coordinates
src_gpu = cp.random.rand(8000, 8000).astype(cp.float32)
coords = cp.indices((8000, 8000)).astype(cp.float32)
coords += 0.5 # 示例偏移
result_gpu = map_coordinates(src_gpu, coords, order=1, mode='constant')
Hello, I haven't tried that function, but I was able to get an order of magnitude speed up over map_coordinates by using torch.
I also tried pytorch but not using map_coordinates. I used F.grid_sample , found it a bit slower than cupyx.
Well both tensor and cupy.array need to send back to CPU from GPU, which cost some time.
import torch
import torch.nn.functional as F
src_tensor = torch.rand(1, 1, 8000, 8000).cuda()
grid = torch.meshgrid(
torch.linspace(-1, 1, 8000),
torch.linspace(-1, 1, 8000)
)
grid = torch.stack(grid, dim=-1).unsqueeze(0).cuda() + 0.5/2000
result_tensor = F.grid_sample(
src_tensor,
grid,
mode='bilinear',
padding_mode='zeros',
align_corners=False
)
array_utils.txt
common.txt
core.txt
I managed to write the cupy gpu 'acceleration' version of function reproject_interp above. But the speed still needs to carefully modified... Currently 8s for a 8K×8K image when parallel=8, order = 1, block_size = (2048,2048). Sadly my 12G GPU memory limits the maximum core number here.
Thanks all, this is very interesting! If it's possible to get speedups by just changing the map_coordinates to e.g. F.grid_sample, then it seems like we should support it as as option (and have e.g. pytorch or cupy as an optional dependency). I've also been playing around with jax to see if we could speed up some of the WCS transformations, and it seems very promising.
So to summarize, there are two main places where we can speed things up in my view:
-
The first is
pixel_to_pixelwhich lives in astropy and could potentially be optimized for certain projections. I will open an issue over at astropy to keep track of what we can do there -
The second is that there might be some almost-drop-in replacements for
map_coordinatesinreproject_interp. I think we should implement this, and add a keyword argument toreproject_interpallowing users to select the implementation to use, and in future we could make the defaultautoinstead ofscipyonce we have a better understanding of performance/issues.
This was mentioned on the astropy slack today. Just cross posting here in case it's helpful: https://github.com/DragonflyTelescope/dfreproject
Thanks! I've now opened an issue over at astropy to discuss specifically the speeding up of WCS transformations: https://github.com/astropy/astropy/issues/18113
This might be an easy drop-in for map_coordinates: https://docs.jax.dev/en/latest/_autosummary/jax.scipy.ndimage.map_coordinates.html - it is faster than map_coordinates from scipy but only for large number of points to interpolate - for instance for 10^7 points it is around 5x faster. However for 10^5 points it is slower, presumably because of getting the data into the GPU. Nevertheless, it could always be provided as an opt-in option for people who want to explore whether it helps their use case.
So one thing to consider is that map_coordinates replacements will probably only be useful if they support numpy memmap natively or if the array is fully in memory. For example jax's map_coordinates requires the array to be in-memory, so if we did use it, it would only be for in-memory arrays (and order <= 1)