1

I am implementing a caching layer for my ASP.NET Core 3.1 Web API.

Starting Implementation

public interface ICache
{
    T Get<T>(string key);
    void Set<T>(string key, T value);
}

public static class ICacheExtensions
{
    public static T GetOrCreate<T>(this ICache cache, string key, Func<T> factory)
    {
        var value = cache.Get<T>(key);

        if (EqualityComparer<T>.Default.Equals(value, default(T)))
        {
            value = factory();
            if (!EqualityComparer<T>.Default.Equals(value, default(T)))
            {
                cache.Set(key, value);
            }
        }

        return value;
    }

    public static async Task<T> GetOrCreateAsync<T>(this ICache cache, string key, Func<Task<T>> factory)
    {
        var value = cache.Get<T>(key);

        if (EqualityComparer<T>.Default.Equals(value, default(T)))
        {
            value = await factory().ConfigureAwait(false);
            if (!EqualityComparer<T>.Default.Equals(value, default(T)))
            {
                cache.Set(key, value);
            }
        }

        return value;
    }
}

This works fine, but one known problem I'm trying to address is that it is susceptible to cache stampedes. If my API is handling many requests that all try to access the same key using one of the GetOrCreate methods at the same time, they will each run a parallel instance of the factory function. This means redundant work and wasted resources.

What I have attempted to do is introduce mutexes to ensure that only one instance of the factory function can run per cache key.

Introduce Mutexes

public interface ICache
{
    T Get<T>(string key);
    void Set<T>(string key, T value);
}

public static class ICacheExtensions
{
    public static T GetOrCreate<T>(this ICache cache, string key, Func<T> factory)
    {
        using var mutex = new Mutex(false, key);
        var value = cache.Get<T>(key);

        if (EqualityComparer<T>.Default.Equals(value, default(T)))
        {
            mutex.WaitOne();

            try
            {
                var value = cache.Get<T>(key);

                if (EqualityComparer<T>.Default.Equals(value, default(T)))
                {
                    value = factory();
                    if (!EqualityComparer<T>.Default.Equals(value, default(T)))
                    {
                        cache.Set(key, value);
                    }
                }
            }
            finally
            {
                mutex.ReleaseMutex();
            }
        }

        return value;
    }

    public static async Task<T> GetOrCreateAsync<T>(this ICache cache, string key, Func<Task<T>> factory)
    {
        using var mutex = new Mutex(false, key);
        var value = cache.Get<T>(key);

        if (EqualityComparer<T>.Default.Equals(value, default(T)))
        {
            mutex.WaitOne();

            try
            {
                var value = cache.Get<T>(key);

                if (EqualityComparer<T>.Default.Equals(value, default(T)))
                {
                    value = await factory().ConfigureAwait(false);
                    if (!EqualityComparer<T>.Default.Equals(value, default(T)))
                    {
                        cache.Set(key, value);
                    }
                }
            }
            finally
            {
                mutex.ReleaseMutex();
            }
        }

        return value;
    }
}

This works great for GetOrCreate(), but GetOrCreateAsync() throws an exception. Turns out mutexes are thread-bound so if WaitOne() and ReleaseMutex() are called on different threads (as tends to happen with async/await), the mutex doesn't like that and throws an exception. I found this other SO question that describes some workarounds and decided to go with a custom task scheduler. SingleThreadedTaskScheduler schedules tasks using a thread pool containing exactly one thread. And I intend to interact with the mutex only from that thread.

SingleThreadedTaskScheduler

internal sealed class SingleThreadedTaskScheduler : TaskScheduler, IDisposable
{
    private readonly Thread _thread;
    private BlockingCollection<Task> _tasks;

    public SingleThreadedTaskScheduler()
    {
        _tasks = new BlockingCollection<Task>();
        _thread = new Thread(() =>
        {
            foreach (var t in _tasks.GetConsumingEnumerable())
            {
                TryExecuteTask(t);
            }
        });
        _thread.IsBackground = true;
        _thread.Start();
    }

    protected override IEnumerable<Task> GetScheduledTasks()
    {
        return _tasks.ToArray();
    }

    protected override void QueueTask(Task task)
    {
        _tasks.Add(task);
    }

    protected override bool TryExecuteTaskInline(Task task, bool taskWasPreviouslyQueued)
    {
        return false;
    }

    public void Dispose()
    {
        _tasks?.CompleteAdding();
        _thread?.Join();
        _tasks?.Dispose();
        _tasks = null;
    }
}

GetOrCreateAsync with SingleThreadedTaskScheduler

private static readonly TaskScheduler _mutexTaskScheduler = new SingleThreadedTaskScheduler();

public static async Task<T> GetOrCreateAsync<T>(this ICache cache, string key, Func<Task<T>> factory)
{
    using var mutex = new Mutex(false, key);
    var value = cache.Get<T>(key);

    if (EqualityComparer<T>.Default.Equals(value, default(T)))
    {
        await Task.Factory
            .StartNew(() => mutex.WaitOne(), CancellationToken.None, TaskCreationOptions.None, _mutexTaskScheduler)
            .ConfiureAwait(false);

        try
        {
            var value = cache.Get<T>(key);

            if (EqualityComparer<T>.Default.Equals(value, default(T)))
            {
                value = await factory().ConfigureAwait(false);
                if (!EqualityComparer<T>.Default.Equals(value, default(T)))
                {
                    cache.Set(key, value);
                }
            }
        }
        finally
        {
            await Task.Factory
                .StartNew(() => mutex.ReleaseMutex(), CancellationToken.None, TaskCreationOptions.None, _mutexTaskScheduler)
                .ConfiureAwait(false);
        }
    }

    return value;
}

