-1

I am trying to make a worker pool in rust. The design itself is simple, the idea is you add function pointers to a structure as tasks (assumed to be independent). Then that structure spawns a number of threads equal to the maximum amount of threads on the machine.

Each thread then acquires a task from the queue, executes it and when it is done acquires the next until they are all consumed.

In order to achieve it I tried the following:

use std::sync::{Arc, Mutex};
use std::thread;

pub struct ThreadPool<'a>
{
    task_queue: Vec<Task<'a>>,
    task_count: Arc<Mutex<usize>>,
}

pub struct Task<'a>(Box<dyn FnMut() + 'a>);
unsafe impl<'a> Send for Task<'a> {}

impl<'a> Task<'a>
{
    fn call(&mut self) { self.0() }
}

impl<'a> ThreadPool<'a>
{
    fn add_task<T>(&mut self, task: T)
    where
        T: 'a + FnMut() -> (),
    {
        self.task_queue.push(Task(Box::new(task)));
    }

    fn run(&mut self)
    {
        let thread_count = thread::available_parallelism().unwrap().get();
        println!("{}", thread_count);

        let mut handlers = Vec::with_capacity(thread_count);
        for _ in 0..thread_count
        {
            unsafe {
                let queue = &mut self.task_queue as *mut Vec<Task<'a>>;
                let task_count = Arc::clone(&self.task_count);
                handlers.push(thread::spawn(move || {
                    let index = task_count.lock().unwrap().overflowing_add(1).0 - 1;
                    (*queue)[index].call();
                }));
            }
        }
    }
}

This is not even compiling.

The kind of interface I would like would be something like:

 let mut thread_pool = ThreadPool {
            task_queue: Vec::new(),
            task_count: Arc::new(Mutex::new(0)),
        };

        for i in 0..100
        {
            thread_pool.add_task(move || println!(r"ran {i} task"));
        }

        thread_pool.run();

i.e. you register the tasks you want, each tasks captures data outside of it in the surrounding scope then it all runs.

I tried searching examples of something like this but the thread pool example in the docs is very different.

error[E0277]: `*mut std::vec::Vec<Task<'a>>` cannot be sent between threads safely
   --> examples/06_fluid/thread_pool.rs:38:45
    |
38  |                   handlers.push(thread::spawn(move || {
    |                                 ------------- ^------
    |                                 |             |
    |  _______________________________|_____________within this `[closure@examples/06_fluid/thread_pool.rs:38:45: 38:52]`
    | |                               |
    | |                               required by a bound introduced by this call
39  | |                     let index = task_count.lock().unwrap().overflowing_add(1).0 - 1;
40  | |                     (*queue)[index].call();
41  | |                 }));
    | |_________________^ `*mut std::vec::Vec<Task<'a>>` cannot be sent between threads safely
    |
    = help: within `[closure@examples/06_fluid/thread_pool.rs:38:45: 38:52]`, the trait `Send` is not implemented for `*mut std::vec::Vec<Task<'a>>`
note: required because it's used within this closure
   --> examples/06_fluid/thread_pool.rs:38:45
    |
