0

I wonder if it's safe to replace Dictionary with ConcurrentDictionary and what modifications should I do to for ex. TryAdd, TryGetValue, removing locks, etc.?

protected class SubscriptionManager
{
    private readonly DeribitV2Client _client;
    private readonly Dictionary<string, SubscriptionEntry> _subscriptionMap;

    public SubscriptionManager(DeribitV2Client client)
    {
        _client = client;
        _subscriptionMap = new Dictionary<string, SubscriptionEntry>();
    }

    public async Task<SubscriptionToken> Subscribe(ISubscriptionChannel channel, Action<Notification> callback)
    {
        if (callback == null)
        {
            return SubscriptionToken.Invalid;
        }

        var channelName = channel.ToChannelName();
        TaskCompletionSource<SubscriptionToken> taskSource = null;
        SubscriptionEntry entry;

        lock (_subscriptionMap)
        {
            if (!_subscriptionMap.TryGetValue(channelName, out entry))
            {
                entry = new SubscriptionEntry();
                if (!_subscriptionMap.TryAdd(channelName, entry))
                {
                    _client.Logger?.Error("Subscribe: Could not add internal item for channel {Channel}", channelName);
                    return SubscriptionToken.Invalid;
                }

                taskSource = new TaskCompletionSource<SubscriptionToken>();
                entry.State = SubscriptionState.Subscribing;
                entry.SubscribeTask = taskSource.Task;
            }

            // Entry already exists but is completely unsubscribed
            if (entry.State == SubscriptionState.Unsubscribed)
            {
                taskSource = new TaskCompletionSource<SubscriptionToken>();
                entry.State = SubscriptionState.Subscribing;
                entry.SubscribeTask = taskSource.Task;
            }

            // Already subscribed - Put the callback in there and let's go
            if (entry.State == SubscriptionState.Subscribed)
            {
                _client.Logger?.Debug("Subscribe: Subscription for channel already exists. Adding callback to list (Channel: {Channel})", channelName);
                var callbackEntry = new SubscriptionCallback(new SubscriptionToken(Guid.NewGuid()), callback);
                entry.Callbacks.Add(callbackEntry);
                return callbackEntry.Token;
            }

            // We are in the middle of unsubscribing from the channel
            if (entry.State == SubscriptionState.Unsubscribing)
            {
                _client.Logger?.Debug("Subscribe: Channel is unsubscribing. Abort subscribe (Channel: {Channel})", channelName);
                return SubscriptionToken.Invalid;
            }
        }

        // Only one state left: Subscribing

        // We are already subscribing
        if (taskSource == null && entry.State == SubscriptionState.Subscribing)
        {
            _client.Logger?.Debug("Subscribe: Channel is already subscribing. Waiting for the task to complete ({Channel})", channelName);

            var subscribeResult = entry.SubscribeTask != null && await entry.SubscribeTask != SubscriptionToken.Invalid;

            if (!subscribeResult && entry.State != SubscriptionState.Subscribed)
            {
                _client.Logger?.Debug("Subscribe: Subscription has failed. Abort subscribe (Channel: {Channel})", channelName);
                return SubscriptionToken.Invalid;
            }

            _client.Logger?.Debug("Subscribe: Subscription was successful. Adding callback (Channel: {Channel}", channelName);
            var callbackEntry = new SubscriptionCallback(new SubscriptionToken(Guid.NewGuid()), callback);
            entry.Callbacks.Add(callbackEntry);
            return callbackEntry.Token;
        }

        if (taskSource == null)
        {
            _client.Logger?.Error("Subscribe: Invalid execution state. Missing TaskCompletionSource (Channel: {Channel}", channelName);
            return SubscriptionToken.Invalid;
        }

        try
        {
            var subscribeResponse = await _client.Send(
            IsPrivateChannel(channelName) ? "private/subscribe" : "public/subscribe",
            new { channels = new[] { channelName } },
            new ListJsonConverter<string>()).ConfigureAwait(false);

            var response = subscribeResponse.ResultData;

            if (response.Count != 1 || response[0] != channelName)
            {
                _client.Logger?.Debug("Subscribe: Invalid result (Channel: {Channel}): {@Response}", channelName, response);
                entry.State = SubscriptionState.Unsubscribed;
                entry.SubscribeTask = null;
                Debug.Assert(taskSource != null, nameof(taskSource) + " != null");
                taskSource.SetResult(SubscriptionToken.Invalid);
            }
            else
            {
                _client.Logger?.Debug("Subscribe: Successfully subscribed. Adding callback (Channel: {Channel})", channelName);

                var callbackEntry = new SubscriptionCallback(new SubscriptionToken(Guid.NewGuid()), callback);
                entry.Callbacks.Add(callbackEntry);
                entry.State = SubscriptionState.Subscribed;
                entry.SubscribeTask = null;
                Debug.Assert(taskSource != null, nameof(taskSource) + " != null");
                taskSource.SetResult(callbackEntry.Token);
            }
        }
        catch (Exception e)
        {
            entry.State = SubscriptionState.Unsubscribed;
            entry.SubscribeTask = null;
            Debug.Assert(taskSource != null, nameof(taskSource) + " != null");
            taskSource.SetException(e);
        }

        return await taskSource.Task;
    }

