array-api-compat icon indicating copy to clipboard operation
array-api-compat copied to clipboard

Fix sign() for torch and cupy

Open asmeurer opened this issue 1 year ago • 3 comments

Neither propagate nans correctly, and torch does not support complex numbers.

Fixes https://github.com/data-apis/array-api-compat/issues/136

asmeurer avatar May 10 '24 20:05 asmeurer

https://github.com/data-apis/array-api-compat/issues/136 should be resolved before this is merged, specifically, we should decide if it's worth fixing the sign(nan) special case, and if we want to keep that special case at all. Regardless of that, though, we should keep the torch complex handling, as it's very straightforward to implement.

asmeurer avatar May 10 '24 20:05 asmeurer

It seems that torch has gained quite a few new test failures since the last time we ran them. I don't know if that's because of a test suite update or a torch update.

asmeurer avatar May 10 '24 20:05 asmeurer

So based on a simple timing test on PyTorch CPU, is 3-10x slower than torch.sign, depending on how many nans are in the tensor. Although sign itself is a fast operation to begin with. But it would definitely be better for this to be fixed upstream.

asmeurer avatar May 24 '24 19:05 asmeurer

It sounds like this will be useful, so I'm going to merge.

asmeurer avatar Sep 03 '24 21:09 asmeurer