WIP: `top_k` tests
The purpose of this PR is to continue several threads of discussion regarding top_k.
This follows roughly the specifications of top_k in data-apis/array-api#722, with slight modifications to the API to increase compatibility:
def top_k(
x: array,
k: int,
/,
axis: Optional[int] = None,
*,
largest: bool = True,
) -> Tuple[array, array]:
...
Modifications:
-
mode: Literal["largest", "smallest"]is replaced withlargest: bool -
axisis no longer a kw-only arg. This makestorch.topkslightly more compatible.
The tests implemented here follows the proposed top_k implementation at numpy/numpy#26666.
Compatibility concerns with prior art:
-
numpy: None if numpy/numpy#26666 gets merged. -
torch: In torch the API name istopkinstead, andtorch.topkis only implemented for certain dtypes (e.g.,topk_cpudoes not implementUInt16). -
tensorflow:axiskeyword does not exist, behaves likeaxis=-1. -
JAX: Same astensorflow -
Dask:largestkeyword does not exist,largestflag is instead determined by the sign ofk.
So the next step here would be to implement wrappers for this function in array_api_compat. That way we will be able to see just how complex the required changes are, and also so we can verify that there aren't other incompatibilities, since the tests can't check things later in the test if the things earlier in them the fail.
The way I would do this is to make a PR to the compat library that
-
Modifies the actions file to point to this PR: https://github.com/data-apis/array-api-compat/blob/ac15c526d9769f77c780958a00097dfd183a2a37/.github/workflows/array-api-tests.yml#L53
-
Adds compat wrappers for the different libraries.
-
Sparse and JAX are currently not tested in the compat library CI, because their support is entirely in the libraries themselves. So what you can do is make a simple wrapper namespace that just wraps
top_kand nothing else. Then add a CI script like# .github/array-api-tests-jax.yml name: Array API Tests (JAX) on: [push, pull_request] jobs: array-api-tests-jax: uses: ./.github/workflows/array-api-tests.yml with: package-name: jax pytest-extra-args: -k top_kThis should hopefully be straightforward, but let me know if you run into any issues.
-
CuPy cannot be tested on CI. However, CuPy should be identical to NumPy, so if you don't have access to a CUDA machine, I wouldn't worry about it for now.
-
Don't worry about tensorflow. It hasn't been included in the compat library at all yet.
-
Considering https://github.com/numpy/numpy/pull/26666 is a simple pure Python implementation, if
top_kis accepted we can reuse it for the NumPy 1.26 wrapper. For now, though, you can either copy it as the NumPy wrapper to your compat PR to verify it, or change the NumPy dev CI job to point to your NumPy PR.
Here are some development notes for the compat library which should be helpful https://data-apis.org/array-api-compat/dev/index.html, but also feel free to ask any questions here.
The main purpose is to just an idea of what the wrappers will look like and to get CI log showing the tests pass (or if something is too hard to wrap, what the error is). So you don't need to worry too much about making everything perfect and mergeable. If top_k is eventually added to the standard we can cleanup these PRs and use them.
If you see something that could be changed in array-api-compat or the test suite to make this process easier, make a note of it. We are going to want to be able to repeat this whole process in the future any time a new function is proposed for inclusion in the array API.
torch.topk is only implemented for certain dtypes (e.g., topk_cpu does not implement UInt16).
Quite a few things in PyTorch don't work with smaller uint dtypes. They are skipped in the CI, so you don't need to worry about that. If the torch wrapper is just top_k = topk and these tests pass, that will be a good sign that the proposed specification matches the existing PyTorch implementation.