    public async Task<bool> Unsubscribe(SubscriptionToken token)
    {
        string channelName;
        SubscriptionEntry entry;
        SubscriptionCallback callbackEntry;
        TaskCompletionSource<bool> taskSource;

        lock (_subscriptionMap)
        {
            (channelName, entry, callbackEntry) = GetEntryByToken(token);

            if (string.IsNullOrEmpty(channelName) || entry == null || callbackEntry == null)
            {
                _client.Logger?.Warning("Unsubscribe: Could not find token {token}", token.Token);
                return false;
            }

            switch (entry.State)
            {
                case SubscriptionState.Subscribing:
                    _client.Logger?.Debug("Unsubscribe: Channel is currently subscribing. Abort unsubscribe (Channel: {Channel})", channelName);
                    return false;
                case SubscriptionState.Unsubscribed:
                case SubscriptionState.Unsubscribing:
                    _client.Logger?.Debug("Unsubscribe: Channel is unsubscribed or unsubscribing. Remove callback (Channel: {Channel})", channelName);
                    entry.Callbacks.Remove(callbackEntry);
                    return true;
                case SubscriptionState.Subscribed:
                    if (entry.Callbacks.Count > 1)
                    {
                        _client.Logger?.Debug("Unsubscribe: There are still callbacks left. Remove callback but don't unsubscribe (Channel: {Channel})", channelName);
                        entry.Callbacks.Remove(callbackEntry);
                        return true;
                    }

                    _client.Logger?.Debug("Unsubscribe: No callbacks left. Unsubscribe and remove callback (Channel: {Channel})", channelName);
                    break;
                default:
                    return false;
            }

            // At this point it's only possible that the entry-State is Subscribed
            // and the callback list is empty after removing this callback.
            // Hence we unsubscribe at the server now
            entry.State = SubscriptionState.Unsubscribing;
            taskSource = new TaskCompletionSource<bool>();
            entry.UnsubscribeTask = taskSource.Task;
        }

        try
        {
            var unsubscribeResponse = await _client.Send(
            IsPrivateChannel(channelName) ? "private/unsubscribe" : "public/unsubscribe",
            new { channels = new[] { channelName } },
            new ListJsonConverter<string>()).ConfigureAwait(false);

            var response = unsubscribeResponse.ResultData;

            if (response.Count != 1 || response[0] != channelName)
            {
                entry.State = SubscriptionState.Subscribed;
                entry.UnsubscribeTask = null;
                taskSource.SetResult(false);
            }
            else
            {
                entry.Callbacks.Remove(callbackEntry);
                entry.State = SubscriptionState.Unsubscribed;
                entry.UnsubscribeTask = null;
                taskSource.SetResult(true);
            }
        }
        catch (Exception e)
        {
            entry.State = SubscriptionState.Subscribed;
            entry.UnsubscribeTask = null;
            taskSource.SetException(e);
        }

        return await taskSource.Task;
    }

