As is often the solution with unit testing something that involves a static method (Task.Run
in this case), you will likely need to pass something in as a dependency to your class that wraps this up and that you can then add behaviour to in the tests.
As @Rich suggests in his answer, you could do this by passing in a TaskScheduler
. Your test version of this can then maintain a count of the tasks as they are enqueued.
Making a test TaskScheduler
is actually a little ugly because of protection levels, but at the bottom of this post I have included one that wraps an existing TaskScheduler
(e.g. you could use TaskScheduler.Default
).
Unfortunately, you would also need to change your calls like
Task.Run(() => DoSomething);
to something like
Task.Factory.StartNew(
() => DoSomething(),
CancellationToken.None,
TaskCreationOptions.DenyChildAttach,
myTaskScheduler);
which is basically what Task.Run
does under the hood, except with the TaskScheduler.Default
. You could of course wrap that up in a helper method somewhere.
Alternatively, if you are not squeamish about some riskier reflection in your test code you could hijack the TaskScheduler.Default
property, so that you can still just use Task.Run
:
var defaultSchedulerField = typeof(TaskScheduler).GetField("s_defaultTaskScheduler", BindingFlags.Static | BindingFlags.NonPublic);
var scheduler = new TestTaskScheduler(TaskScheduler.Default);
defaultSchedulerField.SetValue(null, scheduler);
(Private field name is from TaskScheduler.cs line 285.)
So for example, this test would pass using my TestTaskScheduler
below and the reflection trick:
[Test]
public void Can_count_tasks()
{
// Given
var originalScheduler = TaskScheduler.Default;
var defaultSchedulerField = typeof(TaskScheduler).GetField("s_defaultTaskScheduler", BindingFlags.Static | BindingFlags.NonPublic);
var testScheduler = new TestTaskScheduler(originalScheduler);
defaultSchedulerField.SetValue(null, testScheduler);
// When
Task.Run(() => {});
Task.Run(() => {});
Task.Run(() => {});
// Then
testScheduler.TaskCount.Should().Be(3);
// Clean up
defaultSchedulerField.SetValue(null, originalScheduler);
}
Here is the test task scheduler:
using System.Collections.Generic;
using System.Reflection;
using System.Threading.Tasks;
public class TestTaskScheduler : TaskScheduler
{
private static readonly MethodInfo queueTask = GetProtectedMethodInfo("QueueTask");
private static readonly MethodInfo tryExecuteTaskInline = GetProtectedMethodInfo("TryExecuteTaskInline");
private static readonly MethodInfo getScheduledTasks = GetProtectedMethodInfo("GetScheduledTasks");
private readonly TaskScheduler taskScheduler;
public TestTaskScheduler(TaskScheduler taskScheduler)
{
this.taskScheduler = taskScheduler;
}
public int TaskCount { get; private set; }
protected override void QueueTask(Task task)
{
TaskCount++;
CallProtectedMethod(queueTask, task);
}
protected override bool TryExecuteTaskInline(Task task, bool taskWasPreviouslyQueued)
{
return (bool)CallProtectedMethod(tryExecuteTaskInline, task, taskWasPreviouslyQueued);
}
protected override IEnumerable<Task> GetScheduledTasks()
{
return (IEnumerable<Task>)CallProtectedMethod(getScheduledTasks);
}
private object CallProtectedMethod(MethodInfo methodInfo, params object[] args)
{
return methodInfo.Invoke(taskScheduler, args);
}
private static MethodInfo GetProtectedMethodInfo(string methodName)
{
return typeof(TaskScheduler).GetMethod(methodName, BindingFlags.Instance | BindingFlags.NonPublic);
}
}
or tidied up using RelflectionMagic as suggested by @hgcummings in the comments:
var scheduler = new TestTaskScheduler(TaskScheduler.Default);
typeof(TaskScheduler).AsDynamicType().s_defaultTaskScheduler = scheduler;
using System.Collections.Generic;
using System.Threading.Tasks;
using ReflectionMagic;
public class TestTaskScheduler : TaskScheduler
{
private readonly dynamic taskScheduler;
public TestTaskScheduler(TaskScheduler taskScheduler)
{
this.taskScheduler = taskScheduler.AsDynamic();
}
public int TaskCount { get; private set; }
protected override void QueueTask(Task task)
{
TaskCount++;
taskScheduler.QueueTask(task);
}
protected override bool TryExecuteTaskInline(Task task, bool taskWasPreviouslyQueued)
{
return taskScheduler.TryExecuteTaskInline(task, taskWasPreviouslyQueued);
}
protected override IEnumerable<Task> GetScheduledTasks()
{
return taskScheduler.GetScheduledTasks();
}
}