Based on the author's clarification in the comments, I'm assuming that the goal here is to iterate over a rectangular submatrix of a matrix. For example, given a matrix
100 200 300 400 500 600
110 210 310 410 510 610
120 220 320 420 520 620
130 230 330 430 530 630
as represented by a slice in row-major order
[100, 200, 300, 400, 500, 600, 110, ..., 530, 630]
we want to iterate over a submatrix such as
210 310 410 510
220 320 420 520
again in row-major order, so the elements we would get would be, in order,
210, 310, 410, 510, 220, 320, 420, 520
In this situation, it is possible to solve this problem relatively efficiently using safe Rust. The trick is to use the split_at_mut method of the slice in the data
field of Iter2DMut
, in order to peel off one mutable reference at a time as needed. As the iteration proceeds, the data
field is updated to a smaller and smaller slice, so that it no longer encompasses elements which have already been iterated over; this is necessary, because at any given iteration Rust would not allow us to produce a mutable reference to an element while also retaining a mutable slice containing that element. By updating the slice, we can ensure that it is always disjoint from the mutable references which have been produced by all previous calls to next()
, satisfying the Rust borrow checker. Here is how this can be done:
use itertools::{Itertools, Product};
use std::ops::Range;
use std::mem;
struct Iter2DMut<'a, T: 'a> {
data: &'a mut [T],
full_shape: (usize, usize),
sub_shape: (usize, usize),
idx_iter: Product<Range<usize>, Range<usize>>,
}
impl<'a, T> Iter2DMut<'a, T> {
fn new(
data: &'a mut [T],
full_shape: (usize, usize),
sub_shape: (usize, usize),
offset: (usize, usize),
) -> Self {
assert!(full_shape.0 * full_shape.1 == data.len());
assert!(offset.0 + sub_shape.0 <= full_shape.0);
assert!(offset.1 + sub_shape.1 <= full_shape.1);
Iter2DMut {
data: &mut data[offset.0 * full_shape.1 + offset.1 ..],
full_shape,
sub_shape,
idx_iter: (0..sub_shape.0).cartesian_product(0..sub_shape.1)
}
}
}
impl<'a, T: 'a> Iterator for Iter2DMut<'a, T> {
type Item = &'a mut T;
fn next(&mut self) -> Option<Self::Item> {
if let Some((_, j)) = self.idx_iter.next() {
let mut data: &'a mut [T] = &mut [];
mem::swap(&mut self.data, &mut data);
let (first, rest) = data.split_at_mut(1);
data = rest;
if j == self.sub_shape.1 - 1 {
let n_skip = self.full_shape.1 - self.sub_shape.1;
let (_, rest) = data.split_at_mut(n_skip);
data = rest;
}
self.data = data;
Some(&mut first[0])
} else {
None
}
}
}
fn main() {
let mut v: Vec<usize> = vec![
100, 200, 300, 400, 500, 600,
110, 210, 310, 410, 510, 610,
120, 220, 320, 420, 520, 620,
130, 230, 330, 430, 530, 630,
];
for x in Iter2DMut::new(&mut v, (4, 6), (2, 4), (1, 1)) {
println!("{}", x);
}
}
There's one other trick here worth noting: we use mem::swap
to move out the data
field from the Iter2DMut
in order to call split_at_mut
on it. We temporarily swap in a dummy value &mut []
; this is necessary since Rust won't allow us to move a value out of a (mutably) borrowed struct (even temporarily) without putting something back in at the same time. On the other hand, if we hadn't tried to move data
out but had simply called split_at_mut
directly, as in self.data.split_at_mut(1)
, it would have failed the borrow checker, because then we would have been borrowing self.data
which only lives as long as the the &mut self
reference input into the next
method, which is not necessarily long as the 'a
lifetime that we need it to be.