    public IEnumerable<Action<Notification>> GetCallbacks(string channel)
    {
        if (_subscriptionMap.TryGetValue(channel, out var entry))
        {
            foreach (var callbackEntry in entry.Callbacks)
            {
                yield return callbackEntry.Action;
            }
        }
    }

    public void Reset()
    {
        _subscriptionMap.Clear();
    }

    private static bool IsPrivateChannel(string channel)
    {
        return channel.StartsWith("user.");
    }

    private (string channelName, SubscriptionEntry entry, SubscriptionCallback callbackEntry) GetEntryByToken(SubscriptionToken token)
    {
        lock (_subscriptionMap)
        {
            foreach (var kvp in _subscriptionMap)
            {
                foreach (var callbackEntry in kvp.Value.Callbacks)
                {
                    if (callbackEntry.Token == token)
                    {
                        return (kvp.Key, kvp.Value, callbackEntry);
                    }
                }
            }
        }

        return (null, null, null);
    }
}

GitHub

My attempt

public class SubscriptionToken
{
    public static readonly SubscriptionToken Invalid = new(Guid.Empty);

    public SubscriptionToken(Guid token)
    {
        Token = token;
    }

    public Guid Token { get; }
}

public class SubscriptionCallback
{
    public SubscriptionCallback(SubscriptionToken token, Action<Notification> action)
    {
        Token = token;
        Action = action;
    }

    public Action<Notification> Action { get; }
    public SubscriptionToken Token { get; }
}

public class SubscriptionEntry
{
    public List<SubscriptionCallback> Callbacks { get; } = new();
    public Task<SubscriptionToken>? SubscribeTask { get; set; }
    public Task<bool>? UnsubscribeTask { get; set; }
    public SubscriptionState State { get; set; } = SubscriptionState.Unsubscribed;
}

public class SubscriptionManager
{
    private readonly DeribitClient _client;
    private readonly ConcurrentDictionary<string, SubscriptionEntry> _subscriptions = new();

    public SubscriptionManager(DeribitClient client)
    {
        _client = client ?? throw new ArgumentNullException(nameof(client));
    }

