array-api-compat
array-api-compat copied to clipboard
WIP: Add `top_k` compatibility
This references the PR data-apis/array-api-tests#274 and implements the compatibility layer for top_k.
Summary of Compatibility
-
jax:- JAX's
top_kdoes not implementaxisorlargestarguments. Whileaxisis easily implemented withjax.numpy.swapaxes,largestis not. Implementing the spec in JAX can be done similar to the pure python implementation in numpy/numpy#26666. - Notes: Currently tests for JAX fails on unsigned dtypes because of this issue: google/jax#22137.
- JAX's
-
numpy:- No concerns if numpy/numpy#26666 gets merged.
-
dask:- The runtime of
top_kis currently about 2x longer than it has to be since computing the indices and values has to be done separately. This can be rectified whentake_along_axisis implemented in dask: dask/dask#3663.
- The runtime of
-
torch:- No concerns.
Process
As mentioned in the referenced PR, since the process I went through is likely going to be repeated again, here are the steps I took:
- Create a branch for
array-apithat adds the corresponding specification.- I added the spec in
.draft.
- I added the spec in
- Create a branch for
array-api-testswhich implements the new tests and has itsarray-apisubmodule pointing to the newly createdarray-apibranch.- Being new to submodules, changing the submodule took me forever to debug.
- Create a branch for
array-api-compat(This PR) that implements the compatibility and points the CI to the newly createdarray-api-testsbranch.- Add the environment variable
ARRAY_API_TESTS_VERSION=draftin the CI.
- Add the environment variable
Since I was implementing tests and compatibility on a non-existent spec, developing all 3 concurrently was incredibly messy. As of now I don't have much opinions on how to improve this process, but a documentation page of the necessary steps will be really helpful for future contributors.