1

I am trying to implement somewhat "batched" matrix multiplication in rust using ndarray. Therefore I am trying to combine some .axis_iters and especially update a "result"-tensors Axis-iters accordingly. For me a problem occurs already when trying something as easy as mutating the "tensor-slices" of a simple array like:

let mut c: ArrayBase<ndarray::OwnedRepr<i32>, Dim<[usize; 2]>> = array![
    [1, 2],
    [1, 2]
];
let d: ArrayBase<ndarray::OwnedRepr<i32>, Dim<[usize; 1]>> = array![1, 1];

c.axis_iter_mut(Axis(1)).for_each(|x| x = d);

The compiler complains at d's position in the last line:

mismatched types
expected struct `ndarray::ArrayBase<ViewRepr<&mut i32>, _>`
   found struct `ndarray::ArrayBase<OwnedRepr<i32>, _>

I am new to rust and not sure, what to do here right now. I see that the types do not match, but I do not know how to set it up in a way that they do and c's columns get updated / replaced as intended.

Dereferencing x with *x in the last line also does not work.

I also had a look at Updating a row of a matrix in rust ndarray but I could not figure out how to get it to work with the .axit_iter_mut.

Please note, that I explicitly want this updates to happen in an axis_iter-way because this is needed for my actual goal of batched matmul.

ginger314
  • 13
  • 3

1 Answers1

0

There is a method that does what you want: assign():

c.axis_iter_mut(Axis(1)).for_each(|mut x| x.assign(&d));
Chayim Friedman
  • 47,971
  • 5
  • 48
  • 77
  • Thank you for your reply. I had already noticed `.assign` but it did not work out. I only forgot the `mut` in front of `x`... Thanks again. – ginger314 Jul 27 '23 at 04:46