I would like to have a custom thread pool satisfying the following requirements:
- Real threads are preallocated according to the pool capacity. The actual work is free to use the standard .NET thread pool, if needed to spawn concurrent tasks.
- The pool must be able to return the number of idle threads. The returned number may be less than the actual number of the idle threads, but it must not be greater. Of course, the more accurate the number the better.
- Queuing work to the pool should return a corresponding
Task
, which should place nice with the Task based API. - NEW The max job capacity (or degree of parallelism) should be adjustable dynamically. Trying to reduce the capacity does not have to take effect immediately, but increasing it should do so immediately.
The rationale for the first item is depicted below:
- The machine is not supposed to be running more than N work items concurrently, where N is relatively small - between 10 and 30.
- The work is fetched from the database and if K items are fetched then we want to make sure that there are K idle threads to start the work right away. A situation where work is fetched from the database, but remains waiting for the next available thread is unacceptable.
The last item also explains the reason for having the idle thread count - I am going to fetch that many work items from the database. It also explains why the reported idle thread count must never be higher than the actual one - otherwise I might fetch more work that can be immediately started.
Anyway, here is my implementation along with a small program to test it (BJE stands for Background Job Engine):
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
namespace TaskStartLatency
{
public class BJEThreadPool
{
private sealed class InternalTaskScheduler : TaskScheduler
{
private int m_idleThreadCount;
private readonly BlockingCollection<Task> m_bus;
public InternalTaskScheduler(int threadCount, BlockingCollection<Task> bus)
{
m_idleThreadCount = threadCount;
m_bus = bus;
}
public void RunInline(Task task)
{
Interlocked.Decrement(ref m_idleThreadCount);
try
{
TryExecuteTask(task);
}
catch
{
// The action is responsible itself for the error handling, for the time being...
}
Interlocked.Increment(ref m_idleThreadCount);
}
public int IdleThreadCount
{
get { return m_idleThreadCount; }
}
#region Overrides of TaskScheduler
protected override void QueueTask(Task task)
{
m_bus.Add(task);
}
protected override bool TryExecuteTaskInline(Task task, bool taskWasPreviouslyQueued)
{
return TryExecuteTask(task);
}
protected override IEnumerable<Task> GetScheduledTasks()
{
throw new NotSupportedException();
}
#endregion
public void DecrementIdleThreadCount()
{
Interlocked.Decrement(ref m_idleThreadCount);
}
}
private class ThreadContext
{
private readonly InternalTaskScheduler m_ts;
private readonly BlockingCollection<Task> m_bus;
private readonly CancellationTokenSource m_cts;
public readonly Thread Thread;
public ThreadContext(string name, InternalTaskScheduler ts, BlockingCollection<Task> bus, CancellationTokenSource cts)
{
m_ts = ts;
m_bus = bus;
m_cts = cts;
Thread = new Thread(Start)
{
IsBackground = true,
Name = name
};
Thread.Start();
}
private void Start()
{
try
{
foreach (var task in m_bus.GetConsumingEnumerable(m_cts.Token))
{
m_ts.RunInline(task);
}
}
catch (OperationCanceledException)
{
}
m_ts.DecrementIdleThreadCount();
}
}
private readonly InternalTaskScheduler m_ts;
private readonly CancellationTokenSource m_cts = new CancellationTokenSource();
private readonly BlockingCollection<Task> m_bus = new BlockingCollection<Task>();
private readonly List<ThreadContext> m_threadCtxs = new List<ThreadContext>();
public BJEThreadPool(int threadCount)
{
m_ts = new InternalTaskScheduler(threadCount, m_bus);
for (int i = 0; i < threadCount; ++i)
{
m_threadCtxs.Add(new ThreadContext("BJE Thread " + i, m_ts, m_bus, m_cts));
}
}
public void Terminate()
{
m_cts.Cancel();
foreach (var t in m_threadCtxs)
{
t.Thread.Join();
}
}
public Task Run(Action<CancellationToken> action)
{
return Task.Factory.StartNew(() => action(m_cts.Token), m_cts.Token, TaskCreationOptions.DenyChildAttach, m_ts);
}
public Task Run(Action action)
{
return Task.Factory.StartNew(action, m_cts.Token, TaskCreationOptions.DenyChildAttach, m_ts);
}
public int IdleThreadCount
{
get { return m_ts.IdleThreadCount; }
}
}
class Program
{
static void Main()
{
const int THREAD_COUNT = 32;
var pool = new BJEThreadPool(THREAD_COUNT);
var tcs = new TaskCompletionSource<bool>();
var tasks = new List<Task>();
var allRunning = new CountdownEvent(THREAD_COUNT);
for (int i = pool.IdleThreadCount; i > 0; --i)
{
var index = i;
tasks.Add(pool.Run(cancellationToken =>
{
Console.WriteLine("Started action " + index);
allRunning.Signal();
tcs.Task.Wait(cancellationToken);
Console.WriteLine(" Ended action " + index);
}));
}
Console.WriteLine("pool.IdleThreadCount = " + pool.IdleThreadCount);
allRunning.Wait();
Debug.Assert(pool.IdleThreadCount == 0);
int expectedIdleThreadCount = THREAD_COUNT;
Console.WriteLine("Press [c]ancel, [e]rror, [a]bort or any other key");
switch (Console.ReadKey().KeyChar)
{
case 'c':
Console.WriteLine("Cancel All");
tcs.TrySetCanceled();
break;
case 'e':
Console.WriteLine("Error All");
tcs.TrySetException(new Exception("Failed"));
break;
case 'a':
Console.WriteLine("Abort All");
pool.Terminate();
expectedIdleThreadCount = 0;
break;
default:
Console.WriteLine("Done All");
tcs.TrySetResult(true);
break;
}
try
{
Task.WaitAll(tasks.ToArray());
}
catch (AggregateException exc)
{
Console.WriteLine(exc.Flatten().InnerException.Message);
}
Debug.Assert(pool.IdleThreadCount == expectedIdleThreadCount);
pool.Terminate();
Console.WriteLine("Press any key");
Console.ReadKey();
}
}
}
It is a very simple implementation and it appears to be working. However, there is a problem - the BJEThreadPool.Run
method does not accept asynchronous methods. I.e. my implementation does not allow me to add the following overloads:
public Task Run(Func<CancellationToken, Task> action)
{
return Task.Factory.StartNew(() => action(m_cts.Token), m_cts.Token, TaskCreationOptions.DenyChildAttach, m_ts).Unwrap();
}
public Task Run(Func<Task> action)
{
return Task.Factory.StartNew(action, m_cts.Token, TaskCreationOptions.DenyChildAttach, m_ts).Unwrap();
}
The pattern I use in InternalTaskScheduler.RunInline
does not work in this case.
So, my question is how to add the support for asynchronous work items? I am fine with changing the entire design as long as the requirements outlined at the beginning of the post are upheld.
EDIT
I would like to clarify the intented usage of the desired pool. Please, observe the following code:
if (pool.IdleThreadCount == 0)
{
return;
}
foreach (var jobData in FetchFromDB(pool.IdleThreadCount))
{
pool.Run(CreateJobAction(jobData));
}
Notes:
- The code is going to be run periodically, say every 1 minute.
- The code is going to be run concurrently by multiple machines watching the same database.
FetchFromDB
is going to use the technique described in Using SQL Server as a DB queue with multiple clients to atomically fetch and lock the work from the DB.CreateJobAction
is going to invoke the code denoted byjobData
(the job code) and close the work upon the completion of that code. The job code is out of my control and it could be pretty much anything - heavy CPU bound code or light asynchronous IO bound code, badly written synchronous IO bound code or a mix of all. It could run for minutes and it could run for hours. Closing the work is my code and it would by asynchronous IO bound code. Because of this, the signature of the returned job action is that of an asynchronous method.
Item 2 underlines the importance of correctly identifying the amount of idle threads. If there are 900 pending work items and 10 agent machines I cannot allow an agent to fetch 300 work items and queue them on the thread pool. Why? Because, it is most unlikely that the agent will be able to run 300 work items concurrently. It will run some, sure enough, but others will be waiting in the thread pool work queue. Suppose it will run 100 and let 200 wait (even though 100 is probably far fetched). This wields 3 fully loaded agents and 7 idle ones. But only 300 work items out of 900 are actually being processed concurrently!!!
My goal is to maximize the spread of the work amongst the available agents. Ideally, I should evaluate the load of an agent and the "heaviness" of the pending work, but it is a formidable task and is reserved for the future versions. Right now, I wish to assign each agent the max job capacity with the intention to provide the means to increase/decrease it dynamically without restarting the agents.
Next observation. The work can take quite a long time to run and it could be all synchronous code. As far as I understand it is undesirable to utilize thread pool threads for such kind of work.
EDIT 2
There is a statement that TaskScheduler
is only for the CPU bound work. But what if I do not know the nature of the work? I mean it is a general purpose Background Job Engine and it runs thousands of different kinds of jobs. I do not have means to tell "that job is CPU bound" and "that on is synchronous IO bound" and yet another one is asynchronous IO bound. I wish I could, but I cannot.
EDIT 3
At the end, I do not use the SemaphoreSlim
, but neither do I use the TaskScheduler
- it finally trickled down my thick skull that it is unappropriate and plain wrong, plus it makes the code overly complex.
Still, I failed to see how SemaphoreSlim
is the way. The proposed pattern:
public async Task Enqueue(Func<Task> taskGenerator)
{
await semaphore.WaitAsync();
try
{
await taskGenerator();
}
finally
{
semaphore.Release();
}
}
Expects taskGenerator either be an asynchronous IO bound code or open a new thread otherwise. However, I have no means to determine whether the work to be executed is one or another. Plus, as I have learned from SemaphoreSlim.WaitAsync continuation code if the semaphore is unlocked, the code following the WaitAsync() is going to run on the same thread, which is not very good for me.
Anyway, below is my implementation, in case anyone fancies. Unfortunately, I am yet to understand how to reduce the pool thread count dynamically, but this is a topic for another question.
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
namespace TaskStartLatency
{
public interface IBJEThreadPool
{
void SetThreadCount(int threadCount);
void Terminate();
Task Run(Action action);
Task Run(Action<CancellationToken> action);
Task Run(Func<Task> action);
Task Run(Func<CancellationToken, Task> action);
int IdleThreadCount { get; }
}
public class BJEThreadPool : IBJEThreadPool
{
private interface IActionContext
{
Task Run(CancellationToken ct);
TaskCompletionSource<object> TaskCompletionSource { get; }
}
private class ActionContext : IActionContext
{
private readonly Action m_action;
public ActionContext(Action action)
{
m_action = action;
TaskCompletionSource = new TaskCompletionSource<object>();
}
#region Implementation of IActionContext
public Task Run(CancellationToken ct)
{
m_action();
return null;
}
public TaskCompletionSource<object> TaskCompletionSource { get; private set; }
#endregion
}
private class CancellableActionContext : IActionContext
{
private readonly Action<CancellationToken> m_action;
public CancellableActionContext(Action<CancellationToken> action)
{
m_action = action;
TaskCompletionSource = new TaskCompletionSource<object>();
}
#region Implementation of IActionContext
public Task Run(CancellationToken ct)
{
m_action(ct);
return null;
}
public TaskCompletionSource<object> TaskCompletionSource { get; private set; }
#endregion
}
private class AsyncActionContext : IActionContext
{
private readonly Func<Task> m_action;
public AsyncActionContext(Func<Task> action)
{
m_action = action;
TaskCompletionSource = new TaskCompletionSource<object>();
}
#region Implementation of IActionContext
public Task Run(CancellationToken ct)
{
return m_action();
}
public TaskCompletionSource<object> TaskCompletionSource { get; private set; }
#endregion
}
private class AsyncCancellableActionContext : IActionContext
{
private readonly Func<CancellationToken, Task> m_action;
public AsyncCancellableActionContext(Func<CancellationToken, Task> action)
{
m_action = action;
TaskCompletionSource = new TaskCompletionSource<object>();
}
#region Implementation of IActionContext
public Task Run(CancellationToken ct)
{
return m_action(ct);
}
public TaskCompletionSource<object> TaskCompletionSource { get; private set; }
#endregion
}
private readonly CancellationTokenSource m_ctsTerminateAll = new CancellationTokenSource();
private readonly BlockingCollection<IActionContext> m_bus = new BlockingCollection<IActionContext>();
private readonly LinkedList<Thread> m_threads = new LinkedList<Thread>();
private int m_idleThreadCount;
private static int s_threadCount;
public BJEThreadPool(int threadCount)
{
ReserveAdditionalThreads(threadCount);
}
private void ReserveAdditionalThreads(int n)
{
for (int i = 0; i < n; ++i)
{
var index = Interlocked.Increment(ref s_threadCount) - 1;
var t = new Thread(Start)
{
IsBackground = true,
Name = "BJE Thread " + index
};
Interlocked.Increment(ref m_idleThreadCount);
t.Start();
m_threads.AddLast(t);
}
}
private void Start()
{
try
{
foreach (var actionContext in m_bus.GetConsumingEnumerable(m_ctsTerminateAll.Token))
{
RunWork(actionContext).Wait();
}
}
catch (OperationCanceledException)
{
}
catch
{
// Should never happen - log the error
}
Interlocked.Decrement(ref m_idleThreadCount);
}
private async Task RunWork(IActionContext actionContext)
{
Interlocked.Decrement(ref m_idleThreadCount);
try
{
var task = actionContext.Run(m_ctsTerminateAll.Token);
if (task != null)
{
await task;
}
actionContext.TaskCompletionSource.SetResult(null);
}
catch (OperationCanceledException)
{
actionContext.TaskCompletionSource.TrySetCanceled();
}
catch (Exception exc)
{
actionContext.TaskCompletionSource.TrySetException(exc);
}
Interlocked.Increment(ref m_idleThreadCount);
}
private Task PostWork(IActionContext actionContext)
{
m_bus.Add(actionContext);
return actionContext.TaskCompletionSource.Task;
}
#region Implementation of IBJEThreadPool
public void SetThreadCount(int threadCount)
{
if (threadCount > m_threads.Count)
{
ReserveAdditionalThreads(threadCount - m_threads.Count);
}
else if (threadCount < m_threads.Count)
{
throw new NotSupportedException();
}
}
public void Terminate()
{
m_ctsTerminateAll.Cancel();
foreach (var t in m_threads)
{
t.Join();
}
}
public Task Run(Action action)
{
return PostWork(new ActionContext(action));
}
public Task Run(Action<CancellationToken> action)
{
return PostWork(new CancellableActionContext(action));
}
public Task Run(Func<Task> action)
{
return PostWork(new AsyncActionContext(action));
}
public Task Run(Func<CancellationToken, Task> action)
{
return PostWork(new AsyncCancellableActionContext(action));
}
public int IdleThreadCount
{
get { return m_idleThreadCount; }
}
#endregion
}
public static class Extensions
{
public static Task WithCancellation(this Task task, CancellationToken token)
{
return task.ContinueWith(t => t.GetAwaiter().GetResult(), token);
}
}
class Program
{
static void Main()
{
const int THREAD_COUNT = 16;
var pool = new BJEThreadPool(THREAD_COUNT);
var tcs = new TaskCompletionSource<bool>();
var tasks = new List<Task>();
var allRunning = new CountdownEvent(THREAD_COUNT);
for (int i = pool.IdleThreadCount; i > 0; --i)
{
var index = i;
tasks.Add(pool.Run(async ct =>
{
Console.WriteLine("Started action " + index);
allRunning.Signal();
await tcs.Task.WithCancellation(ct);
Console.WriteLine(" Ended action " + index);
}));
}
Console.WriteLine("pool.IdleThreadCount = " + pool.IdleThreadCount);
allRunning.Wait();
Debug.Assert(pool.IdleThreadCount == 0);
int expectedIdleThreadCount = THREAD_COUNT;
Console.WriteLine("Press [c]ancel, [e]rror, [a]bort or any other key");
switch (Console.ReadKey().KeyChar)
{
case 'c':
Console.WriteLine("ancel All");
tcs.TrySetCanceled();
break;
case 'e':
Console.WriteLine("rror All");
tcs.TrySetException(new Exception("Failed"));
break;
case 'a':
Console.WriteLine("bort All");
pool.Terminate();
expectedIdleThreadCount = 0;
break;
default:
Console.WriteLine("Done All");
tcs.TrySetResult(true);
break;
}
try
{
Task.WaitAll(tasks.ToArray());
}
catch (AggregateException exc)
{
Console.WriteLine(exc.Flatten().InnerException.Message);
}
Debug.Assert(pool.IdleThreadCount == expectedIdleThreadCount);
pool.Terminate();
Console.WriteLine("Press any key");
Console.ReadKey();
}
}
}