feat: add API specification for returning the `k` largest elements
Update: This PR
-
resolves https://github.com/data-apis/array-api/issues/629 by adding one new API to the Array API specification
-
top_k: returns a tuple whose first element is an array containing the topklargest (or smallest) values and whose second element is an array containing the indices of thosekvalues.
-
The design decisions largely follow the discussion below. The specialized methods top_k_indices and top_k_values have been dropped from the specification.
This PR
-
resolves https://github.com/data-apis/array-api/issues/629 by adding
3new APIs to the Array API specification-
top_k: returns a tuple whose first element is an array containing the topklargest (or smallest) values and whose second element is an array containing the indices of thosekvalues. -
top_k_indices: returns an array containing the indices of theklargest (or smallest) values. -
top_k_values: returns an array containing theklargest (or smallest) values.
-
Prior Art
As illustrated in the API comparison, there is currently no consistent API across array libraries for returning the k largest or smallest values.
- NumPy has
partitionandargpartition, but these return full arrays. WhenaxisisNone, NumPy operates on a flattened input array. To get the topkvalues, one must index and, if wanting sorted values, sort. - CuPy has
partitionandargpartitionand follows NumPy; however, forargpartition, for implementation reasons, it performs a full sort. - Dask has
topkand switches "modes" (largest or smallest) based on whetherkis positive or negative. - JAX has
top_kwhich ~~only returns values~~ always returns both values and indices, as well as NumPy equivalentpartitionandargpartitionAPIs (however, JAX differs in how it handles NaNs). The function only supports searching along the last axis. - PyTorch has
topkwhich always returns both values and indices. - TensorFlow has
top_kwhich always returns both values and indices and only supports searching along the last axis.
Proposed APIs
This PR attempts to synthesize the common themes and best ideas for "top k" APIs as observed among array libraries and attempts to define APIs which adhere to specification precedent in order to promote consistent design and reduce cognitive load.
top_k
def top_k(
x: array,
k: int,
/,
*,
axis: Optional[int] = None,
mode: Literal["largest", "smallest"] = "largest",
) -> Tuple[array, array]
Returns a tuple containing the k largest (or smallest) elements in x.
def top_k_indices(
x: array,
k: int,
/,
*,
axis: Optional[int] = None,
mode: Literal["largest", "smallest"] = "largest",
) -> array
Returns an array containing the indices of the k largest (or smallest) elements in x.
def top_k_values(
x: array,
k: int,
/,
*,
axis: Optional[int] = None,
mode: Literal["largest", "smallest"] = "largest",
) -> array
Returns an array containing the k largest (or smallest) elements in x.
Design Decision Rationale
- The default for
axisisNonein order to matchmin,max,argmin, andargmax. In those APIs, whenaxisisNone(the default), the functions operate over a flattened array. Given thattop_k*may be considered a generalization of the mentioned APIs, ensuring consistency seemed preferable to requiring users to remember a separate set of rules fortop_k*. -
axisonly supportsintandNonein order to matchargminandargmax. Inminandmax, the specification supports specifying multiple axes. Support for multiple axes can be a future specification extension. -
top_kwas chosen overtopkdue to naming concerns discussed elsewhere (namelytop kvsto pk). Furthermore, "top k" follows ML conventions, as opposed tomaxk/max_kornlargest/nsmallestas found in other languages. - The PR includes three separate APIs following the lead of
unique. In that case, rather than support polymorphic return values (e.g., returning values, returning values and indices, return values and counts, etc), we chose define specific API which are monomorphic in their output. We innovated there, and the thinking that went into those design decisions seemed applicable here, where a user may want only values, indices, or both. - The PR follows the
unique_*naming convention, rather than thearg*naming convention, as there are three different return value situations: values, indices, and indices and values. Hence, using a suffix to describe what is returned as inunique_*seems reasonable and follows existing precedent in the specification. - The APIs include a "mode" option to specify the type (largest or smallest) of values to return. Most existing array libraries supporting a "top k" API return only the largest values; however, PyTorch supports returning either the smallest or largest and does so via a
largestkeyword argument. This PR chooses to name the kwargmodein order to be more explicit (what doeslargest=Falsemean to the lay reader?) and follows precedent elsewhere in the specification (e.g.,linalg.qr) wheremodeis used to toggle between different operating modes. - The PR does not include a
sortedkwarg in order to instruct the API to return sorted values (or indices corresponding to sorted values) because (a) the kwarg is not universally supported currently, (b) downstream users can, at least for values, explicitly callsort(except in Dask which doesn't currently support full sorting) after callingtop_kortop_k_values, and (c) can be addressed in a future specification extension. Additionally, if we supportsorted, we may also want to support astablekwarg as insortto allow ensuring that returned indices are consistent when provided the same input array. - Leaves unspecified what should happen when
kexceeds the number of elements, as different behaviors seem acceptable (e.g., raising an exception or returningm < kvalues).
Questions
- Should we be more strict in specifying what should happen when
kexceeds the number of elements? - Should zero-dimensional arrays be supported?
- In
argminandargmax, the specification requires returning the index of first occurrence when a minimum/maximum value occurs multiple times. Given thattop_k*can be implemented as a partial sort, presumably we do not want specify a first occurrence restriction. Is this a reasonable assumption? - Should we defer adding support for specifying multiple axes until a future revision of the specification or should we go ahead and add now for parity with
minandmax? - Are we okay with
Nonebeing the default foraxis, where the default behavior is searching over a flattened array?
Considerations
The APIs included in this PR have implications for the following array libraries:
-
NumPy: these will be new APIs, and, similar to
unique_*, will need to be added to the main namespace. - CuPy: same as NumPy.
-
Dask: will need to introduce new APIs; however, the new APIs can be implemented as lightweight wrappers around Dask's existing
topkandargtopk. -
JAX: were JAX to place the APIs in its
laxnamespace, this PR would introduce breaking changes, as ~~JAX would need to return both values and indices, by default,~~ and JAX would need to flatten by default rather than search along the last dimension. However, if implemented in itsnumpynamespace, these will simply be new APIs. In both scenarios, JAX will need to add support foraxisandmodebehavior. -
PyTorch: these will be new APIs; however, the new APIs can be implemented as lightweight wrappers around PyTorch's existing
topk. -
TensorFlow: additional APIs (
top_k_valuesandtop_k_indices). If implemented in itsmathnamespace, this PR would introduce breaking changes as TensorFlow would need to flatten by default. However, if implemented in itsnumpynamespace, these will simply be new APIs. In both scenarios, TensorFlow will need to add support foraxisandmodebehavior.
Related Links
cc @ogrisel for visibility since you originally opened #722.
@ogrisel Do you have opinions on whether having three separate APIs is preferable to having just argtop_k and top_k?
I tried tracking down actual usages of top k in the wild, but I wasn't able to get a good sense on whether having only two APIs suffices or having three separate APIs is more desirable.
There may also be other combinations. The PR currently specifies
-
top_kreturns[ values, indices ] -
top_k_valuesreturnsvalues -
top_k_indicesreturnsindices
We could, e.g., only specify
-
top_kreturns[ values, indices ] -
argtop_kreturnsindices
or strictly complementary
-
top_kreturnsvalues -
argtop_kreturnsindices
If you have a feel for what is preferable, that would be great to hear!
Thank you very much for the survey of current implementations and API proposal. From a potential API consumer point of view the main proposal seems good for me.
About the name: topk seems to be a bit more popular than top_k in existing libraries, but using top_k might help reducing breaking changes when adopting this spec but it might be better to hear from library implementers.
Should we be more strict in specifying what should happen when k exceeds the number of elements?
I believe so. As a user I would accept the call to fail with a standard exception type such as ValueError.
In argmin and argmax, the specification requires returning the index of first occurrence when a minimum/maximum value occurs multiple times. Given that top_k* can be implemented as a partial sort, presumably we do not want specify a first occurrence restriction. Is this a reasonable assumption?
I think it's fine to allow topk to be faster by not enforcing this constraint. It would always be possible to add stable=False bool kwarg later to make it possible for the users to request stability of the results maybe at the cost of a performance penalty as done for the xp.sort function.
Should we defer adding support for specifying multiple axes until a future revision of the specification or should we go ahead and add now for parity with min and max?
I wouldn't like this requirement to hurt speed of adoption by backing libraries. I don't think many users need that in practice.
Are we okay with None being the default for axis, where the default behavior is searching over a flattened array?
I don't think users have an need for this: then can flatten the input by themselves if need. But I have the impression that NumPy (and therefore Array API) often axis=None as a default convention to work on 1d flatten array anyway so I am fine with staying consistent in that regard.
However for this particular case, the default in numpy.argpartition is axis=-1 instead of axis=None. No strong opinion of which is most natural.
Should we defer adding support for specifying multiple axes until a future revision of the specification or should we go ahead and add now for parity with min and max?
That would remove a lot of value by preventing efficient parallelization by the underlying backend when running topk on a 2D array with axis=0 or axis=1. This pattern was the original motivation for #629 (e.g. k-nearest neighbors classification in scikit-learn).
I tried tracking down actual usages of top k in the wild, but I wasn't able to get a good sense on whether having only two APIs suffices or having three separate APIs is more desirable.
The benefit for the 3 function API would be to always be able to optimize memory usage by not allocating unnecessary arrays when k is large enough for this to matter. However I have no good feeling about how much this would really be a problem in practice (and how much the underlying implementation would be able to skip the extra contiguous memory allocation internally).
Something that seems missing from this spec is to specify the handling of NaN values. Maybe and extra kwarg is needed to specify if they should be considered as either smallest or largest, or if they need to be filtere out from the result (but then the result size would be data-dependent and potentially empty arrays which might also cause problems).
Also I assume that nan values are always smaller than +inf and larger than -inf but maybe not all libraries agree on that.
Revisiting this topic in preparation of helping it move forward. Quick first comment on:
JAX has
top_kwhich only returns values,
I am not sure if it changed in the meantime or you just misread the JAX docs at the time, but https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.top_k.html says that both values and indices are returned.
I had a peek at the jax.lax implementation as well, but could tell so quickly because I'm not familiar with the .bind method. @jakevdp would you mind confirming that the docs are correct here?
So it looks like PyTorch, JAX and TF all return (values, indices). That seems to align well enough with the top_k definition proposed here.
JAX's current top_k functions are in the jax.lax namespace, while the array API implementations will be in the jax.numpy namespace. So there is no issue with having different API conventions here.
The rendered documentation is misleading due to a misplaced colon in the Returns block, but JAX returns (values, indices):
In [2]: x = jax.numpy.arange(100, 110)
In [3]: jax.lax.top_k(x, 3)
Out[3]: [Array([109, 108, 107], dtype=int32), Array([9, 8, 7], dtype=int32)]
Here is a PR with a draft implementation for NumPy: https://github.com/numpy/numpy/pull/26666, aligned with the top_k signature in this PR.
The most sane NaN handling IMO, is to sort NaNs always to the end, which however means that if you implement sort="desc" or largest values here, NaNs should end up also at the end, which is opposite of what happens for ascending sort/smallest values.
The annoyance with that is that it means sort behavior diverges for asc/desc sort beyong a [::-1] also for unstable sort.
Of course one can just leave it unspecified here. OTOH, I dunno how much that limits the usability.