38  |                 handlers.push(thread::spawn(move || {
    |                                             ^^^^^^^
note: required by a bound in `spawn`
   --> /home/makogan/.rustup/toolchains/nightly-2022-10-29-x86_64-unknown-linux-gnu/lib/rustlib/src/rust/library/std/src/thread/mod.rs:705:8
    |
705 |     F: Send + 'static,
    |        ^^^^ required by this bound in `spawn`

For more information about this error, try `rustc --explain E0277`.
error: could not compile `neverengine` due to previous error
Makogan
  • 8,208
  • 7
  • 44
  • 112
  • @Finomnis I added it. – Makogan Jul 08 '23 at 07:17
  • 3
    The code you posted does [not produce the error you claim](https://play.rust-lang.org/?version=stable&mode=debug&edition=2021&gist=8cc72faeb1a8f882e881b6d475f7be43). Please provide a [MRE]. – Finomnis Jul 08 '23 at 07:37
  • @Finomnis I copied the wrong part of the terminal log, that was my bad – Makogan Jul 08 '23 at 08:28
  • Your problem is solved by [Can a struct containing a raw pointer implement Send and be FFI safe?](https://stackoverflow.com/questions/50258359/can-a-struct-containing-a-raw-pointer-implement-send-and-be-ffi-safe), but your code is very, very bad. Specifically, `unsafe impl Send` in your code is a mistake and unsound, using raw pointer for the `Vec` of tasks is also unsound from multiple sides, `overflowing_add()` could be just `wrapping_add()` but is also unsound, and for last the `Mutex` could be replaced by `AtomicUsize`. – Chayim Friedman Jul 08 '23 at 20:06
  • Do not drop into unsafe code if you don't understand every detail of what you're doing. – Chayim Friedman Jul 08 '23 at 20:06
  • I'm hesitating to close this as a duplicate of https://stackoverflow.com/questions/50258359/... On one hand, this _is_ a duplicate. On the other hand, solving this error won't give you a good or safe code. – Chayim Friedman Jul 08 '23 at 20:08
  • @ChayimFriedman Here you go, I added an answer for your peace of mind ;) – Finomnis Jul 09 '23 at 14:12

1 Answers1

2

As @ChayimFriedman pointed out, there are many things wrong with your code.

  • Most importantly, don't use unsafe to fix errors if you don't understand their root cause. You most certainly only fix the symptoms, but the root cause still exists and will bite you at a later point.
  • Don't ever, not even inside of unsafe, cast a & reference to a &mut reference. This is always a bad idea and a strong indicator of that something else is already pretty wrong.
  • Don't use Arcs prematurely. Most of the time, there are better alternatives, like std::thread::scope.
  • Don't enforce Send via unsafe. Every type is Send by default, and if a type isn't, then there is a reason. In your case, the reason is that your task type should be Box<dyn FnMut() + 'a + Send>, then the entire Task struct is also Send automatically.
  • Don't use Mutex prematurely. In many situations (of course not all of them), there are better alternatives, like atomics.
  • If you iterate through arrays via index inside of unsafe, make absolutely sure that you do not run out of bounds. Your code performs no overflow check before accessing (*queue)[index].
  • You never loop; currently every thread only takes the same last task and executes it. This already proves why casting a & to a &mut is a bad idea.
  • Be aware of the difference between Fn, FnMut and FnOnce. In your case, if you are sure that you only want to run each task exactly once, don't use FnMut - use FnOnce instead. All FnMut are automatically FnOnce, but not the other way round.
  • Last but not least: there are very good crates for exactly this type of parallelism, like rayon. Use those instead, it will be safer and perform better.

That said, here's two examples for possible solutions:

  • Based on your code, but I changed:
    • Remove Arc and use std::thread::scope instead
    • Remove the entire unsafe thing and introduce a Mutex around self.task_queue instead. Use a VecDeque instead of a Vec, so we can simply pop the items out instead of having to count. With that knowledge, we can remove task_count alltogether.
    • Change FnMut to FnOnce, as we destroy the task object in the process and therefore have ownership over it. This will increase the types of functions we allow.
    • Fix bugs
use std::collections::VecDeque;
use std::sync::Mutex;
use std::thread;

pub struct ThreadPool<'a> {
    task_queue: Mutex<VecDeque<Task<'a>>>,
}

pub struct Task<'a>(Box<dyn FnOnce() + 'a + Send>);

impl<'a> Task<'a> {
    fn call(self) {
        self.0()
    }
}

impl<'a> ThreadPool<'a> {
    pub fn new() -> Self {
        Self {
            task_queue: Mutex::new(VecDeque::new()),
        }
    }

    pub fn add_task<T>(&mut self, task: T)
    where
        T: 'a + FnOnce() -> () + Send,
    {
        // Use `get_mut()` because we already have exclusive access to the mutex
        self.task_queue
            .get_mut()
            .unwrap()
            .push_back(Task(Box::new(task)));
    }

    pub fn run(&mut self) {
        let thread_count = thread::available_parallelism().unwrap().get();
        println!("Threads: {}", thread_count);

        thread::scope(|s| {
            let mut handlers = Vec::with_capacity(thread_count);

            for _ in 0..thread_count {
                handlers.push(s.spawn(|| loop {
                    let potential_task = {
                        // Important: put this in its own scope, to make sure `task_queue.lock()` is
                        // released before `task.call()`. Otherwise no two threads will be allowed
                        // to execute a task simultaneously.
                        self.task_queue.lock().unwrap().pop_front()
                    };
                    if let Some(task) = potential_task {
                        task.call();
                    } else {
                        break;
                    }
                }));
            }

            // Do something with `handlers`. Or don't, `handlers` can also be removed entirely.
            // Threads get joined automatically at the end of `scope`.
        });
    }
}
use std::{thread::sleep, time::Duration};

use rust_playground::ThreadPool;

fn main() {
    let mut threadpool = ThreadPool::new();

    for i in 0..10 {
        threadpool.add_task(move || {
            println!(
                "Running task {} on thread {:?} ...",
                i,
                std::thread::current().id()
            );
            sleep(Duration::from_millis(100));
        })
    }

    threadpool.run();
}
Threads: 4
Running task 0 on thread ThreadId(3) ...
Running task 1 on thread ThreadId(4) ...
Running task 2 on thread ThreadId(2) ...
Running task 3 on thread ThreadId(5) ...
Running task 4 on thread ThreadId(3) ...
Running task 5 on thread ThreadId(2) ...
Running task 7 on thread ThreadId(5) ...
Running task 6 on thread ThreadId(4) ...
Running task 8 on thread ThreadId(3) ...
Running task 9 on thread ThreadId(2) ...
  • Using rayon, which I highly recommend:
use rayon::prelude::*;

pub struct ThreadPool<'a> {
    task_queue: Vec<Task<'a>>,
}

pub struct Task<'a>(Box<dyn FnOnce() + 'a + Send>);

