array-api-compat icon indicating copy to clipboard operation
array-api-compat copied to clipboard

WIP: Add `top_k` compatibility

Open JuliaPoo opened this issue 1 year ago • 0 comments

This references the PR data-apis/array-api-tests#274 and implements the compatibility layer for top_k.

Summary of Compatibility

  • jax:
    • JAX's top_k does not implement axis or largest arguments. While axis is easily implemented with jax.numpy.swapaxes, largest is not. Implementing the spec in JAX can be done similar to the pure python implementation in numpy/numpy#26666.
    • Notes: Currently tests for JAX fails on unsigned dtypes because of this issue: google/jax#22137.
  • numpy:
    • No concerns if numpy/numpy#26666 gets merged.
  • dask:
    • The runtime of top_k is currently about 2x longer than it has to be since computing the indices and values has to be done separately. This can be rectified when take_along_axis is implemented in dask: dask/dask#3663.
  • torch:
    • No concerns.

Process

As mentioned in the referenced PR, since the process I went through is likely going to be repeated again, here are the steps I took:

  • Create a branch for array-api that adds the corresponding specification.
    • I added the spec in .draft.
  • Create a branch for array-api-tests which implements the new tests and has its array-api submodule pointing to the newly created array-api branch.
    • Being new to submodules, changing the submodule took me forever to debug.
  • Create a branch for array-api-compat (This PR) that implements the compatibility and points the CI to the newly created array-api-tests branch.
    • Add the environment variable ARRAY_API_TESTS_VERSION=draft in the CI.

Since I was implementing tests and compatibility on a non-existent spec, developing all 3 concurrently was incredibly messy. As of now I don't have much opinions on how to improve this process, but a documentation page of the necessary steps will be really helpful for future contributors.

JuliaPoo avatar Jun 27 '24 12:06 JuliaPoo