Argmax/margmin for Array1
It would be useful to have a function like https://athemathmo.github.io/rulinalg/doc/rulinalg/utils/fn.argmax.html
It'd be useful to have a general argmin/argmax for ArrayBase, also it seems there're no universal min, max, sum available for convenience!
What's a universal max, min, sum? Just curious.
On the topic of sum, the plan is that current Array method .scalar_sum() is renamed to .sum(). The plan is to also add methods like .std() and .std_axis() (for standard dev) etc.. in that style. Any innovative ideas about how to design the interface in a language where return types are statically known and we don't have default arguments, they are very welcome of course.
@bluss by universal, I meant to make the difference between sum of all elements vs. sum_axis for example. Please see this as well.
The sum/stdev/variance/min/max/etc. operations reduce the dimensionality of the array by one for each axis being iterated over (since they remove that axis). For example, NumPy behaves like this (where the axis argument is optional):
>>> import numpy as np
>>> x = np.zeros((2, 3, 4, 5))
>>> x.sum(axis=None).shape
()
>>> x.sum(axis=2).shape
(2, 3, 5)
>>> x.sum(axis=(1, 2)).shape
(2, 5)
I'd suggest using a trait to handle the various types of axes arguments:
pub trait FoldAxes<A, D: Dimension> {
type Output;
type Repr: AsRef<[isize]>;
/// Should return `None` if fold is over all axes.
/// Instead of `Option<Self::Repr>` here, it could just be `Self::Repr`, where
/// a repr of length zero would mean "all axes".
fn into_repr(self) -> Option<Self::Repr>;
}
// Scalar output case (all axes).
impl<A, D: Dimension> FoldAxes<A, D> for () {
type Output = A;
type Repr = &'static [isize];
fn into_repr(self) -> Option<&'static [isize]> {
None
}
}
// Iterate over single axis.
impl<A, D: Dimension> FoldAxes<A, D> for isize {
type Output = Array<A, D::Smaller>;
type Repr = [isize; 1];
fn into_repr(self) -> Option<[isize; 1]> {
Some([self])
}
}
// impl<A, D: Dimension> FoldAxes<A, D> for (isize,) {...}
impl<A, D: Dimension> FoldAxes<A, D> for (isize, isize) {
type Output = Array<A, <<D as Dimension>::Smaller as Dimension>::Smaller>;
type Repr = [isize; 2];
fn into_repr(self) -> Option<[isize; 2]> {
Some([self.0, self.1])
}
}
// impl<A, D: Dimension> FoldAxes<A, D> for (isize, isize, isize) {...}
// impl<A, D: Dimension> FoldAxes<A, D> for (isize, isize, isize, isize) {...}
// impl<A, D: Dimension> FoldAxes<A, D> for (isize, isize, isize, isize, isize) {...}
// impl<A, D: Dimension> FoldAxes<A, D> for (isize, isize, isize, isize, isize, isize) {...}
// impl<A, D: Dimension> FoldAxes<A, D> for [isize; 1] {...}
// impl<A, D: Dimension> FoldAxes<A, D> for [isize; 2] {...}
// impl<A, D: Dimension> FoldAxes<A, D> for [isize; 3] {...}
// impl<A, D: Dimension> FoldAxes<A, D> for [isize; 4] {...}
// impl<A, D: Dimension> FoldAxes<A, D> for [isize; 5] {...}
// impl<A, D: Dimension> FoldAxes<A, D> for [isize; 6] {...}
impl<'a, A, D: Dimension> FoldAxes<A, D> for &'a [isize] {
type Output = Array<A, IxDyn>;
type Repr = &'a [isize];
fn into_repr(self) -> Option<&'a [isize]> {
Some(self)
}
}
impl<A, D: Dimension> FoldAxes<A, D> for Vec<isize> {
type Output = Array<A, IxDyn>;
type Repr = Vec<isize>;
fn into_repr(self) -> Option<Vec<isize>> {
Some(self)
}
}
// same for `usize`, tuples of `usize`, slices of `usize`, vecs of `usize` (casting to `usize` to `isize`)
Then, for example, sum would be
impl<A, S, D> ArrayBase<S, D>
where
S: Data<Elem = A>,
D: Dimension,
{
fn sum<T>(&self, axes: T) -> <T as FoldAxes<A, D>>::Output
where
A: Clone + Add<A, Output=A>,
T: FoldAxes<A, D>,
{
// (or check zero-length here if not using the Option approach)
let axes_repr: Option<T::Repr> = axes.into_repr();
if let Some(axes) = axes {
let axes_slice: &[isize] = axes.as_ref();
// operate over specified axes...
} else {
// operate over all axes...
}
}
}
which you could call like
// Sum over all axes.
arr.sum(());
// Sum over axis 1.
arr.sum(1);
// Sum over axes 2 and 3.
arr.sum((2, 3));
// Sum over variable number of axes.
arr.sum(axes_vec);
Note that FoldAxes::Output specifies the entire output type instead of just the output dimension because for the scalar case it would be nice to return A instead of Array0<A>. All of mean/stdev/min/max/variance behave the same way. For the argmin/argmax case, the element type of the output scalar/array is usize instead of A, so you would change the return type slightly:
impl<A, S, D> ArrayBase<S, D>
where
S: Data<Elem = A>,
D: Dimension,
{
fn argmin<T>(&self, axes: T) -> <T as FoldAxes<usize, D>>::Output
where
for<'a> &'a A: PartialOrd<&'a A>,
T: FoldAxes<usize, D>,
{
// ...
}
}
You could reduce duplication of code in the method implementations by implementing them in terms of a generalized fold operation (like the sum above but taking an initial value and closure) followed by map_inplace. (map_inplace is necessary for mean/stdev but not for min/max/sum.)
Edit:
For what it's worth, if all you want to support is sum() and sum_axis(axis: isize), I think that's fine too and simpler than the generic approach. I can see potential use cases for summing over multiple axes, though.
You could also have sum() (scalar sum) and sum_axes<T>(axes: T) where T: FoldAxes<> (sum over axis/axes) which would be nice because FoldAxes<A, D> could then be simplified to FoldAxes<D> with the output dimension as an associated type. Now that I think about it some more, I like this approach better than combining sum() and sum_axes() together because the scalar case would be simplified to arr.sum() and the FoldAxes trait would be simpler.
Edit 2:
I remembered that computing the standard deviation requires multiple numerical accumulators. So, instead of fold_axes (which would allocate an array containing all the accumulators) followed by map (to map the accumulators to the final result), I'd suggest combining the fold and map like this:
impl<A, S, D> ArrayBase<S, D>
where
S: Data<Elem = A>,
D: Dimension,
{
fn fold_map_axes<'a, T, B, F, M, O>(&'a self, axes: T, init: B, fold: F, map: M) -> Array<O, <T as FoldAxes<D>>::OutDim>
where
A: 'a,
T: FoldAxes<D>,
F: FnMut(B, &'a A) -> B,
M: FnMut(B) -> O,
{
// ...
}
}
Any updates? This is frequently used... Thanks!