impl<'a> Task<'a> {
    fn call(self) {
        self.0()
    }
}

impl<'a> ThreadPool<'a> {
    pub fn new() -> Self {
        Self {
            task_queue: Vec::new(),
        }
    }

    pub fn add_task<T>(&mut self, task: T)
    where
        T: 'a + FnMut() -> () + Send,
    {
        self.task_queue.push(Task(Box::new(task)));
    }

    pub fn run(&mut self) {
        let task_queue = std::mem::take(&mut self.task_queue);
        task_queue.into_par_iter().for_each(|task| task.call());
    }
}

Using the same main() as before, we get:

Running task 0 on thread ThreadId(2) ...
Running task 5 on thread ThreadId(3) ...
Running task 2 on thread ThreadId(5) ...
Running task 3 on thread ThreadId(4) ...
Running task 1 on thread ThreadId(2) ...
Running task 6 on thread ThreadId(3) ...
Running task 7 on thread ThreadId(5) ...
Running task 4 on thread ThreadId(4) ...
Running task 8 on thread ThreadId(2) ...
Running task 9 on thread ThreadId(3) ...

Last but not least, I think once you understand how to use rayon properly, you don't even need to write your own worker pool any more. rayon already takes care of all of this internally, simply use its parallel iterators ;)

use std::{thread::sleep, time::Duration};

use rayon::prelude::*;

fn main() {
    let mut tasks = vec![];

    for i in 0..10 {
        tasks.push(move || {
            println!(
                "Running task {} on thread {:?} ...",
                i,
                std::thread::current().id()
            );
            sleep(Duration::from_millis(100));
        })
    }

    tasks.par_iter_mut().for_each(|task| task());
}
Running task 0 on thread ThreadId(2) ...
Running task 5 on thread ThreadId(5) ...
Running task 7 on thread ThreadId(3) ...
Running task 2 on thread ThreadId(4) ...
Running task 1 on thread ThreadId(2) ...
Running task 6 on thread ThreadId(5) ...
Running task 8 on thread ThreadId(3) ...
Running task 3 on thread ThreadId(4) ...
Running task 4 on thread ThreadId(5) ...
Running task 9 on thread ThreadId(3) ...
Finomnis
  • 18,094
  • 1
  • 20
  • 27