array-api icon indicating copy to clipboard operation
array-api copied to clipboard

result_type() for mixed arrays/Python scalars

Open shoyer opened this issue 1 year ago • 11 comments

The array API's type promotion rules support mixed scalar/array operations, e.g., 1 + xp.arange(3).

For Xarray, we would like to be able to figure out the resulting dtype from this sort of operation before actually doing it (https://github.com/pydata/xarray/pull/8946).

Ideally, we could use xp.result_type() for this purpose, but as documented result_type only supports arrays and dtype objects. Could we potentially extend result_type to also handle Python scalars? It is worth noting that this already works today in NumPy, e.g.,

>>> np.result_type(1, np.arange(3))
dtype('int64')

shoyer avatar May 15 '24 16:05 shoyer

This makes sense to me. torch seems to support this as well. What should the result be if there are multiple Python scalars? Undefined?

asmeurer avatar May 15 '24 17:05 asmeurer

What should the result be if there are multiple Python scalars? Undefined?

This should indeed probably be undefined by the spec.

In most cases I imagine array libraries will have a default dtype, but different libraries will make different choices (e.g., int32 in JAX vs int64 in NumPy):

>> np.result_type(1, 2)
dtype('int64')
>> jnp.result_type(1, 2)
dtype('int32')

shoyer avatar May 15 '24 18:05 shoyer

One concern I see with this is that libraries need not support Python scalars in functions, only for operators. So result_type(a, b) working does not imply that func(a, b) will work.

asmeurer avatar May 15 '24 18:05 asmeurer

One concern I see with this is that libraries need not support Python scalars in functions, only for operators. So result_type(a, b) working does not imply that func(a, b) will work.

In Xarrray, we are thinking of defining something like:

def as_shared_dtype(scalars_or_arrays):
    xp = get_array_namespace(scalars_or_arrays)
    dtype = xp.result_type(*scalars_or_arrays)
    return tuple(xp.asarray(x, dtype) for x in scalars_or_arrays)

shoyer avatar May 15 '24 19:05 shoyer

Does xarray automatically call asarray on scalar function arguments like NumPy does? Certainly the recommendation of the standard is to not do that, because it's cleaner from a typing perspective. Implicitly calling asarray at the top of every function is considered a historical NumPy antipattern. It's not disallowed, but we also should probably avoid standardizing things that encourage it.

asmeurer avatar May 15 '24 20:05 asmeurer

the only time we call that function is when preparing arguments for where (and for concat / stack, but there we don't expect to encounter python scalars), which as far as I can tell doesn't support python scalars.

keewis avatar May 15 '24 20:05 keewis

Xarray objects always contain array objects, but indeed there are functions like where() for which it's convenient to be able to use scalars.

I opened a separate issue to discuss: https://github.com/data-apis/array-api/issues/807

shoyer avatar May 15 '24 20:05 shoyer

This sounds like a useful change to me.

What should the result be if there are multiple Python scalars? Undefined?

This should indeed probably be undefined by the spec.

What is the problem? It seems well-defined to allow multiple. If multiple arrays and dtype objects are allowed, why not multiple Python scalars?

rgommers avatar May 17 '24 13:05 rgommers

I'm not sure, but I think that was referring to a situation where you have no explicit dtypes, just (compatible) python scalars. In that case, we'd have to make an arbitrary choice (or raise an error).

keewis avatar May 17 '24 13:05 keewis

Ah of course. Agreed, there must be at least one array or dtype object.

rgommers avatar May 17 '24 13:05 rgommers

Making this change to result_type seemed fair to everyone in the discussion we just had. Given that our type promotion rules include Python scalars, the function that can be used to apply those promotion rules should support them as well.

rgommers avatar May 30 '24 18:05 rgommers