With this implementation, the exception is resolved, but GetOrCreateAsync still calls the factory function many times in a cache stampede scenario. Am I missing something?

I've also tried using SemaphoreSlim instead of Mutex which should play nicer with async/await. The issue here is that Linux doesn't support named semaphores so I'd have to keep all my semaphores in a Dictionary<string, SemaphoreSlim> and that would be too cumbersome to manage.

Raymond Saltrelli
  • 4,071
  • 2
  • 33
  • 52
  • Why are you writing a [caching layer](https://github.com/thangchung/awesome-dotnet-core#caching)? For fun? There's lots of [pre-existing libraries](https://github.com/quozd/awesome-dotnet#caching) to do caching. – mason Apr 07 '21 at 18:07
  • Let me rephrase. The caching layer exists and is in use and I'm trying to address the stampede problem. – Raymond Saltrelli Apr 07 '21 at 18:31
  • That's kind of sidestepping the point. The reason you're having issues is that you (or whoever) rolled your own caching layer. You're trying to reinvent the wheel. Why not take advantage of an existing caching framework, which has solved these problems, and has been tested by thousands of other devs? It really shouldn't be that much work to swap it around, and it would let you avoid having to maintain a lot of complex code. – mason Apr 07 '21 at 18:34
  • Thanks for your input. – Raymond Saltrelli Apr 07 '21 at 18:36
  • 1
    @mason: I'm not aware of any caching library for .NET that has async support *and* shares creation function invocation results. I've [started one](https://gist.github.com/StephenCleary/39a2cd0aa3c705a984a4dbbea8275fe9), but it's not prod-ready. – Stephen Cleary Apr 07 '21 at 18:53
  • @StephenCleary Yeah so that's the crux of my problem. How to get mutexes and async/await to play nice together. The fact that it's a caching solution is just for context. – Raymond Saltrelli Apr 07 '21 at 19:03
  • `that would be too cumbersome to manage` well, if it's the difference between a working cache and a non-working one... – Ian Kemp Apr 07 '21 at 19:59

2 Answers2

3

The linked solution only works when using a named mutex to synchronize asynchronous code across processes. It won't work to synchronize code within the same process. Mutexes allow recursive acquisition, so by moving all acquisitions on the same thread, it's the same as if the mutex isn't there at all.

I'd have to keep all my semaphores in a Dictionary<string, SemaphoreSlim> and that would be too cumbersome to manage.

If you need a non-recursive named mutex, named Semaphores (which don't work on Linux) or managing your own dictionary is really the only way to go.

I have an AsyncCache<T> that I've been working on but isn't prod-ready yet. It tries to look like a cache of Task<T> instances but is actually a cache of TaskCompletionSource<T> instances.

Stephen Cleary
  • 437,863
  • 77
  • 675
  • 810
  • Thanks for the response. I was afraid of this. Ah well .... So I did some experimenting with ConcurrentDictionary that seems to work. The remaining issue is how to keep that dictionary from just growing unbounded. I'll have to evict semaphores for stale keys at some point. – Raymond Saltrelli Apr 07 '21 at 20:00
  • See my answer for what I ended up going with based on Stephen's answer. – Raymond Saltrelli Apr 07 '21 at 20:44
0

Using semaphores appears to work. Credit to Stephen Cleary for confirming that this was a better route than Mutexes.

public static async Task<T> GetOrCreateAsync<T>(this ICache cache, string key, Func<Task<T>> factory)
{
    using var mutex = new Mutex(false, key);
    var value = cache.Get<T>(key);

    if (EqualityComparer<T>.Default.Equals(value, default(T)))
    {
        WaitOne(key);

        try
        {
            var value = cache.Get<T>(key);

            if (EqualityComparer<T>.Default.Equals(value, default(T)))
            {
                value = await factory().ConfigureAwait(false);
                if (!EqualityComparer<T>.Default.Equals(value, default(T)))
                {
                    cache.Set(key, value);
                }
            }

            ReleaseAll(key);
        }
        catch (Exception)
        {
            ReleaseOne(key);
            throw;
        }
    }

    return value;
}

private static readonly ConcurrentDictionary<string, SemaphoreSlim> _semaphores = new ConcurrentDictionary<string, SemaphoreSlim>();

private static void WaitOne(string key)
{
    var semaphore = _semaphores.GetOrAdd(key, k => new SemaphoreSlim(1, int.MaxValue));
    semaphore.Wait();
}

private static void ReleaseOne(string key)
{
    var semaphore = _semaphores.GetOrAdd(key, k => new SemaphoreSlim(0, int.MaxValue));
    semaphore.Release();
}

private static void ReleaseAll(string key)
{
    var semaphore = default(SemaphoreSlim);
    _semaphores.Remove(key, out semaphore);
    semaphore?.Release(int.MaxValue);
    semaphore?.Dispose();
}
Raymond Saltrelli
  • 4,071
  • 2
  • 33
  • 52