array-api-compat icon indicating copy to clipboard operation
array-api-compat copied to clipboard

`sign` complex case implementations

Open mdhaber opened this issue 1 year ago • 3 comments

Since the 2022.12 standard, the required implementation of sign has been:

image

but I think only the most recent versions of libraries follow this (if any). Older versions of all libraries and even the most recent versions of some (e.g. CuPy, and even array_api_strict, which I can report separately if need be) use other conventions. It would be helpful if all libraries had aliases of sign that use the new definition.

mdhaber avatar Sep 21 '24 19:09 mdhaber

sign for torch was already fixed at https://github.com/data-apis/array-api-compat/pull/137/files. I didn't realize cupy had the issue too. Do older versions of NumPy have this problem as well?

asmeurer avatar Sep 23 '24 18:09 asmeurer

Yes. Basically everything needs to be patched unless it is recent enough. (I know the torch error would not be present with array_api_compat main, but this is what an environment on Colab looks like after !pip install array_api_compat array_api_strict.

import array_api_compat
print(array_api_compat.__version__)  # 1.8

from array_api_compat import numpy as xp
print(xp.__version__)  # 1.26.4
x = xp.asarray(1 + 2j)
print(xp.sign(x))  # (1+0j)

import cupy as cp
print(cp.__version__)  # 12.2.0
from array_api_compat import cupy as xp
x = xp.asarray(1 + 2j)
print(xp.sign(x))  # (1+0j)

import torch
print(torch.__version__)  # 2.4.1+cu121
from array_api_compat import torch as xp
x = xp.asarray(1 + 2j)
# print(xp.sign(x))  # RuntimeError: Unlike NumPy, torch.sign is not intended to support complex numbers. Please use torch.sgn instead.

import dask
print(dask.__version__)  # 2024.8.0
from array_api_compat.dask import array as xp
x = xp.asarray(1 + 2j)
print(xp.sign(x))  # dask.array<sign, shape=(), dtype=complex128, chunksize=(), chunktype=numpy.ndarray>

import array_api_strict as xp
print(xp.__version__)  # 2.0.1
x = xp.asarray(1 + 2j)
print(xp.sign(x))  # (1+0j)

import jax
print(jax.__version__)  # 0.4.26
import jax.numpy as xp
x = xp.asarray(1 + 2j)
print(xp.sign(x))  # (0.44721365+0.8944273j)

Basically jax.numpy is the only thing that that works with the default installation and I see in the [Change Log](jax 0.4.24 (Feb 6, 2024)) that was only just updated in February.

mdhaber avatar Sep 23 '24 21:09 mdhaber

Interesting. The test suite should be checking this as far as I can tell, but it hasn't come up, even though we do explicitly test against older versions of NumPy. That will require some investigation.

asmeurer avatar Sep 23 '24 21:09 asmeurer

So I dug into this and it looks like the test suite has been ignoring any exceptions raised in the reference implementations in the elementwise function tests. This appears to affect quite a few functions, although it isn't clear yet if there are any actual unwrapped incompatibilities due to this other than this sign one that you've pointed out.

asmeurer avatar Oct 16 '24 20:10 asmeurer