In Java, C and Rust you can preempt a thread in a hot loop by directly changing the looped variable to go beyond the loop invariant. You can do this from any thread.
If you need to use the loop variable (you probably do), make sure you copy it into a variable like this, or you'll have a data race:
for (initialVar(0, 0), int loopVal = 0; getValue(0) < getLimit(0); loopVal = increment(0)) {
Math.sqrt(loopVal);
}
Or
for (int loopVal = 0; m->value[0] < m->limit[0]; loopVal = m->value[0]++) {
sqrt(loopVal);
}
You can create cancellable APIs with this approach, your cancel token can represent all the loops indexes that the code is in. So you can create profoundly responsive code. Even if you are deep in encryption or compression or uploading data. As long as you're not in a syscall.
I also pasted a C version of this sourcecode below the Java version. The Rust version is below the C version
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
public class Scheduler {
public static class KernelThread extends Thread {
public Map<LightWeightThread, Boolean> scheduled = new HashMap<>();
public void preempt() {
for (LightWeightThread thread : scheduled.keySet()) {
scheduled.put(thread, false);
for (int loop = 0 ; loop < thread.getLoops(); loop++) {
thread.preempt(loop);
}
}
}
public void addLightWeightThread(LightWeightThread thread) {
scheduled.put(thread, false);
}
public boolean isScheduled(LightWeightThread lightWeightThread) {
return scheduled.get(lightWeightThread);
}
public void run() {
while (true) {
LightWeightThread previous = null;
for (LightWeightThread thread : scheduled.keySet()) {
scheduled.put(thread, false);
}
for (LightWeightThread thread : scheduled.keySet()) {
if (previous != null) {
scheduled.put(previous, false);
}
scheduled.put(thread, true);
thread.run();
previous = thread;
}
}
}
}
public interface Preemptible {
void registerLoop(int name, int defaultValue, int limit);
int increment(int name);
boolean isPreempted(int name);
int getLimit(int name);
int getValue(int name);
void preempt(int id);
int getLoops();
}
public static abstract class LightWeightThread implements Preemptible {
public int kernelThreadId;
public int threadId;
public KernelThread parent;
AtomicInteger[] values = new AtomicInteger[1];
int[] limits = new int[1];
boolean[] preempted = new boolean[1];
int[] remembered = new int[1];
public LightWeightThread(int kernelThreadId, int threadId, KernelThread parent) {
this.kernelThreadId = kernelThreadId;
this.threadId = threadId;
this.parent = parent;
for (int i = 0 ; i < values.length; i++) {
values[i] = new AtomicInteger();
}
}
public void run() {
}
public void registerLoop(int name, int defaultValue, int limit) {
if (preempted.length > name && remembered[name] < limit) {
values[name].set(remembered[name]);
limits[name] = limit;
} else {
values[name].set(defaultValue);
limits[name] = limit;
}
preempted[name] = false;
}
public int increment(int name) {
return values[name].incrementAndGet();
}
public boolean isPreempted(int name) {
return preempted[name];
}
public int getLimit(int name) {
return limits[name];
}
public int getValue(int name) {
return values[name].get();
}
public int initialVar(int name, int value) {
values[name].set(value);
return value;
}
public void preempt(int id) {
remembered[id] = values[id].get();
preempted[id] = true;
while (!values[id].compareAndSet(values[id].get(), limits[id])){};
}
public int getLoops() {
return values.length;
}
}
public static void main(String[] args) throws InterruptedException {
List<KernelThread> kernelThreads = new ArrayList<>();
for (int i = 0; i < 5; i++) {
KernelThread kt = new KernelThread();
for (int j = 0 ; j < 5; j++) {
LightWeightThread lightWeightThread = new LightWeightThread(i, j, kt) {
@Override
public void run() {
while (this.parent.isScheduled(this)) {
System.out.println(String.format("%d %d", this.kernelThreadId, this.threadId));
registerLoop(0, 0, 10000000);
for (initialVar(0, 0); getValue(0) < getLimit(0); increment(0)) {
Math.sqrt(getValue(0));
}
if (isPreempted(0)) {
System.out.println(String.format("%d %d: %d was preempted !%d < %d", this.kernelThreadId, this.threadId, 0, values[0].get(), limits[0]));
}
}
}
};
kt.addLightWeightThread(lightWeightThread);
}
kernelThreads.add(kt);
}
for (KernelThread kt : kernelThreads) {
kt.start();
}
Timer timer = new Timer();
timer.schedule(new TimerTask() {
@Override
public void run() {
for (KernelThread kt : kernelThreads) {
kt.preempt();
}
}
}, 10, 10);
for (KernelThread kt : kernelThreads) {
kt.join();
}
}
}
This is the C version of the same thing:
#include <pthread.h>
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <errno.h>
#include <ctype.h>
#include <time.h>
#include <math.h>
#define handle_error_en(en, msg) \
do { errno = en; perror(msg); exit(EXIT_FAILURE); } while (0)
#define handle_error(msg) \
do { perror(msg); exit(EXIT_FAILURE); } while (0)
struct lightweight_thread {
int thread_num;
volatile int preempted;
int num_loops;
int *limit;
volatile int *value;
int *remembered;
int kernel_thread_num;
struct lightweight_thread* (*user_function) (struct lightweight_thread*);
};
struct thread_info { /* Used as argument to thread_start() */
pthread_t thread_id; /* ID returned by pthread_create() */
int thread_num; /* Application-defined thread # */
char *argv_string; /* From command-line argument */
int lightweight_threads_num;
struct lightweight_thread *user_threads;
volatile int running;
};
struct timer_thread {
pthread_t thread_id;
struct thread_info *all_threads;
int num_threads;
int lightweight_threads_num;
int delay;
volatile int running;
};
static void *
timer_thread_start(void *arg) {
int iterations = 0;
struct timer_thread *timer_thread = arg;
int msec = 0, trigger = timer_thread->delay; /* 10ms */
clock_t before = clock();
while (timer_thread->running == 1 && iterations < 100000) {
do {
for (int i = 0 ; i < timer_thread->num_threads; i++) {
for (int j = 0 ; j < timer_thread->all_threads[i].lightweight_threads_num; j++) {
// printf("Preempting kernel thread %d user thread %d\n", i, j);
timer_thread->all_threads[i].user_threads[j].preempted = 0;
}
}
for (int i = 0 ; i < timer_thread->num_threads; i++) {
for (int j = 0 ; j < timer_thread->all_threads[i].lightweight_threads_num; j++) {
// printf("Preempting kernel thread %d user thread %d\n", i, j);
for (int loop = 0; loop < timer_thread->all_threads[i].user_threads[j].num_loops; loop++) {
timer_thread->all_threads[i].user_threads[j].remembered[loop] = timer_thread->all_threads[i].user_threads[j].value[loop];
timer_thread->all_threads[i].user_threads[j].value[loop] = timer_thread->all_threads[i].user_threads[j].limit[loop];
}
}
}
clock_t difference = clock() - before;
msec = difference * 1000 / CLOCKS_PER_SEC;
iterations++;
} while ( msec < trigger && iterations < 100000 );
// printf("Time taken %d seconds %d milliseconds (%d iterations)\n",
// msec/1000, msec%1000, iterations);
}
return 0;
}
/* Thread start function: display address near top of our stack,
and return upper-cased copy of argv_string. */
static void *
thread_start(void *arg)
{
struct thread_info *tinfo = arg;
char *uargv;
printf("Thread %d: top of stack near %p; argv_string=%s\n",
tinfo->thread_num, (void *) &tinfo, tinfo->argv_string);
uargv = strdup(tinfo->argv_string);
if (uargv == NULL)
handle_error("strdup");
for (char *p = uargv; *p != '\0'; p++) {
*p = toupper(*p);
}
while (tinfo->running == 1) {
for (int i = 0 ; i < tinfo->lightweight_threads_num; i++) {
tinfo->user_threads[i].preempted = 0;
}
int previous = -1;
for (int i = 0 ; i < tinfo->lightweight_threads_num; i++) {
if (previous != -1) {
tinfo->user_threads[previous].preempted = 0;
}
tinfo->user_threads[i].preempted = 1;
tinfo->user_threads[i].user_function(&tinfo->user_threads[i]);
previous = i;
}
}
return uargv;
}
void
register_loop(int index, int value, struct lightweight_thread* m, int limit) {
if (m->remembered[index] == -1) {
m->limit[index] = limit;
m->value[index] = value;
} else {
m->limit[index] = limit;
m->value[index] = m->remembered[index];
}
}
int
lightweight_thread_function(struct lightweight_thread* m)
{
while (m->preempted != 0) {
register_loop(0, 0, m, 100000000);
for (; m->value[0] < m->limit[0]; m->value[0]++) {
sqrt(m->value[0]);
}
printf("Kernel thread %d User thread %d ran\n", m->kernel_thread_num, m->thread_num);
}
return 0;
}
struct lightweight_thread*
create_lightweight_threads(int kernel_thread_num, int num_threads) {
struct lightweight_thread *lightweight_threads =
calloc(num_threads, sizeof(*lightweight_threads));
if (lightweight_threads == NULL)
handle_error("calloc lightweight threads");
for (int i = 0 ; i < num_threads ; i++) {
lightweight_threads[i].kernel_thread_num = kernel_thread_num;
lightweight_threads[i].thread_num = i;
lightweight_threads[i].num_loops = 1;
lightweight_threads[i].user_function = lightweight_thread_function;
int *remembered = calloc(lightweight_threads[i].num_loops, sizeof(*remembered));
int *value = calloc(lightweight_threads[i].num_loops, sizeof(*value));
int *limit = calloc(lightweight_threads[i].num_loops, sizeof(*limit));
lightweight_threads[i].remembered = remembered;
lightweight_threads[i].value = value;
lightweight_threads[i].limit = limit;
for (int j = 0 ; j < lightweight_threads[i].num_loops ; j++) {
lightweight_threads[i].remembered[j] = -1;
}
}
return lightweight_threads;
}
int
main(int argc, char *argv[])
{
int s, timer_s, opt, num_threads;
pthread_attr_t attr;
pthread_attr_t timer_attr;
ssize_t stack_size;
void *res;
int timer_result;
/* The "-s" option specifies a stack size for our threads. */
stack_size = 16384ul;
num_threads = 5;
while ((opt = getopt(argc, argv, "t:")) != -1) {
switch (opt) {
case 't':
num_threads = strtoul(optarg, NULL, 0);
break;
default:
fprintf(stderr, "Usage: %s [-t thread-size] arg...\n",
argv[0]);
exit(EXIT_FAILURE);
}
}
/* Initialize thread creation attributes. */
s = pthread_attr_init(&attr);
if (s != 0)
handle_error_en(s, "pthread_attr_init");
timer_s = pthread_attr_init(&timer_attr);
if (timer_s != 0)
handle_error_en(s, "pthread_attr_init timer_s");
if (stack_size > 0) {
s = pthread_attr_setstacksize(&attr, stack_size);
int t = pthread_attr_setstacksize(&timer_attr, stack_size);
if (t != 0)
handle_error_en(t, "pthread_attr_setstacksize timer");
if (s != 0)
handle_error_en(s, "pthread_attr_setstacksize");
}
/* Allocate memory for pthread_create() arguments. */
struct thread_info *tinfo = calloc(num_threads, sizeof(*tinfo));
if (tinfo == NULL)
handle_error("calloc");
for (int tnum = 0 ; tnum < num_threads; tnum++) {
tinfo[tnum].running = 1;
}
struct timer_thread *timer_info = calloc(1, sizeof(*timer_info));
timer_info->running = 1;
timer_info->delay = 10;
timer_info->num_threads = num_threads;
if (timer_info == NULL)
handle_error("calloc timer thread");
/* Create one thread for each command-line argument. */
timer_info->all_threads = tinfo;
for (int tnum = 0; tnum < num_threads; tnum++) {
tinfo[tnum].thread_num = tnum + 1;
tinfo[tnum].argv_string = argv[0];
struct lightweight_thread *lightweight_threads = create_lightweight_threads(tnum, num_threads);
tinfo[tnum].user_threads = lightweight_threads;
tinfo[tnum].lightweight_threads_num = num_threads;
/* The pthread_create() call stores the thread ID into
corresponding element of tinfo[]. */
s = pthread_create(&tinfo[tnum].thread_id, &attr,
&thread_start, &tinfo[tnum]);
if (s != 0)
handle_error_en(s, "pthread_create");
}
s = pthread_create(&timer_info[0].thread_id, &timer_attr,
&timer_thread_start, &timer_info[0]);
if (s != 0)
handle_error_en(s, "pthread_create");
/* Destroy the thread attributes object, since it is no
longer needed. */
s = pthread_attr_destroy(&attr);
if (s != 0)
handle_error_en(s, "pthread_attr_destroy");
s = pthread_attr_destroy(&timer_attr);
if (s != 0)
handle_error_en(s, "pthread_attr_destroy timer");
/* Now join with each thread, and display its returned value. */
s = pthread_join(timer_info->thread_id, &timer_result);
if (s != 0)
handle_error_en(s, "pthread_join");
printf("Joined timer thread");
for (int tnum = 0; tnum < num_threads; tnum++) {
tinfo[tnum].running = 0;
s = pthread_join(tinfo[tnum].thread_id, &res);
if (s != 0)
handle_error_en(s, "pthread_join");
printf("Joined with thread %d; returned value was %s\n",
tinfo[tnum].thread_num, (char *) res);
free(res); /* Free memory allocated by thread */
for (int user_thread_num = 0 ; user_thread_num < num_threads; user_thread_num++) {
free(tinfo[tnum].user_threads[user_thread_num].remembered);
free(tinfo[tnum].user_threads[user_thread_num].value);
free(tinfo[tnum].user_threads[user_thread_num].limit);
}
free(tinfo[tnum].user_threads);
}
free(timer_info);
free(tinfo);
exit(EXIT_SUCCESS);
}
The Rust version uses unsafe.
extern crate timer;
extern crate chrono;
use std::sync::Arc;
use std::thread;
use std::sync::atomic::{AtomicI32, Ordering};
use std::{time};
struct LightweightThread {
thread_num: i32,
preempted: AtomicI32,
num_loops: i32,
limit: Vec<AtomicI32>,
value: Vec<AtomicI32>,
remembered: Vec<AtomicI32>,
kernel_thread_num: i32,
lightweight_thread: fn(&mut LightweightThread)
}
fn register_loop(loopindex: usize, initialValue: i32, limit: i32, _thread: &mut LightweightThread) {
if _thread.remembered[loopindex].load(Ordering::Relaxed) < _thread.limit[loopindex].load(Ordering::Relaxed) {
_thread.value[loopindex].store( _thread.remembered[loopindex].load(Ordering::Relaxed), Ordering::Relaxed);
_thread.limit[loopindex].store(limit, Ordering::Relaxed);
} else {
_thread.value[loopindex].store(initialValue, Ordering::Relaxed);
_thread.limit[loopindex].store(limit, Ordering::Relaxed);
}
}
fn lightweight_thread(_thread: &mut LightweightThread) {
register_loop(0usize, 0, 1000000, _thread);
while _thread.preempted.load(Ordering::Relaxed) == 1 {
while _thread.value[0].load(Ordering::Relaxed) < _thread.limit[0].load(Ordering::Relaxed) {
let i = _thread.value[0].load(Ordering::Relaxed);
f64::sqrt(i.into());
_thread.value[0].fetch_add(1, Ordering::Relaxed);
}
}
println!("Kernel thread {} User thread {}", _thread.kernel_thread_num, _thread.thread_num)
}
fn main() {
println!("Hello, world!");
let timer = timer::Timer::new();
static mut threads:Vec<LightweightThread> = Vec::new();
let mut thread_handles = Vec::new();
for kernel_thread_num in 1..=5 {
let thread_join_handle = thread::spawn(move || {
for i in 1..=5 {
let mut lthread = LightweightThread {
thread_num: i,
preempted: AtomicI32::new(0),
num_loops: 1,
limit: Vec::new(),
value: Vec::new(),
remembered: Vec::new(),
kernel_thread_num: kernel_thread_num.clone(),
lightweight_thread: lightweight_thread
};
lthread.limit.push(AtomicI32::new(-1));
lthread.value.push(AtomicI32::new(-1));
lthread.remembered.push(AtomicI32::new(1));
unsafe {
threads.push(lthread);
}
}
loop {
let mut previous:Option<&mut LightweightThread> = None;
unsafe {
for (_pos, current_thread) in threads.iter_mut().enumerate() {
if current_thread.kernel_thread_num != kernel_thread_num {
continue;
}
if !previous.is_none() {
previous.unwrap().preempted.store(0, Ordering::Relaxed)
}
current_thread.preempted.store(1, Ordering::Relaxed);
(current_thread.lightweight_thread)(current_thread);
previous = Some(current_thread);
// println!("Running")
}
}
} // loop forever
}); // thread
thread_handles.push(thread_join_handle);
} // thread generation
let timer_handle = thread::spawn(move || {
unsafe {
loop {
for thread in threads.iter() {
thread.preempted.store(0, Ordering::Relaxed);
}
let mut previous:Option<usize> = None;
for (index, thread) in threads.iter_mut().enumerate() {
if !previous.is_none() {
threads[previous.unwrap()].preempted.store(0, Ordering::Relaxed);
}
previous = Some(index);
for loopindex in 0..thread.num_loops {
thread.remembered[loopindex as usize].store(thread.value[loopindex as usize].load(Ordering::Relaxed), Ordering::Relaxed);
thread.value[loopindex as usize].store(thread.limit[loopindex as usize].load(Ordering::Relaxed), Ordering::Relaxed);
}
thread.preempted.store(1, Ordering::Relaxed);
}
let ten_millis = time::Duration::from_millis(10);
thread::sleep(ten_millis);
} // loop
} // unsafe
}); // end of thread
timer_handle.join();
for thread in thread_handles {
thread.join();
}
}