Incompatibilities with bloat16 after update to numpy 2
What happened?
Computing the max or the isnull on a DataArray with bfloat16 values raises a:
TypeError: dtype argument must be a NumPy dtype, but it is a <class 'numpy.dtype[bfloat16]'>.
This worked fine before updating numpy to version 2. The main difference in the code seems to be that with numpy < 2, xarray uses its own implementation of isdtype, while for numpy >= 2 it relies on np.isdtype. This is confirmed by checking that doing import numpy as np; del np.isdtype fixes the problem.
What did you expect to happen?
I expected the computation to be successful, just as prior to numpy 2.
Minimal Complete Verifiable Example
import numpy as np
# del np.isdtype # Uncommenting this line fixes it.
import xarray
import ml_dtypes
da = xarray.DataArray(np.zeros([2], dtype=ml_dtypes.bfloat16), dims=("dim",))
da.isnull() # Or da.max("dim")
MVCE confirmation
- [ ] Minimal example — the example is as focused as reasonably possible to demonstrate the underlying issue in xarray.
- [ ] Complete example — the example is self-contained, including all data and the text of any traceback.
- [ ] Verifiable example — the example copy & pastes into an IPython prompt or Binder notebook, returning the result.
- [ ] New issue — a search of GitHub Issues suggests this is not a duplicate.
- [ ] Recent environment — the issue occurs with the latest version of xarray and its dependencies.
Relevant log output
TypeError Traceback (most recent call last)
Cell In[1], line 5
3 import numpy as np
4 da = xarray.DataArray(np.zeros([2], dtype=jnp.bfloat16), dims=("dim",))
----> 5 da.isnull()
File ~/dev/xarray/xarray/core/common.py:1293, in DataWithCoords.isnull(self, keep_attrs)
1290 if keep_attrs is None:
1291 keep_attrs = _get_keep_attrs(default=False)
-> 1293 return apply_ufunc(
1294 duck_array_ops.isnull,
1295 self,
1296 dask="allowed",
1297 keep_attrs=keep_attrs,
1298 )
File ~/dev/xarray/xarray/core/computation.py:1278, in apply_ufunc(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, meta, dask_gufunc_kwargs, on_missing_core_dim, *args)
1276 # feed DataArray apply_variable_ufunc through apply_dataarray_vfunc
1277 elif any(isinstance(a, DataArray) for a in args):
-> 1278 return apply_dataarray_vfunc(
1279 variables_vfunc,
1280 *args,
1281 signature=signature,
1282 join=join,
1283 exclude_dims=exclude_dims,
1284 keep_attrs=keep_attrs,
1285 )
1286 # feed Variables directly through apply_variable_ufunc
1287 elif any(isinstance(a, Variable) for a in args):
File ~/dev/xarray/xarray/core/computation.py:320, in apply_dataarray_vfunc(func, signature, join, exclude_dims, keep_attrs, *args)
315 result_coords, result_indexes = build_output_coords_and_indexes(
316 args, signature, exclude_dims, combine_attrs=keep_attrs
317 )
319 data_vars = [getattr(a, "variable", a) for a in args]
--> 320 result_var = func(*data_vars)
322 out: tuple[DataArray, ...] | DataArray
323 if signature.num_outputs > 1:
File ~/dev/xarray/xarray/core/computation.py:831, in apply_variable_ufunc(func, signature, exclude_dims, dask, output_dtypes, vectorize, keep_attrs, dask_gufunc_kwargs, *args)
826 if vectorize:
827 func = _vectorize(
828 func, signature, output_dtypes=output_dtypes, exclude_dims=exclude_dims
829 )
--> 831 result_data = func(*input_data)
833 if signature.num_outputs == 1:
834 result_data = (result_data,)
File ~/dev/xarray/xarray/core/duck_array_ops.py:144, in isnull(data)
139 if dtypes.is_datetime_like(scalar_type):
140 # datetime types use NaT for null
141 # note: must check timedelta64 before integers, because currently
142 # timedelta64 inherits from np.integer
143 return isnat(data)
--> 144 elif dtypes.isdtype(scalar_type, ("real floating", "complex floating"), xp=xp):
145 # float types use NaN for null
146 xp = get_array_namespace(data)
147 return xp.isnan(data)
File ~/dev/xarray/xarray/core/dtypes.py:208, in isdtype(dtype, kind, xp)
205 raise TypeError(f"kind must be a string or a tuple of strings: {repr(kind)}")
207 if isinstance(dtype, np.dtype):
--> 208 return npcompat.isdtype(dtype, kind)
209 elif is_extension_array_dtype(dtype):
210 # we never want to match pandas extension array dtypes
211 return False
File ~/miniconda3/envs/xarray-py312/lib/python3.12/site-packages/numpy/_core/numerictypes.py:425, in isdtype(dtype, kind)
423 dtype = _preprocess_dtype(dtype)
424 except _PreprocessDTypeError:
--> 425 raise TypeError(
426 "dtype argument must be a NumPy dtype, "
427 f"but it is a {type(dtype)}."
428 ) from None
430 input_kinds = kind if isinstance(kind, tuple) else (kind,)
432 processed_kinds = set()
TypeError: dtype argument must be a NumPy dtype, but it is a <class 'numpy.dtype[bfloat16]'>.
Anything else we need to know?
Here's a a different reproducer showing the inconsistency between np.isdtype and npcompat.isdtype
import importlib
from xarray.core import npcompat
import ml_dtypes
import numpy as np
try:
npcompat.isdtype(ml_dtypes.bfloat16.dtype, 'real floating') # `AttributeError: 'module' object has no attribute 'isdtype'`
except Exception as e:
print(e)
numpy_is_dytype = np.isdtype
del np.isdtype
importlib.reload(npcompat)
np.isdtype = numpy_is_dytype
npcompat.isdtype(ml_dtypes.bfloat16.dtype, 'real floating') # No error, but returns False.
Environment
In [5]: xarray.show_versions()
INSTALLED VERSIONS
commit: 03d3e0b5992051901c777cbf2c481abe2201facd python: 3.12.3 | packaged by conda-forge | (main, Apr 15 2024, 18:35:20) [Clang 16.0.6 ] python-bits: 64 OS: Darwin OS-release: 23.6.0 machine: arm64 processor: arm byteorder: little LC_ALL: None LANG: en_US.UTF-8 LOCALE: ('en_US', 'UTF-8') libhdf5: 1.14.3 libnetcdf: 4.9.2
xarray: 2024.7.1.dev73+g781877cb pandas: 2.2.2 numpy: 2.1.1 scipy: 1.13.1 netCDF4: 1.6.5 pydap: None h5netcdf: None h5py: None zarr: 2.18.2 cftime: 1.6.4 nc_time_axis: None iris: None bottleneck: None dask: 2024.8.2 distributed: 2024.5.2 matplotlib: 3.9.0 cartopy: None seaborn: None numbagg: None fsspec: 2024.6.0 cupy: None pint: None sparse: None flox: None numpy_groupies: 0.11.1 setuptools: 70.0.0 pip: 24.0 conda: 24.7.1 pytest: 8.2.2 mypy: 1.10.0 IPython: 8.25.0
the difference here is that npcompat.isdtype translates the string to a numpy.dtype superclass, then uses isinstance to perform the check, while np.isdtype explicitly raises if it receives anything other than np.dtype subclasses or the string categories.
I don't think we can do a lot here (correct me if I'm wrong, @shoyer), so it might make more sense to take this up with the numpy devs.
cc @rgommers, @seberg for awareness
The quick thing is to use np.dtype() for conversion of the dtype (i.e. also in your code). I suspect np.isdtype (and other maybe other "array api" function) should do this explicitly.
(I am not sure why it ends up where it ends up with the DType class.)
EDIT: To be clear, since this tries to use array API, I don't know that is possible to work around easily.
for reference, the reason this error is raised is because numpy._core._type_aliases.allTypes contains a explicit list of allowed dtypes, so any new dtypes that are not in that list – like ml_dtypes.bfloat16 or even the new numpy.dtypes.StringDType – when passed to numpy.isdtype will trigger this error.
Which means that wrapping in numpy.dtype does not help, unfortunately.
contains a explicit list of allowed dtypes
Can you open a NumPy issue about it? I know that there is always this knee jerk reaction to focus on the Array API blessed dtypes only, but honestly, that is just wrong. This is NumPy API and while there may be some guarantees missing, it shouldn't be artificially limiting here.
I have to look at what is going on closer. Maybe using this list was just a case of cargo-culted from the wrong place. Translating arbitrary objects to dtype instances is tricky.
In the meantime, would it make sense to simply continue falling back into the xarray implementationnpcompat.isdtype, even when np.isdtype it exists (instead of this try/except)? At the end of the day this is failing at an xarray callsite.
Can you open a NumPy issue about it?
See numpy/numpy#27545
In the meantime, would it make sense to simply continue falling back into the xarray implementation
npcompat.isdtype
As npcompat is compatibility code that's supposed to go away as soon as we can require a specific numpy version I'd prefer waiting until the numpy team has reached a decision. However, we don't really have to wait until that change in numpy has been released to write the compat code.