0

I'm using std::atomic to create a stack that utilizes CAS.
To solve the ABA problem, I used tagged pointer.
The code is shown below.

template <typename T>
union tagged_ptr {
    struct
    {
        std::uint64_t tag : 12, ptr : 52;
    };
    std::uint64_t full;

    tagged_ptr(const std::uint64_t &full)
    {
        this->full = full;
    }
    tagged_ptr(T *ptr = nullptr, std::uint16_t cnt = 0)
    {
        tag = cnt;
        this->ptr = reinterpret_cast<std::uint64_t>(ptr);
    }
    T *get()
    {
        return reinterpret_cast<T *>(ptr);
    }
};

template <typename T>
class lfstack
{
    struct Node
    {
        T data;
        Node *next;
        Node(const T &data, Node *next = nullptr)
        {
            this->data = data;
            this->next = next;
        }
    };

    std::atomic_uint64_t m_top;
    std::atomic_size_t m_size;

  public:
    lfstack()
    {
        m_top = 0;
        m_size = 0;
    }

    ~lfstack()
    {
        Node *ptr = reinterpret_cast<Node *>(m_top.load(std::memory_order_relaxed)), *next;
        while (ptr)
        {
            next = ptr->next;
            delete ptr;
            ptr = next;
        }
    }

    size_t size()
    {
        return m_size.load();
    }

    bool empty()
    {
        return !size();
    }

    const T &top()
    {
        return tagged_ptr<Node>(m_top.load()).get()->data;
    }

    std::optional<T> pop()
    {
        tagged_ptr<Node> local_ptr(m_top.load(std::memory_order_relaxed));
        while (true)
        {
            if (!local_ptr.get())
                return std::nullopt;
            tagged_ptr<Node> local_next(local_ptr.get()->next, local_ptr.tag); 
            if (m_top.compare_exchange_weak(local_ptr.full, local_next.full))
            {
                T ret_val = std::move(local_ptr.get()->data);
                delete local_ptr.get();
                m_size.fetch_sub(1, std::memory_order_relaxed);
                return ret_val;
            }
        }
    }

    void push(const T &data)
    {
        tagged_ptr<Node> local_ptr(m_top.load(std::memory_order_relaxed)), new_ptr(new Node(data));
        while (true)
        {
            new_ptr.get()->next = local_ptr.get();
            new_ptr.tag = local_ptr.tag + 1;
            if (m_top.compare_exchange_weak(local_ptr.full, new_ptr.full))
            {
                m_size.fetch_add(1, std::memory_order_relaxed);
                break;
            }
        }
    }
};

I've tested it on several threads and it seems to work fine. (I tested it on the MSVC x64 compiler.)

But the line tagged_ptr<Node> local_next(local_ptr.get()->next, local_ptr.tag); doesn't look thread-safe to me, am I right?

If I do delete local_ptr.get(); in another thread, it seems like referencing local_ptr in local_ptr.get()->next would be UB, so should I delay the object destruction in these cases?

tongstar
  • 35
  • 5
  • 2
    Keep in mind that [5 level page tables](https://en.m.wikipedia.org/wiki/Intel_5-level_paging) are a thing - Intel Ice Lake and above processors can handle virtual addresses up to 57 bits instead of the usual 48. so your tagged_ptr might end up truncating pointers on those platforms. (you would need to assign at least 56 bits for the ptr value for those processors - or ensure your app never uses a pointer with the higher bits set) – Turtlefight May 03 '23 at 09:05
  • Type-punning through a union is undefined behaviour. You don't need it. – Caleth May 03 '23 at 10:47

0 Answers0