    public async Task<SubscriptionToken> SubscribeAsync(string channel, Action<Notification>? callback)
    {
        if (callback == null)
        {
            throw new ArgumentNullException(nameof(callback));
        }

        TaskCompletionSource<SubscriptionToken>? tcs = null;

        if (_subscriptions.TryGetValue(channel, out var entry))
        {
            if (entry.State == SubscriptionState.Subscribed)
            {
                Log.Debug("Subscribe: Subscription for channel already exists. Adding callback to list (Channel: {Channel})", channel);

                var callbackEntry = new SubscriptionCallback(new SubscriptionToken(Guid.NewGuid()), callback);
                entry.Callbacks.Add(callbackEntry);
                return callbackEntry.Token;
            }

            if (entry.State == SubscriptionState.Unsubscribing)
            {
                Log.Debug("Subscribe: Channel is unsubscribing. Abort subscribe (Channel: {Channel})", channel);
                return SubscriptionToken.Invalid;
            }

            if (entry.State == SubscriptionState.Unsubscribed)
            {
                Log.Debug("Subscribe: Entry already exists but is completely unsubscribed (Channel: {Channel})", channel);

                tcs = new TaskCompletionSource<SubscriptionToken>();
                entry.State = SubscriptionState.Subscribing;
                entry.SubscribeTask = tcs.Task;
            }
        }
        else
        {
            tcs = new TaskCompletionSource<SubscriptionToken>();
            entry = new SubscriptionEntry
            {
                State = SubscriptionState.Subscribing,
                SubscribeTask = tcs.Task
            };

            if (!_subscriptions.TryAdd(channel, entry))
            {
                Log.Error("Subscribe: Could not add internal item for channel {Channel}", channel);
                return SubscriptionToken.Invalid;
            }
        }

        if (tcs == null && entry.State == SubscriptionState.Subscribing)
        {
            Log.Debug("Subscribe: Channel is already subscribing. Waiting for the task to complete ({Channel})", channel);

            var subscribeResult = entry.SubscribeTask != null && await entry.SubscribeTask.ConfigureAwait(false) != SubscriptionToken.Invalid;

            if (!subscribeResult && entry.State != SubscriptionState.Subscribed)
            {
                Log.Debug("Subscribe: Subscription has failed. Abort subscribe (Channel: {Channel})", channel);
                return SubscriptionToken.Invalid;
            }

            Log.Debug("Subscribe: Subscription was successful. Adding callback (Channel: {Channel}", channel);

            var callbackEntry = new SubscriptionCallback(new SubscriptionToken(Guid.NewGuid()), callback);
            entry.Callbacks.Add(callbackEntry);
            return callbackEntry.Token;
        }

        if (tcs == null)
        {
            Log.Error("Subscribe: Invalid execution state. Missing TaskCompletionSource (Channel: {Channel}", channel);
            return SubscriptionToken.Invalid;
        }

        try
        {
            var method = IsPrivateChannel(channel) ? "private/subscribe" : "public/subscribe";
            var @params = new Dictionary<string, string[]>
            {
                { "channels", new[] { channel } }
            };
            var subscribeResponse = await _client.SendAsync<Notification>(method, @params).ConfigureAwait(false);

            if (subscribeResponse == null)
            {
                Log.Debug("Subscribe: Invalid result (Channel: {Channel}): {@Response}", channel, subscribeResponse);

                entry.State = SubscriptionState.Unsubscribed;
                entry.SubscribeTask = null;

                Debug.Assert(tcs != null);

                tcs.SetResult(SubscriptionToken.Invalid);
            }
            else
            {
                Log.Debug("Subscribe: Successfully subscribed. Adding callback (Channel: {Channel})", channel);

                var callbackEntry = new SubscriptionCallback(new SubscriptionToken(Guid.NewGuid()), callback);
                entry.Callbacks.Add(callbackEntry);

                entry.State = SubscriptionState.Subscribed;
                entry.SubscribeTask = null;

                Debug.Assert(tcs != null);

                tcs.SetResult(callbackEntry.Token);
            }
        }
        catch (Exception ex)
        {
            entry.State = SubscriptionState.Unsubscribed;
            entry.SubscribeTask = null;

            Debug.Assert(tcs != null);

            tcs.SetException(ex);
        }

        return await tcs.Task.ConfigureAwait(false);
    }

    public async Task<bool> UnsubscribeAsync(SubscriptionToken token)
    {
        TaskCompletionSource<bool> tcs;

        var (channel, entry, callbackEntry) = GetEntryByToken(token);

        if (string.IsNullOrEmpty(channel) || entry == null || callbackEntry == null)
        {
            Log.Warning("UnsubscribeAsync: Could not find token {token}", token.Token);
            return false;
        }

        switch (entry.State)
        {
            case SubscriptionState.Subscribing:
                Log.Debug("UnsubscribeAsync: Channel is currently subscribing. Abort unsubscribe (Channel: {Channel})", channel);
                return false;
            case SubscriptionState.Unsubscribed:
            case SubscriptionState.Unsubscribing:
                Log.Debug("UnsubscribeAsync: Channel is unsubscribed or unsubscribing. Remove callback (Channel: {Channel})", channel);
                entry.Callbacks.Remove(callbackEntry);
                return true;
            case SubscriptionState.Subscribed when entry.Callbacks.Count > 1:
                Log.Debug("UnsubscribeAsync: There are still callbacks left. Remove callback but don't unsubscribe (Channel: {Channel})", channel);
                entry.Callbacks.Remove(callbackEntry);
                return true;
            case SubscriptionState.Subscribed:
                Log.Debug("UnsubscribeAsync: No callbacks left. UnsubscribeAsync and remove callback (Channel: {Channel})", channel);
                tcs = new TaskCompletionSource<bool>();
                entry.State = SubscriptionState.Unsubscribing;
                entry.UnsubscribeTask = tcs.Task;
                break;
            default:
                return false;
        }

        try
        {
            var method = IsPrivateChannel(channel) ? "private/unsubscribe" : "public/unsubscribe";
            var @params = new Dictionary<string, string[]>
            {
                { "channels", new[] { channel } }
            };
            var unsubscribeResponse = await _client.SendAsync<Notification>(method, @params).ConfigureAwait(false);

            if (unsubscribeResponse == null)
            {
                entry.State = SubscriptionState.Subscribed;
                entry.UnsubscribeTask = null;

                tcs.SetResult(false);
            }
            else
            {
                entry.Callbacks.Remove(callbackEntry);

                entry.State = SubscriptionState.Unsubscribed;
                entry.UnsubscribeTask = null;

                tcs.SetResult(true);
            }
        }
        catch (Exception ex)
        {
            entry.State = SubscriptionState.Subscribed;
            entry.UnsubscribeTask = null;

            tcs.SetException(ex);
        }

        return await tcs.Task.ConfigureAwait(false);
    }

