1

I am working on a parallel matrix multiplication code in Rust, where I want to compute every element of the product in parallel. I use ndarrays to store my data. Thus, my code would be something alone the lines

fn mul(lhs: &Array2<f32>, rhs: &Array2<f32>) -> Array2<f32> {
   let N = lhs.raw_size()[0];
   let M = rhs.raw_size()[1];
   let mut result = Array2::zeros((N,M));
   
   range_2d(0..N,0..M).par_iter().map(|(i, j)| {
      // load the result for the (i,j) element into 'result'
   }).count();

   result
}

Is there any way to achieve this?

tadman
  • 208,517
  • 23
  • 234
  • 262
stomfaig
  • 15
  • 4
  • A) How big is this matrix? B) What's your threading strategy? C) Why not use an existing math library? D) Are you using SIMD? E) Would using a GPU not be far better? Is this an option? – tadman Jul 18 '23 at 15:16
  • A) rhs will be a matrix with sparse rows, around 1000 entries maybe. lhs will be a dense random gaussian matrix, potentially really big. B) I'm not sure what you mean C) partly for the sake of practice, and also because the use case is quite specialised, and I couldn't find a suitable library E) yes, using GPU would definitely help – stomfaig Jul 18 '23 at 15:19
  • Matrix multiplication is a *very* well trod bit of terrain, there's even [crates for your specific use case](https://github.com/bluss/matrixmultiply). If you're looking to roll your own for academic reasons, nothing wrong with that, but worth studying the work of others. – tadman Jul 18 '23 at 15:21
  • I ask about threading strategy because there's tools like [Rayon](https://docs.rs/rayon/latest/rayon/) that make this pretty straight-forward if you use them properly. Is that what `par_iter()` is from here? – tadman Jul 18 '23 at 15:22
  • 1
    `ndarray` does support `rayon`, but I see no support for extracting the indices. – Chayim Friedman Jul 18 '23 at 15:22
  • 1
    'par_iter' is rayon's iterator, yes. My main concern though, is getting all the mutable refernces – stomfaig Jul 18 '23 at 15:24
  • You might go with a map-reduce style here, parallel iterate, then collect into your desired results. A functional style does not need mutable references. – tadman Jul 18 '23 at 15:26
  • Yes, that's a good point. Don't you think there would be tons overhead for all the additions though? Especially if the matrix is large? – stomfaig Jul 18 '23 at 15:31

1 Answers1

0

You can create a parallel iterator this way:

use rayon::prelude::*;

pub fn mul(lhs: &Array2<f32>, rhs: &Array2<f32>) -> Array2<f32> {
    let n = lhs.raw_dim()[0];
    let m = rhs.raw_dim()[1];
    let mut result = Array2::zeros((n, m));

    result
        .axis_iter_mut(Axis(0))
        .into_par_iter()
        .enumerate()
        .flat_map(|(n, axis)| {
            axis.into_slice()
                .unwrap()
                .par_iter_mut()
                .enumerate()
                .map(move |(m, item)| (n, m, item))
        })
        .for_each(|(n, m, item)| {
            // Do the multiplication.
            *item = n as f32 * m as f32;
        });

    result
}
Chayim Friedman
  • 47,971
  • 5
  • 48
  • 77
  • As I know par_bridge is really inefficient though. In essence I believe the bottleneck is the Array2, since slicing it into mutable blocks is possible, but is just a pain. When asking the question, I though there might be a way to generate raw pointers to each of the items, in the fashion of how [split_at_mut](https://doc.rust-lang.org/nomicon/borrow-splitting.html) does it, but it seems that it's not the case. – stomfaig Jul 18 '23 at 16:33
  • @stomfaig After further thinking, I came into solution without `par_bridge()`. Edited the answer above. – Chayim Friedman Jul 18 '23 at 18:16