sparse icon indicating copy to clipboard operation
sparse copied to clipboard

Usage: Poor performance of NaN-aware xarray computations?

Open peanutfun opened this issue 2 months ago • 6 comments

Please provide a description of what you'd like to do.

I am using sparse arrays with xarray and I found that a simple DataArray.sum() operation performs poorly with skipna=True, which is the default for float data types. It seems like it's usually faster to densify the underlying array, compute the sum on it, and then sparsifying this result again, than computing the sum on the original array (granted, this only works if the array fits into memory).

What is even more confusing is that the performance gets even worse when I set fill_value=np.nan, in which case the skipna=True should be somewhat trivial. Why is that?

I am using fill_value=np.nan because I find it to be the "natural" choice for xarray data, because xarray uses NaNs to indicate "no data" and uses them as default fill values when merging, aligning, or extending data. I therefore think it's important that fill_value=np.nan does not incur a penalty when compared to the default fill_value=0.

PS: I thought this has to do with sparse data structures, not xarray, so I raised the issue here.

Example Code

# Fill value is 0
$ python -m timeit -s "import sparse; import xarray as xr; arr = xr.DataArray(sparse.random((100, 100, 100), density=0.1, fill_value=0), dims=['x', 'y', 'z'])" "arr.sum(dim='x')"
1 loop, best of 5: 24.1 msec per loop

# Fill value is NaN, takes twice (!) the time
$ python -m timeit -s "import sparse; import xarray as xr; import numpy as np; arr = xr.DataArray(sparse.random((100, 100, 100), density=0.1, fill_value=np.nan), dims=['x', 'y', 'z'])" "arr.sum(dim='x')"
1 loop, best of 5: 44.6 msec per loop

# Densifying is slightly faster
$ python -m timeit -s "import sparse; import xarray as xr; import numpy as np; arr = xr.DataArray(sparse.random((100, 100, 100), density=0.1, fill_value=np.nan), dims=['x', 'y', 'z'])" "arr.data = arr.data.todense(); arr.sum(dim='x'); arr.data = sparse.as_coo(arr.data)"
5 loops, best of 5: 38.9 msec per loop

# Setting the fill value to zero and not skipping NaNs is way faster
$ python -m timeit -s "import sparse; import xarray as xr; import numpy as np; arr = xr.DataArray(sparse.random((100, 100, 100), density=0.1, fill_value=np.nan), dims=['x', 'y', 'z'])" "arr.data.fill_value=0.0; arr.sum(dim='x', skipna=False)"
20 loops, best of 5: 7.81 msec per loop

# For comparison, the computation on a dense array is faster, even with skipna=True
$ python -m timeit -s "import sparse; import xarray as xr; import numpy as np; arr = xr.DataArray(sparse.random((100, 100, 100), density=0.1, fill_value=np.nan).todense(), dims=['x', 'y', 'z'])" "arr.sum(dim='x')"
100 loops, best of 5: 2.99 msec per loop

peanutfun avatar Nov 12 '25 11:11 peanutfun

It turns out that when calculating the results array, one needs to check if the result is NaN, and not to add it to the data array if so.

Indeed, comparing with NaN can be harder than other comparisons. This is because NaN isn't just one combination of 32/64 bits, it's many different combinations. This is further complicated by the fact that IEEE-754 requires that NaN != NaN, even if the underlying bits are the same.

We use isnan to check for "semantic equality" when the fill-value is NaN, and regular comparisons otherwise. I'm assuming this is what's taking the extra time.

hameerabbasi avatar Nov 12 '25 11:11 hameerabbasi

@hameerabbasi Thanks for your answer. I understand that checking for NaNs is expensive in general. However, this does not explain why changing the fill value of a sparse array makes the skipna check so much more expensive. After all, there are the same amount of numbers to be checked for being NaN.

After running some profiling, I think I understand what's happening:

  • If the fill value is zero, the condition array in where is a sparse matrix with fill value False and no other values (the original array is not NaN everywhere)
  • If the fill value is NaN, the condition array is a sparse matrix with fill value True and stored values False at the same coordinates as the original array. So it actually stores much more information than the other condition array and elemwise has to perform coordinate matching. That's why it takes longer.

I guess the remaining unfortunate thing is that indexing a sparse array with a sparse boolean array that has the exact same coordinates and is False everywhere takes so much longer than indexing with a sparse boolean array whose fill value is False:

import sparse as sp
import numpy as np

arr = sp.random((100, 100, 100), density=0.1, fill_value=0.0)
cond = sp.isnan(arr)  # fill_value = False, coords = [], data = []
sp.where(cond, sp.zeros_like(arr), arr)

arr.fill_value = np.nan
cond = sp.isnan(arr)  # fill_value = True, coords = arr.coords, data = np.full_like(arr.data, False)
sp.where(cond, sp.zeros_like(arr), arr)  # Takes twice the amount of time

My profiling:

import cProfile
import sparse
import xarray as xr
import numpy as np

arr = xr.DataArray(
    sparse.random((100, 100, 100), density=0.1, fill_value=0.0),
    dims=["x", "y", "z"],
)
for fill_value in (0.0, np.nan):
    arr.data.fill_value = fill_value
    with cProfile.Profile() as pr:
        for _ in range(100):
            result = arr.sum(dim="x", skipna=True)
    pr.dump_stats(f"sparse-{fill_value}")

fill_value=0.0

Image

fill_value=np.nan

Image

peanutfun avatar Nov 13 '25 10:11 peanutfun

I found that dispatching to sparse.nansum instead of relying on xarray.computation.nanops.nansum gives a significant speedup, see https://github.com/pydata/xarray/issues/10922

peanutfun avatar Nov 14 '25 11:11 peanutfun

That explains it! I appreciate the insight; thanks for tracking this down, @peanutfun!

hameerabbasi avatar Nov 14 '25 11:11 hameerabbasi

@hameerabbasi Are there any "guarantees" about the fill value not being present in the data? Meaning: Is there a way to be sure that the matrix was pruned and is read-only? For the case fill_value=np.nan, one could then simplify _replace_nan to just replace the fill value and not check any data values, which would increase performance of all nan functions significantly.

To make this work from xarray, it would need to dispatch to the nan functions, which are unfortunately not part of the general array API. See https://github.com/pydata/xarray/issues/10922 for the ongoing discussion.

peanutfun avatar Nov 19 '25 08:11 peanutfun

Yes; we guarantee that tensors will be pruned; but I'm not 100% sure that function isn't used anywhere else. I'm happy to accept a PR.

hameerabbasi avatar Nov 22 '25 08:11 hameerabbasi