    private (string? channelName, SubscriptionEntry? entry, SubscriptionCallback? callbackEntry) GetEntryByToken(SubscriptionToken token)
    {
        foreach (var (key, value) in _subscriptions)
        {
            foreach (var callbackEntry in value.Callbacks.Where(callbackEntry => callbackEntry.Token == token))
            {
                return (key, value, callbackEntry);
            }
        }

        return (null, null, null);
    }

    public IEnumerable<Action<Notification>> GetCallbacks(string channel)
    {
        if (_subscriptions.TryGetValue(channel, out var entry))
        {
            foreach (var callbackEntry in entry.Callbacks)
            {
                yield return callbackEntry.Action;
            }
        }
    }

    private static bool IsPrivateChannel(string channel)
    {
        return channel.StartsWith("user.");
    }
}

nop
  • 4,711
  • 6
  • 32
  • 93
  • 1
    Well, your keys won't be sorted anymore. – ProgrammingLlama Apr 18 '22 at 08:49
  • Indeed, your question title and your question body are significantly different: replacing `Dictionary` with `ConcurrentDictionary` will usually be fine. Replacing `SortedDictionary` with `ConcurrentDictionary` is a different matter. – Jon Skeet Apr 18 '22 at 08:51
  • @JonSkeet, I edited it. It doesn't really matter if they are sorted or not. – nop Apr 18 '22 at 08:52
  • Well I'd suggest removing all the locking, and where you're using TryGetValue and then TryAdd, you can use GetOrAdd instead. (It's quite hard to follow the code with the odd indentation, I'm afraid.) Beyond that, you'll need to ask more specific questions really... I can't imagine what a suitable *answer* on this post would look like at the moment. – Jon Skeet Apr 18 '22 at 08:59
  • Are you asking about replacing the `Dictionary` with a `ConcurrentDictionary`, with no other modification in the code whatsoever? – Theodor Zoulias Apr 18 '22 at 09:05
  • @TheodorZoulias, yes and what modifications I should do, exactly what Jon Skeet said. – nop Apr 18 '22 at 09:10
  • @JonSkeet, that's a good one, you can wrap it in an answer. – nop Apr 18 '22 at 09:10
  • No, it's not really an answer - it's a couple of suggestions, but that's all. The question at the moment is too vague to answer clearly, IMO. – Jon Skeet Apr 18 '22 at 09:12
  • I don't even know what you mean by "breaking the locks" - but this still feels like a question I would be uncomfortable answering. Maybe someone else will do so. – Jon Skeet Apr 18 '22 at 09:25
  • Btw the `GetCallbacks` and `Reset` methods are accessing the `_subscriptionMap` dictionary without synchronization, so your existing code is not thread-safe. Why do you want to replace the `Dictionary` with a `ConcurrentDictionary`, to make your code faster? – Theodor Zoulias Apr 18 '22 at 09:35
  • @TheodorZoulias, because ConcurrentDictionary was meant to be used instead of `lock`ing. – nop Apr 18 '22 at 10:07
  • 1
    Using a `Dictionary` with a `lock` is completely fine. The `ConcurrentDictionary` is a convenience class, or a performance optimization if you are using it in hot paths. Don't replace it just out of principles. If it's not broken, don't fix it! – Theodor Zoulias Apr 18 '22 at 10:30
  • @TheodorZoulias, I edited my question. Added code of what I changed. It would be nice if you have a look at it. – nop Apr 18 '22 at 14:47
  • @JonSkeet, I added "My attempt" which is the answer of the question to my understanding. It would be nice if you have any suggestions. – nop Apr 18 '22 at 14:48
  • What is the expected usage pattern of the `SubscriptionManager` class? Is it supposed to handle concurrent invocations of all its members by multiple threads without race conditions or data corruption? More specifically is it supposed to survive two different threads calling `SubscribeAsync` with the same `channel` argument at the same time? – Theodor Zoulias Apr 18 '22 at 15:06
  • @TheodorZoulias, yes it should survive two subscriptions with same parameters at the same time. The whole point of the SubscriptionManager class is to split the subscriptions from the client class, so it fulfills the single responsibility principles. It's also easier to call the subscription methods like that and it's easier when we need to re-subscribe each time the client reconnects. – nop Apr 18 '22 at 15:16

