Chunked pint arrays break on `rolling()`
Hi folks,
I noticed that when running .rolling(...) on a chunked pint array, there is an exception raised that breaks the process:
TypeError: `pad_value` must be composed of integral typed values.
I outline three different cases below for running .rolling() on a pint-aware DataArray.
- Calculating the rolling sum on an in-memory
pintarray.- Works, but loses units as expected (e.g. https://github.com/xarray-contrib/pint-xarray/issues/6#issuecomment-611134048). Although it seems like running it with
xr.set_options(use_bottleneck=False)preserves units (https://github.com/pydata/xarray/issues/7062#issuecomment-1254047656).
- Works, but loses units as expected (e.g. https://github.com/xarray-contrib/pint-xarray/issues/6#issuecomment-611134048). Although it seems like running it with
- Calculating the rolling sum on a chunked
pintarray, usingxarraychunking.- This works, even without turning off
bottleneck. However, this isn't an optimal solution for me, since one cannot queryds.pint.unitson anxarray-chunkedpintarray. I like being able to do that for various QOL checks in a data pipeline.
- This works, even without turning off
- Calculating the rolling sum on a
pintarray chunked withds.pint.chunk(...).- This method preserves the units, but leads to the traceback seen above and in full detail below. It also breaks when turning off
bottleneck.
- This method preserves the units, but leads to the traceback seen above and in full detail below. It also breaks when turning off
import pint_xarray
import xarray as xr
print(xr.__version__)
>>> '2022.6.0'
print(pint_xarray.__version__)
>>> '0.3'
data = xr.DataArray(range(3), dims='time').pint.quantify('kelvin')
print(data)
>>> <xarray.DataArray (time: 3)>
>>> <Quantity([0 1 2], 'kelvin')>
# Case 1: rolling sum with `pint` units.
# Lose the units as expected, but executes properly.
rs = data.rolling(time=2).sum()
print(rs)
>>> <xarray.DataArray (time: 3)>
>>> array([nan, 1., 3.])
# Case 2: rolling sum with `xr.chunk()`
# Maintain the units after compute,
# but `data_xr_chunk.pint.units` returns `None` in the interim
data_xr_chunk = data.chunk({'time': 1})
rs = data_xr_chunk.rolling(time=2).sum().compute()
>>> <xarray.DataArray (time: 3)>
>>> <Quantity([nan 1. 3.], 'kelvin')>
# Case 3: rolling sum with `xr.pint.chunk()`
# Maintains units on chunked array, but raises exception
# (see full traceback below)
data_pint_chunk = data.pint.chunk({"time": 1})
rs = data_pint_chunk.rolling(time=2).sum().compute()
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Input In [31], in <cell line: 1>()
----> 1 rs = data_pint_chunk.rolling(time=2).sum().compute()
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/xarray/core/rolling.py:155, in Rolling._reduce_method.<locals>.method(self, keep_attrs, **kwargs)
151 def method(self, keep_attrs=None, **kwargs):
153 keep_attrs = self._get_keep_attrs(keep_attrs)
--> 155 return self._numpy_or_bottleneck_reduce(
156 array_agg_func,
157 bottleneck_move_func,
158 rolling_agg_func,
159 keep_attrs=keep_attrs,
160 fillna=fillna,
161 **kwargs,
162 )
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/xarray/core/rolling.py:589, in DataArrayRolling._numpy_or_bottleneck_reduce(self, array_agg_func, bottleneck_move_func, rolling_agg_func, keep_attrs, fillna, **kwargs)
586 kwargs.setdefault("skipna", False)
587 kwargs.setdefault("fillna", fillna)
--> 589 return self.reduce(array_agg_func, keep_attrs=keep_attrs, **kwargs)
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/xarray/core/rolling.py:472, in DataArrayRolling.reduce(self, func, keep_attrs, **kwargs)
470 else:
471 obj = self.obj
--> 472 windows = self._construct(
473 obj, rolling_dim, keep_attrs=keep_attrs, fill_value=fillna
474 )
476 result = windows.reduce(
477 func, dim=list(rolling_dim.values()), keep_attrs=keep_attrs, **kwargs
478 )
480 # Find valid windows based on count.
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/xarray/core/rolling.py:389, in DataArrayRolling._construct(self, obj, window_dim, stride, fill_value, keep_attrs, **window_dim_kwargs)
384 window_dims = self._mapping_to_list(
385 window_dim, allow_default=False, allow_allsame=False # type: ignore[arg-type] # https://github.com/python/mypy/issues/12506
386 )
387 strides = self._mapping_to_list(stride, default=1)
--> 389 window = obj.variable.rolling_window(
390 self.dim, self.window, window_dims, self.center, fill_value=fill_value
391 )
393 attrs = obj.attrs if keep_attrs else {}
395 result = DataArray(
396 window,
397 dims=obj.dims + tuple(window_dims),
(...)
400 name=obj.name,
401 )
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/xarray/core/variable.py:2314, in Variable.rolling_window(self, dim, window, window_dim, center, fill_value)
2311 else:
2312 pads[d] = (win - 1, 0)
-> 2314 padded = var.pad(pads, mode="constant", constant_values=fill_value)
2315 axis = [self.get_axis_num(d) for d in dim]
2316 new_dims = self.dims + tuple(window_dim)
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/xarray/core/variable.py:1416, in Variable.pad(self, pad_width, mode, stat_length, constant_values, end_values, reflect_type, **pad_width_kwargs)
1413 if reflect_type is not None:
1414 pad_option_kwargs["reflect_type"] = reflect_type
-> 1416 array = np.pad( # type: ignore[call-overload]
1417 self.data.astype(dtype, copy=False),
1418 pad_width_by_index,
1419 mode=mode,
1420 **pad_option_kwargs,
1421 )
1423 return type(self)(self.dims, array)
File <__array_function__ internals>:180, in pad(*args, **kwargs)
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/pint/quantity.py:1730, in Quantity.__array_function__(self, func, types, args, kwargs)
1729 def __array_function__(self, func, types, args, kwargs):
-> 1730 return numpy_wrap("function", func, args, kwargs, types)
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/pint/numpy_func.py:936, in numpy_wrap(func_type, func, args, kwargs, types)
934 if name not in handled or any(is_upcast_type(t) for t in types):
935 return NotImplemented
--> 936 return handled[name](*args, **kwargs)
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/pint/numpy_func.py:660, in _pad(array, pad_width, mode, **kwargs)
656 if key in kwargs:
657 kwargs[key] = _recursive_convert(kwargs[key], units)
659 return units._REGISTRY.Quantity(
--> 660 np.pad(array._magnitude, pad_width, mode=mode, **kwargs), units
661 )
File <__array_function__ internals>:180, in pad(*args, **kwargs)
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/array/core.py:1762, in Array.__array_function__(self, func, types, args, kwargs)
1759 if has_keyword(da_func, "like"):
1760 kwargs["like"] = self
-> 1762 return da_func(*args, **kwargs)
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/array/creation.py:1229, in pad(array, pad_width, mode, **kwargs)
1227 elif mode == "constant":
1228 kwargs.setdefault("constant_values", 0)
-> 1229 return pad_edge(array, pad_width, mode, **kwargs)
1230 elif mode == "linear_ramp":
1231 kwargs.setdefault("end_values", 0)
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/array/creation.py:964, in pad_edge(array, pad_width, mode, **kwargs)
957 def pad_edge(array, pad_width, mode, **kwargs):
958 """
959 Helper function for padding edges.
960
961 Handles the cases where the only the values on the edge are needed.
962 """
--> 964 kwargs = {k: expand_pad_value(array, v) for k, v in kwargs.items()}
966 result = array
967 for d in range(array.ndim):
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/array/creation.py:964, in <dictcomp>(.0)
957 def pad_edge(array, pad_width, mode, **kwargs):
958 """
959 Helper function for padding edges.
960
961 Handles the cases where the only the values on the edge are needed.
962 """
--> 964 kwargs = {k: expand_pad_value(array, v) for k, v in kwargs.items()}
966 result = array
967 for d in range(array.ndim):
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/array/creation.py:910, in expand_pad_value(array, pad_value)
908 pad_value = array.ndim * (tuple(pad_value[0]),)
909 else:
--> 910 raise TypeError("`pad_value` must be composed of integral typed values.")
912 return pad_value
TypeError: `pad_value` must be composed of integral typed values.
My solution in the interim is to do something like:
units = data.pint.units
data = data.pint.dequantify()
rs = data.rolling(time=2)
rs = rs.pint.quantify(units)
Another side note that came up here -- I'm curious if there's any roadmap plan for recognizing integration of units for methods like rolling().sum().
E.g.,
data = xr.DataArray(range(3), dims='time').pint.quantify('mm/day')
data.pint.units
>>> mm/day
data = data.rolling(time=2).sum()
data.pint.units
>>> mm
thanks for the report, @riley-brady. It seems that xarray operations on pint+dask are not as thoroughly tested as pint and dask on their own. I think this is a bug in pint (or dask, not sure): we enable force_ndarray_like to convert scalars to 0d arrays, which means that the final call to np.pad becomes:
np.pad(magnitude, pad_width, mode="constant", constant_values=np.array(0))
numpy seems to be fine with that, but dask is not.
@jrbourbeau, what do you think? Would it make sense to extend expand_pad_value to unpack 0d arrays (using .item()), or would you rather have the caller (pint, in this case) do that?
I'm curious if there's any roadmap plan for recognizing integration of units for methods like
rolling().sum()
I'm not sure I follow. Why would rolling().sum() work similar to integration, when all it does is compute a grouped sum? I'm not sure if this actually counts as integration, but you can multiply the result of the rolling sum with the diff of the time coordinate (which is a bit tricky because time is an indexed coordinate):
data = xr.DataArray(
range(3), dims="time", coords={"time2": ("time", [1, 2, 3])}
).pint.quantify("mm/day", time="day")
dt = data.time2.pad(time=(1, 0)).diff(dim="time")
data.rolling(time=2).sum() * dt
and then you would have the correct units (with the same numerical result, because I chose the time coordinate to have increments of 1 day)
Thanks for the quick feedback on this issue @keewis.
Also thanks for the demo with .diff(). You're right about the integration assumptions. In my specific use case I am doing a rolling sum of units mm/day with daily time steps, so in this case it should reflect total precip in mm, but that's not a fair assumption for many other cases. I'll give the .diff() method a try.
this should be fixed in dask since quite a while ago, but I'll leave it open until we have tests for this (probably after copying the test suite from xarray)