12

I have a 2D array of type f32 (from ndarray::ArrayView2) and I want to find the index of the maximum value in each row, and put the index value into another array.

The equivalent in Python is something like:

import numpy as np

for i in range (0, max_val, batch_size):
   sims = xp.dot(batch, vectors.T) 
   # sims is the dot product of batch and vectors.T
   # the shape is, for example, (1024, 10000)

   best_rows[i: i+batch_size] = sims.argmax(axis = 1)

In Python, the function .argmax is very fast, but I don't see any function like that in Rust. What's the fastest way of doing so?

Shepmaster
  • 388,571
  • 95
  • 1,107
  • 1,366

2 Answers2

7

Consider the easy case of a general Ord type: The answer will differ slightly depending on whether you know the values are Copy or not, but here's the code:

fn position_max_copy<T: Ord + Copy>(slice: &[T]) -> Option<usize> {
    slice.iter().enumerate().max_by_key(|(_, &value)| value).map(|(idx, _)| idx)
}

fn position_max<T: Ord>(slice: &[T]) -> Option<usize> {
    slice.iter().enumerate().max_by(|(_, value0), (_, value1)| value0.cmp(value1)).map(|(idx, _)| idx)
}

The basic idea is that we pair [a reference to] each item in the array (really, a slice - it doesn't matter if it's a Vec or an array or something more exotic) with its index, use std::iter::Iterator functions to find the maximum value according to the value only (not the index), then return just the index. If the slice is empty None will be returned. Per the documentation, the rightmost index will be returned; if you need the leftmost, do rev() after enumerate().

rev(), enumerate(), max_by_key(), and max_by() are documented here; slice::iter() is documented here (but that one needs to be on your shortlist of things to recall without documentation as a rust dev); map is Option::map() documented here (ditto). Oh, and cmp is Ord::cmp but most of the time you can use the Copy version which doesn't need it (e.g. if you're comparing integers).


Now here's the catch: f32 isn't Ord because of the way IEEE floats work. Most languages ignore this and have subtly wrong algorithms. The most popular crate to provide a total order on Ord (by declaring all NaN to be equal, and greater than all numbers) seems to be ordered-float. Assuming it's implemented correctly it should be very very lightweight. It does pull in num_traits but this is part of the most popular numerics library so might well be pulled in by other dependencies already.

You'd use it in this case by mapping ordered_float::OrderedFloat (the "constructor" of the tuple type) over the slice iter (slice.iter().map(ordered_float::OrderedFloat)). Since you only want the position of the maximum element, no need to extract the f32 afterward.

David A
  • 422
  • 2
  • 4
  • Note that this is for a 1D vector, but the OP is using a 2D array, so he will need to iterate over the rows of his array and call `position_max` for each row. – Jmb Sep 06 '19 at 06:21
  • 3
    For the one-dimensional case, another option is `(0..slice.len()).max_by_key(|i| &slice[i])`. (I did not test this, but it should work regardless of whether `T: Copy`.) – Sven Marnach Sep 06 '19 at 07:45
  • Yes, that's probably easier to understand TBH. The issue with `max_by_key` for `T: !Copy` is that it (implicitly) requires the return type to be `Ord + 'static`; I'm not sure if that's actually necessary for the algorithm or maybe an oversight. – David A Sep 07 '19 at 00:52
4

The approach from @David A is cool, but as mentioned, there's a catch: f32 & f64 do not implement Ord::cmp. (Which is really a pain in your-know-where.)

There are multiple ways of solving that: You can implement cmp yourself, or you can use ordered-float, etc..

In my case, this is a part of a bigger project and we are very careful about using external packages. Besides, I am pretty sure we don't have any NaN values. Therefore I would prefer using fold, which, if you take a close look at the max_by_key source code, is what they have been using too.

for (i, row) in matrix.axis_iter(Axis(1)).enumerate() {
    let (max_idx, max_val) =
        row.iter()
            .enumerate()
            .fold((0, row[0]), |(idx_max, val_max), (idx, val)| {
                if &val_max > val {
                    (idx_max, val_max)
                } else {
                    (idx, *val)
                }
            });
}
Shepmaster
  • 388,571
  • 95
  • 1,107
  • 1,366