1 Answers1

1

A ConcurrentDictionary<K,V> is thread-safe in the sense that it protects its internal state from corruption. It doesn't protect the keys and values it contains, in case these are mutable objects.

In your case the values stored in the dictionary (SubscriptionEntry) are mutable objects. They have public setters, and they expose a public property of type List<SubscriptionCallback>. The List<T> class is not thread-safe. So, no, you can't replace the Dictionary with a ConcurrentDictionary the way you've shown in the question (the My attempt section). Here are some options:

  1. Make sure that the SubscriptionCallback type is immutable. If you want to change it, create a new SubscriptionCallback instance and discard the previous one.
  2. Keep the SubscriptionCallback mutable, but make it thread-safe.
  3. Just keep the Dictionary, and forget about switching to ConcurrentDictionary. The overhead of a lock is minuscule, provided that you are not doing anything non trivial while holding the lock. If you are doing only basic operations (Add/TryGetValue/Remove), it's unlikely that you'll notice any measurable contention, unless you are doing 100,000 operations per second or more.
Theodor Zoulias
  • 34,835
  • 7
  • 69
  • 104
  • Thank you for your answer! I understand. If it's ConcurrentDictionary, it all has to be immutable and thread-safe. One more question about the locking. Should I create a new `private readonly object _lock = new object();` or can I use the same Dictionary as a lock? just like it is on GitHub – nop Apr 18 '22 at 17:57
  • 1
    @nop as long as the `_subscriptionMap` is a private field, it's perfectly OK to use it as a locker. The guideline is to avoid using as lockers publicly exposed objects. – Theodor Zoulias Apr 18 '22 at 18:04
  • I made the final changes based on your opinion. May you have a final look at it before I accept it? https://www.toptal.com/developers/hastebin/omewilizem.csharp. I changed some stuff @ switch. – nop Apr 18 '22 at 19:10
  • @nop two observations: In the `GetCallbacks` [it's not a good idea](https://stackoverflow.com/questions/4608215/does-the-c-sharp-yield-free-a-lock) to `yield` while holding the `lock`. Also the `SubscriptionEntry` class has the public property `Callbacks` which is a `List`. This is a thread-safety issue. Readers might try to enumerate the `List` while it is updated by another thread, resulting in a exception. – Theodor Zoulias Apr 18 '22 at 20:27
  • Thanks! What do you suggest in order to fix these issues? – nop Apr 18 '22 at 21:49
  • @nop these issues are not easily fixable. It's doable, but it's not trivial, and it requires knowledge about how the API is expected to be used in order to do the proper compromises. Building thread-safe types is challenging in general, because many operations that are common in single-thread scenarios do not make sense in multithreaded scenarios (because they introduce race conditions). That's why you see unorthodox APIs like `GetOrAdd`/`AddOrUpdate` in concurrent collections. My suggestion is to search for answers, and if none is found then ask new questions for the specific issues. – Theodor Zoulias Apr 18 '22 at 22:02