Can an API for dealing with composite operations be defined?
We've had discussions several times about whether some function that is a composite of other functions already present in the standard can be added. The most recent example is gh-460, which proposed adding abs2 (= abs(x) ** 2). Typically this isn't worth it, unless there are very compelling reasons. Fusing function calls is something that multiple libraries have compilers for, so there may not be a gain for those libraries to add a function like abs2, just extra API surface. And performance-wise, the gain must be quite large for it to be justified to add a function.
The discussion then turned to "is it possible to write for example a standardized way for array API consumers to element-wise apply arbitrary functions to arrays"? Something along the lines of np.vectorize.
Another direction could be to try and write a portable JIT-able set of functions. Something like:
# Here `xp` and `compiler` can be injected
@compiler.jit
def abs2(x):
"""Squared absolute value"""
return xp.real(x) * xp.real(x) + xp.imag(x) * xp.imag(x)
There's multiple options for compilers here, and they work with different array libraries. Perhaps a more future-proof direction (np.vectorize wasn't quite a success, and numexpr type string expressions are a bit of a hack as well ...).
This is not a worked out proposal, just opening it here as a follow-up to the discussion in gh-460 and as a tracker issue to collect more ideas and serve as a reference for when other composite functions are proposed to be added to the standard.
np.vectorizewasn't quite a success, andnumexprtype string expressions are a bit of a hack as well ...
I'd love to know why they didn't work.
@leofang for np.vectorize, from the docstring: The vectorize function is provided primarily for convenience, not for performance. The implementation is essentially a for loop. If you look at the implementation, you see that no C compiler is invoked: https://github.com/numpy/numpy/blob/2a6daf39cc4fd895ab803edf018907cb8044f821/numpy/lib/function_base.py#L2118, it's just pure Python code that happens to be a bit faster than actual for-loops; but not by all that much, there's no fusing of operations.
For numexpr, it does do fusing and is fast. It's more a design and usability issue:
- You need to have working a C++ compiler installed. That's a nonstarter for end users, so you can only use it in a (numpy-using) library and ship compiled extensions. So unlike something that also works with JIT compilers and pure Python code, it's much harder to write portable code (ala array API standard compliant code).
- The UX is strings:
ne.evaluate("2*x +y + 1"). This works, but it's 2022 and that just doesn't seem like a healthy way of doing things.
import numpy as np
import numba
import jax
import jax.numpy as jnp
# Here `xp` and `compiler` can be injected
compiler = jax # numba
xp = jnp # np
@compiler.jit
def abs2(x):
"""Squared absolute value"""
return xp.real(x) * xp.real(x) + xp.imag(x) * xp.imag(x)
x = xp.arange(3000, dtype=xp.complex64)
# %timeit abs2(x)
# %timeit xp.abs(x)
For NumPy + Numba:
>>> %timeit abs2(x)
2.33 µs ± 10.5 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
>>> %timeit xp.abs(x)**2
3.82 µs ± 56.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
For JAX:
>>> %timeit abs2(x)
2.58 µs ± 17.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
>>> %timeit xp.abs(x)**2
86.1 µs ± 383 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
For PyTorch, I tried with torch.jit.script (note that it's not modulename.jit, so it needed a tweak). It does not yield any speedup even after adding type annotations:
>>> %timeit abs2(x)
9.71 µs ± 36.8 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
>>> %timeit xp.abs(x)**2
9.86 µs ± 74.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
There's other PyTorch compilers though, and this really should work given the right one (e.g., TorchDynamo: https://github.com/pytorch/torchdynamo#usage-example).
There's also Transonic (see https://transonic.readthedocs.io/en/latest/backends/pythran.html), which has @jit and @boost (for AOT compilation) decorators, and supports Pythran, Numba and Cython as backends. It is basically a worked-out version of this basic hand-wavy "here's how composite ops could be written".
CuPy will work with Numba too: https://docs.cupy.dev/en/stable/user_guide/interoperability.html#numba.
Triton has a jit decorator, but has a lower-level programming model so pure Python + type annotations is not enough (see for example https://triton-lang.org/master/getting-started/tutorials/01-vector-add.html).
I think the point of this issue is not that we should add an API for this or that someone should write a separate package for these (that is possible, but not sure it's high-value). More that this is the right way of doing it in principle, so every function that can easily be written this way probably should be - rather than expanding the API surface of the standard.