Signatures of `get_xp`-wrapped functions?
There seem to be some issues with the signatures of functions wrapped by get_xp.
I haven't narrowed down the exact problem, but here's an MRE:
import cupy as xp
from array_api_compat import cupy as xp_compat
A = xp.eye(3)
A = xp.asarray(A)
xp.linalg.eigh(A) # fine
xp.linalg.eigh(a=A) # fine
xp_compat.linalg.eigh(A) # fine
xp_compat.linalg.eigh(a=A) # error
# TypeError: eigh() missing 1 required positional argument: 'x'
Also, e.g.
xp.linalg.eigh(A, 'U') # fine
xp_compat.linalg.eigh(A, 'U') # error
TypeError: eigh() got multiple values for argument 'xp'
The problem is straightforward. eigh is defined as
https://github.com/data-apis/array-api-compat/blob/ac15c526d9769f77c780958a00097dfd183a2a37/array_api_compat/common/_linalg.py#L45-L46
In other words, it passes **kwargs through but doesn't pass *args through. This is easy enough to fix
diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py
index dc2b69d..01db3a0 100644
--- a/array_api_compat/common/_linalg.py
+++ b/array_api_compat/common/_linalg.py
@@ -42,8 +42,8 @@ class SVDResult(NamedTuple):
# These functions are the same as their NumPy counterparts except they return
# a namedtuple.
-def eigh(x: ndarray, /, xp, **kwargs) -> EighResult:
- return EighResult(*xp.linalg.eigh(x, **kwargs))
+def eigh(x: ndarray, /, *args, xp, **kwargs) -> EighResult:
+ return EighResult(*xp.linalg.eigh(x, *args, **kwargs))
def qr(x: ndarray, /, xp, *, mode: Literal['reduced', 'complete'] = 'reduced',
**kwargs) -> QRResult:
We just need to do this for every function for which NumPy supports more positional arguments than the standard.
As for the other issue, eigh(a=A) you are calling eigh using the argument name that is positional-only in the standard.
We could make this work by instead defining all functions like
diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py
index dc2b69d..11b54bb 100644
--- a/array_api_compat/common/_linalg.py
+++ b/array_api_compat/common/_linalg.py
@@ -42,8 +42,8 @@ class SVDResult(NamedTuple):
# These functions are the same as their NumPy counterparts except they return
# a namedtuple.
-def eigh(x: ndarray, /, xp, **kwargs) -> EighResult:
- return EighResult(*xp.linalg.eigh(x, **kwargs))
+def eigh(*args, xp, **kwargs) -> EighResult:
+ return EighResult(*xp.linalg.eigh(*args, **kwargs))
def qr(x: ndarray, /, xp, *, mode: Literal['reduced', 'complete'] = 'reduced',
**kwargs) -> QRResult:
The downside of this is it would completely kill introspectability of functions (right now I have it set up so that help(array_api_compat.numpy.linalg.eigh) shows the actual arguments).
Note that both of these examples are not portable with the standard, which defines the signature as
eigh(x: array, /)
Note that both of these examples are not portable with the standard
Yeah, the thing is that this came up in the context of dispatching calls to SciPy's eigh function to other backends. It's unclear ATM what we want to do when the SciPy function has a much more flexible signature (including many other arguments) than the standard.
That said, my impression from here was that array_api_compat did not intend to limit capabilities of the wrapped libraries to those of the standard, so I went ahead and reported it.
Yes, in principle we should support this. Maybe I can modify the get_xp decorator to keep the standard signature for introspection purposes, but always pass through *args and **kwargs automatically. I'll need to think a bit about it.