1

I am trying to learn Rust's multi-threading by implementing a parallel merge sort. A simple recursive version works just fine, but this version:

use rand;

use std::sync::{Arc, Mutex};
use std::thread;

fn main() {
    //let mut input_line = String::new();
    // println!("Input amount of numbers to sort:");
    // let amount = match std::io::stdin().read_line(&mut input_line){
    //     Ok(_) => i64::from_str_radix(&input_line.trim(), 10).unwrap(),
    //     Err(_) => panic!("Error while reading amount of values")
    // };
    let amount = 1_000_000;

    // let mut rnd = rand::thread_rng();
    let mut arr: Vec<i64> = Vec::new();
    for _ in 0..amount {
        arr.push(rand::random::<i64>())
    }

    // println!("Vector before sort:");
    // for elem in &arr {
    //     println!("{}", elem);
    // }

    merge_sort(&mut arr);

    // println!("Vector after sort:");
    // for elem in &arr {
    //     println!("{}", elem);
    // }
}

fn merge_sort(arr: &mut Vec<i64>) {
    let arr_len = arr.len();
    let arr_slice = arr.as_mut_slice();

    // simple_merge_sort(arr, 0 as usize, arr_len - 1 as usize);

    let arc = Arc::new(Mutex::new(arr));
    par_merge_sort(&mut arc, 0 as usize, arr_len - 1 as usize, 4);
}

fn simple_merge_sort(arr: &mut Vec<i64>, lo: usize, hi: usize) {
    if lo == hi {
        return;
    }

    let mi = (hi + lo) / 2;
    simple_merge_sort(arr, lo, mi);
    simple_merge_sort(arr, mi + 1, hi);

    merge(arr, lo, mi, hi);
}

fn par_merge_sort(arc: &mut Arc<Mutex<&mut Vec<i64>>>, lo: usize, hi: usize, threads: i32) {
    if lo == hi {
        return;
    }

    let mi = (hi + lo) / 2_usize;
    if threads == 1 {
        let mut simple_arr = arc.lock().unwrap();
        simple_merge_sort(&mut simple_arr, lo, hi);
    } else {
        let thread_arc = Arc::from(*arc);
        let thread_rest = threads / 2;
        let thread_rest_2 = threads - thread_rest;
        let thread1 = thread::spawn(move || {
            par_merge_sort(&mut thread_arc, lo, mi, thread_rest);
        });
        let thread_arc = Arc::from(*arc);
        let thread2 = thread::spawn(move || {
            par_merge_sort(&mut thread_arc, mi + 1, hi, thread_rest_2);
        });

        thread1.join().unwrap();
        thread2.join().unwrap();
    }

    let mutex = arc.lock().unwrap();
    merge(&mut *mutex, lo, mi, hi);
}

fn merge(arr: &mut Vec<i64>, lo: usize, mi: usize, hi: usize) {
    let mut lo_arr: Vec<i64> = Vec::new();
    for i in lo..(mi + 1) {
        let elem = *arr.get(i).unwrap();
        lo_arr.push(elem);
    }

    let mut hi_arr: Vec<i64> = Vec::new();
    for i in (mi + 1)..(hi + 1) {
        let elem = *arr.get(i).unwrap();
        hi_arr.push(elem);
    }

    let mut i = 0;
    let mut j = 0;
    let mut counter = lo;

    while i < lo_arr.len() && j < hi_arr.len() {
        let elem_i = *lo_arr.get(i).unwrap();
        let elem_j = *hi_arr.get(j).unwrap();

        if elem_i <= elem_j {
            arr[counter] = elem_i;
            i += 1;
        } else {
            // elem_j <= elem_i
            arr[counter] = elem_j;
            j += 1;
        }
        counter += 1;
    }

    if j == hi_arr.len() {
        while i < lo_arr.len() {
            let elem_i = *lo_arr.get(i).unwrap();
            arr[counter] = elem_i;
            i += 1;
            counter += 1;
        }
    } else {
        // i == lo_arr.len()
        while j < hi_arr.len() {
            let elem_j = *hi_arr.get(j).unwrap();
            arr[counter] = elem_j;
            j += 1;
            counter += 1;
        }
    }
}

Produces the error:

error[E0621]: explicit lifetime required in the type of `arc`
  --> src/main.rs:69:23
   |
56 | fn par_merge_sort(arc: &mut Arc<Mutex<&mut Vec<i64>>>, lo: usize, hi: usize, threads: i32) {
   |                        ------------------------------ help: add explicit lifetime `'static` to the type of `arc`: `&mut Arc<Mutex<&'static mut Vec<i64>>>`
...
69 |         let thread1 = thread::spawn(move || {
   |                       ^^^^^^^^^^^^^ lifetime `'static` required

error[E0621]: explicit lifetime required in the type of `arc`
  --> src/main.rs:73:23
   |
56 | fn par_merge_sort(arc: &mut Arc<Mutex<&mut Vec<i64>>>, lo: usize, hi: usize, threads: i32) {
   |                        ------------------------------ help: add explicit lifetime `'static` to the type of `arc`: `&mut Arc<Mutex<&'static mut Vec<i64>>>`
...
73 |         let thread2 = thread::spawn(move || {
   |                       ^^^^^^^^^^^^^ lifetime `'static` required
Shepmaster
  • 388,571
  • 95
  • 1,107
  • 1,366
Marek Nagy
  • 13
  • 2
  • 1
    Your question might be answered by the answers of [How to get mutable references to two array elements at the same time?](https://stackoverflow.com/q/30073684); [How to operate on 2 mutable slices of a Rust array?](https://stackoverflow.com/q/36244762); [How do I pass disjoint slices from a vector to different threads?](https://stackoverflow.com/q/33818141); [How can I pass a reference to a stack variable to a thread?](https://stackoverflow.com/q/32750829). If not, please **[edit]** your question to explain the differences. Otherwise, we can mark this question as already answered. – Shepmaster Dec 22 '20 at 20:32
  • 1
    See also [Why does my parallel merge algorithm produce the correct values in all positions of the output except the first?](https://stackoverflow.com/q/64124353/155423) – Shepmaster Dec 22 '20 at 20:33

1 Answers1

0

Since your question is about parallelizing and not sorting I've omitted the implementations for the serial_sort and merge functions in the example below but you can easily fill them in yourself using what code you have already:

#![feature(is_sorted)]

use crossbeam; // 0.8.0
use rand; // 0.7.3
use rand::Rng;

fn random_vec(capacity: usize) -> Vec<i64> {
    let mut vec = vec![0; capacity];
    rand::thread_rng().fill(&mut vec[..]);
    vec
}

fn parallel_sort(data: &mut [i64], threads: usize) {
    let chunks = std::cmp::min(data.len(), threads);
    let _ = crossbeam::scope(|scope| {
        for slice in data.chunks_mut(data.len() / chunks) {
            scope.spawn(move |_| serial_sort(slice));
        }
    });
    merge(data, chunks);
}

fn serial_sort(data: &mut [i64]) {
    // actual implementation omitted for conciseness
    data.sort()
}

fn merge(data: &mut [i64], _sorted_chunks: usize) {
    // actual implementation omitted for conciseness
    data.sort()
}

fn main() {
    let mut vec = random_vec(10_000);
    parallel_sort(&mut vec, 4);
    assert!(vec.is_sorted());
}

playground

parallel_sort breaks the data into n chunks and sorts each chunk in its own thread and the merges the sorted chunks together before finally returning.

pretzelhammer
  • 13,874
  • 15
  • 47
  • 98