1

I wrote a socket class to wrap all the work with asynchronous methods boost::asio, did it for the sake of code reduction, just inherit from this class and use its methods! Is there any flaws, because there is uncertainty that the implementation is in places with UB or bugs!

#include <boost/asio.hpp>

#include <memory>
#include <string>
#include <utility>

namespace network {
    enum Type {
        UDP,
        TCP
    };

    template <typename socket_type, typename resolver_type, typename endpoint_iter_type>
    struct SocketImpl : public std::enable_shared_from_this<SocketImpl<socket_type, resolver_type, endpoint_iter_type>> {
    public:
        typedef std::function<void()> ConnectCallback, PromoteCallback, PostCallback;
        typedef std::function<void(size_t)> WriteCallback;
        typedef std::function<void(const uint8_t *, size_t)> ReadCallback;
        typedef std::function<void(const std::string &)> ErrorCallback;

        explicit SocketImpl(const boost::asio::strand<typename boost::asio::io_service::executor_type> &executor)
            : socket_(executor), resolver_(executor), timeout_(executor) {}

        explicit SocketImpl(socket_type sock)
            : resolver_(sock.get_executor()), timeout_(sock.get_executor()), socket_(std::move(sock)) {}

        void Post(const PostCallback &callback);

        auto Get() { return this->shared_from_this(); }

        void Connect(std::string Host, std::string Port, const ConnectCallback &connect_callback, const ErrorCallback &error_callback);

        virtual void Send(const uint8_t *message_data, size_t size, const WriteCallback &write_callback, const ErrorCallback &error_callback) = 0;

        virtual void Read(size_t size, const ReadCallback &read_callback, const ErrorCallback &error_callback) = 0;

        template <typename Handler> void Await(boost::posix_time::time_duration ms, Handler f);

        virtual void Disconnect();

        ~SocketImpl();

    protected:
        void stop_await();
        
        virtual void do_resolve(std::string host, std::string port, const SocketImpl::ConnectCallback &connect_callback,
            const SocketImpl::ErrorCallback &error_callback) = 0;

        void deadline();

        resolver_type resolver_;

        endpoint_iter_type endpoint_iter_;

        socket_type socket_;

        boost::asio::deadline_timer timeout_;

        boost::asio::streambuf buff_;
    };

    template <Type t>
    struct Socket
        : public SocketImpl<boost::asio::ip::tcp::socket, boost::asio::ip::tcp::resolver, boost::asio::ip::tcp::resolver::iterator> {
        explicit Socket(const boost::asio::strand<typename boost::asio::io_service::executor_type> &executor) : SocketImpl(executor) {}

        explicit Socket(boost::asio::ip::tcp::socket sock) : SocketImpl(std::move(sock)) {
            if (socket_.is_open())
                is_connected = true;
        }

        void Send(const uint8_t *message_data, size_t size, const WriteCallback &write_callback, const ErrorCallback &error_callback) override {
            auto self = Get();
            Post([this, self, message_data, size, write_callback, error_callback] {
                boost::asio::async_write(socket_, boost::asio::buffer(message_data, size),
                    [this, self, write_callback, error_callback](boost::system::error_code ec, std::size_t bytes_transferred) {
                        if (!ec) {
                            write_callback(bytes_transferred);
                        } else {
#ifdef OS_WIN
                            SetThreadUILanguage(MAKELANGID(LANG_ENGLISH, SUBLANG_ENGLISH_US));
#endif
                            error_callback(ec.message());
                        }
                    });
            });
        }

        void Read(size_t size, const ReadCallback &read_callback, const ErrorCallback &error_callback) override {
            auto self = Get();
            Post([this, self, size, read_callback, error_callback] {
                boost::asio::async_read(socket_, boost::asio::buffer(buff_.prepare(size)),
                    [this, self, read_callback, error_callback](boost::system::error_code ec, std::size_t length) {
                        stop_await();
                        if (!ec) {
                            const uint8_t *data = boost::asio::buffer_cast<const uint8_t *>(buff_.data());
                            read_callback(data, length);
                        } else {
#ifdef OS_WIN
                            SetThreadUILanguage(MAKELANGID(LANG_ENGLISH, SUBLANG_ENGLISH_US));
#endif
                            error_callback(ec.message());
                        }
                        buff_.consume(length);
                    });
            });
        }

        bool IsConnected() const { return is_connected; }

        void ReadUntil(std::string until_str, const ReadCallback &read_callback, const ErrorCallback &error_callback) {
            auto self = Get();
            Post([this, self, until_str = std::move(until_str), read_callback, error_callback] {
                boost::asio::async_read_until(socket_, buff_, until_str,
                    [this, read_callback, error_callback](boost::system::error_code ec, std::size_t bytes_transferred) {
                        stop_await();
                        if (!ec) {
                            const uint8_t *data = boost::asio::buffer_cast<const uint8_t *>(buff_.data());
                            read_callback(data, bytes_transferred);
                        } else {
#ifdef OS_WIN
                            SetThreadUILanguage(MAKELANGID(LANG_ENGLISH, SUBLANG_ENGLISH_US));
#endif
                            error_callback(ec.message());
                        }
                        buff_.consume(bytes_transferred);
                    });
            });
        }

    protected:
        void do_resolve(std::string host, std::string port, const SocketImpl::ConnectCallback &connect_callback,
            const SocketImpl::ErrorCallback &error_callback) override {
            auto self = Get();
            resolver_.async_resolve(host, port,
                [this, self, connect_callback, error_callback](
                    boost::system::error_code ec, boost::asio::ip::tcp::resolver::iterator endpoints) {
                    stop_await();
                    if (!ec) {
                        endpoint_iter_ = std::move(endpoints);
                        do_connect(endpoint_iter_, connect_callback, error_callback);
                    } else {
#ifdef OS_WIN
                        SetThreadUILanguage(MAKELANGID(LANG_ENGLISH, SUBLANG_ENGLISH_US));
#endif
                        error_callback("Unable to resolve host: " + ec.message());
                    }
                });
        }

        void do_connect(boost::asio::ip::tcp::resolver::iterator endpoints, const SocketImpl::ConnectCallback &connect_callback,
            const SocketImpl::ErrorCallback &error_callback) {
            auto self = Get();
            boost::asio::async_connect(socket_, std::move(endpoints),
                [this, self, connect_callback, error_callback](
                    boost::system::error_code ec, [[maybe_unused]] const boost::asio::ip::tcp::resolver::iterator &) {
                    stop_await();
                    if (!ec) {
                        connect_callback();
                    } else {
#ifdef OS_WIN
                        SetThreadUILanguage(MAKELANGID(LANG_ENGLISH, SUBLANG_ENGLISH_US));
#endif
                        error_callback("Unable to connect host: " + ec.message());
                    }
                });
        }

        bool is_connected = false;
    };

    template <>
    struct Socket<UDP>
        : public SocketImpl<boost::asio::ip::udp::socket, boost::asio::ip::udp::resolver, boost::asio::ip::udp::resolver::iterator> {
    public:
        explicit Socket(const boost::asio::strand<typename boost::asio::io_service::executor_type> &executor) : SocketImpl(executor) {}

        explicit Socket(boost::asio::ip::udp::socket sock) : SocketImpl(std::move(sock)) {}

        void Send(const uint8_t *message_data, size_t size, const WriteCallback &write_callback, const ErrorCallback &error_callback) override {
            auto self = Get();
            Post([this, self, message_data, size, write_callback, error_callback] {
                socket_.async_send_to(boost::asio::buffer(message_data, size), *endpoint_iter_,
                    [this, self, write_callback, error_callback](boost::system::error_code ec, size_t bytes_transferred) {
                        if (!ec) {
                            write_callback(bytes_transferred);
                        } else {
#ifdef OS_WIN
                            SetThreadUILanguage(MAKELANGID(LANG_ENGLISH, SUBLANG_ENGLISH_US));
#endif
                            error_callback(ec.message());
                        }
                    });
            });
        }

        void Read(size_t size, const ReadCallback &read_callback, const ErrorCallback &error_callback) override {
            auto self = Get();
            Post([this, self, size, read_callback, error_callback] {
                boost::asio::ip::udp::endpoint endpoint = *endpoint_iter_;
                socket_.async_receive_from(boost::asio::buffer(buff_.prepare(size)), endpoint,
                    [this, self, read_callback, error_callback](boost::system::error_code ec, size_t bytes_transferred) {
                        stop_await();
                        if (!ec) {
                            const auto *data = boost::asio::buffer_cast<const uint8_t *>(buff_.data());
                            read_callback(data, bytes_transferred);
                        } else {
                            error_callback(ec.message());
                        }
                        buff_.consume(bytes_transferred);
                    });
            });
        }

        void Promote(const PromoteCallback &callback);

    protected:
        void do_resolve(std::string host, std::string port, const SocketImpl::ConnectCallback &connect_callback,
            const SocketImpl::ErrorCallback &error_callback) override {
            auto self = Get();
            resolver_.async_resolve(host, port,
                [this, self, connect_callback, error_callback](
                    boost::system::error_code ec, boost::asio::ip::udp::resolver::iterator endpoints) {
                    stop_await();
                    if (!ec) {
                        endpoint_iter_ = std::move(endpoints);
                        boost::asio::ip::udp::endpoint endpoint = *endpoint_iter_;
                        socket_.open(endpoint.protocol());

                        connect_callback();
                    } else {
#ifdef OS_WIN
                        SetThreadUILanguage(MAKELANGID(LANG_ENGLISH, SUBLANG_ENGLISH_US));
#endif
                        error_callback("Unable to resolve host: " + ec.message());
                    }
                });
        }
    };

    void Socket<UDP>::Promote(const PromoteCallback &callback) {
        auto self = Get();
        Post([this, self, callback] {
            endpoint_iter_++;
            socket_.cancel();
            callback();
        });
    }

    template <typename socket_type, typename resolver_type, typename endpoint_iter_type>
    void SocketImpl<socket_type, resolver_type, endpoint_iter_type>::Post(const SocketImpl::PostCallback &callback) {
        post(socket_.get_executor(), callback);
    }

    template <typename socket_type, typename resolver_type, typename endpoint_iter_type>
    void SocketImpl<socket_type, resolver_type, endpoint_iter_type>::Connect(std::string Host, std::string Port,
        const SocketImpl::ConnectCallback &connect_callback, const SocketImpl::ErrorCallback &error_callback) {
        auto self = Get();
        Post([this, self, Host, Port, connect_callback, error_callback] { do_resolve(Host, Port, connect_callback, error_callback); });
    }

    template <typename socket_type, typename resolver_type, typename endpoint_iter_type>
    template <typename Handler>
    void SocketImpl<socket_type, resolver_type, endpoint_iter_type>::Await(boost::posix_time::time_duration ms, Handler f) {
        auto self = Get();
        Post([this, ms, self, f] {
            timeout_.expires_from_now(ms);
            timeout_.template async_wait([this, self, f](boost::system::error_code const &ec) {
                if (!ec) {
                    deadline(f);
                }
            });
        });
    }

    template <typename socket_type, typename resolver_type, typename endpoint_iter_type>
    void SocketImpl<socket_type, resolver_type, endpoint_iter_type>::Disconnect() {
        auto self = Get();
        Post([this, self] {
#ifdef OS_WIN
            SetThreadUILanguage(MAKELANGID(LANG_ENGLISH, SUBLANG_ENGLISH_US));
#endif
            timeout_.cancel();
            resolver_.cancel();
            if (socket_.is_open()) socket_.cancel();
        });
    }

    template <typename socket_type, typename resolver_type, typename endpoint_iter_type>
    void SocketImpl<socket_type, resolver_type, endpoint_iter_type>::stop_await() {
        timeout_.cancel();
    }

    template <typename socket_type, typename resolver_type, typename endpoint_iter_type>
    void SocketImpl<socket_type, resolver_type, endpoint_iter_type>::deadline() {
        if (timeout_.expires_at() <= boost::asio::deadline_timer::traits_type::now()) {
            timeout_.cancel();
            socket_.cancel();
        } else {
            auto self(Get());
            timeout_.async_wait([this, self](boost::system::error_code ec) {
                if (!ec) {
                    deadline();
                }
            });
        }
    }

    template <typename socket_type, typename resolver_type, typename endpoint_iter_type>
    SocketImpl<socket_type, resolver_type, endpoint_iter_type>::~SocketImpl() {
        if (socket_.is_open()) socket_.close();
    }
} // namespace network

I use it like this (C++ 17):

struct Client : Socket<TCP> { ... };

Happy to take advice on this structure! Thanks!

1 Answers1

3

That's a lot of code.

  1. Always compile with warnings enabled. This would have told you that members are not constructed in the order you list their initializers. Importantly, the second one is UB:

    explicit SocketImpl(socket_type sock)
        : resolver_(sock.get_executor()), timeout_(sock.get_executor()), socket_(std::move(sock)) {}
    

    Because socket_ is declared before timeout_, it will also be initialized before, meaning that sock.get_executor() is actually use-after-move. Oops. Fix it:

    explicit SocketImpl(socket_type sock)
        : resolver_(sock.get_executor()), socket_(std::move(sock)), timeout_(socket_.get_executor()) {}
    

    Now, even though the other constructor doesn't have such a problem, it's good practice to match declaration order there as well:

    explicit SocketImpl(Executor executor)
        : resolver_(executor)
        , socket_(executor)
        , timeout_(executor) {}
    
    explicit SocketImpl(socket_type sock)
        : resolver_(sock.get_executor())
        , socket_(std::move(sock))
        , timeout_(socket_.get_executor()) {}
    

    (Kudos for making constructors explicit)

  2. I'd implement any Impl class inline (the naming suggests that the entire class is "implementation detail" anyways).

  3. Destructors like this are busy-work:

    template <typename socket_type, typename resolver_type, typename endpoint_iter_type>
    SocketImpl<socket_type, resolver_type, endpoint_iter_type>::~SocketImpl() {
        if (socket_.is_open()) {
            socket_.close();
        }
    }
    

    The default destructor of socket_ will already do that. All you do is get in the way of the compiler to generate optimal, exception safe code. E.g. in this case close() might raise an exception. Did you want that?

  4. Consider taking arguments that hold resource by const-reference, or by value if you intend to std::move() from them.

    virtual void do_resolve(std::string host, std::string port,
                            ConnectCallback const&,
                            ErrorCallback const&) = 0;
    
  5. These instantiations:

    template <Type>
    struct Socket
        : public SocketImpl<boost::asio::ip::tcp::socket,
                            boost::asio::ip::tcp::resolver,
                            boost::asio::ip::tcp::resolver::iterator> {
    

    and

    template <>
    struct Socket<UDP>
        : public SocketImpl<boost::asio::ip::udp::socket,
                            boost::asio::ip::udp::resolver,
                            boost::asio::ip::udp::resolver::iterator> {
    

    Seem laborious. Why not use the generic templates and protocols from Asio directly? You could even throw in a free performance optimization by allowing callers to override the type-erased executor type:

    template <typename Protocol,
              typename Executor = boost::asio::any_io_executor>
    struct SocketImpl
        : public std::enable_shared_from_this<SocketImpl<Protocol, Executor>> {
      public:
        using base_type     = SocketImpl<Protocol, Executor>;
        using socket_type   = std::conditional_t<
            std::is_same_v<Protocol, boost::asio::ip::udp>,
            boost::asio::basic_datagram_socket<Protocol, Executor>,
            boost::asio::basic_socket<Protocol, Executor>>;
        using resolver_type =
            boost::asio::ip::basic_resolver<Protocol, Executor>;
        using endpoint_iter_type = typename resolver_type::iterator;
    

    Now your instantiations can just be:

    template <Type> struct Socket : public SocketImpl<boost::asio::ip::tcp> {
        // ...
    template <> struct Socket<UDP> : public SocketImpl<boost::asio::ip::udp> {
    

    with the exact behaviour you had, or better:

    using StrandEx = boost::asio::strand<boost::asio::io_context::executor_type>;
    
    template <Type> struct Socket : public SocketImpl<boost::asio::ip::tcp, StrandEx> {
        // ...
    template <> struct Socket<UDP> : public SocketImpl<boost::asio::ip::udp, StrandEx> {
    

    with the executor optimized for the strand as you were restricting it to anyways!

  6. Instead of repeating the type arguments:

    explicit Socket(boost::asio::ip::tcp::socket sock) : SocketImpl(std::move(sock)) {
    

    refer to exposed typedefs, so you have a single source of truth:

    explicit Socket(base_type::socket_type sock) : SocketImpl(std::move(sock)) {
    
  7. Pass executors by value. They're cheap to copy and you could even move from them since you're "sinking" them into you members

  8. In fact, just inherit the constructors whole-sale instead of repeating. So even:

    template <Type>
    struct Socket : public SocketImpl<boost::asio::ip::tcp, StrandEx> {
        explicit Socket(StrandEx executor) : SocketImpl(executor) {}
        explicit Socket(base_type::socket_type sock)
            : SocketImpl(std::move(sock)) {}
    

    Could just be:

    template <Type>
    struct Socket : public SocketImpl<boost::asio::ip::tcp, StrandEx> {
        using base_type::base_type;
    

    and land you with the exact same set of constructors.

  9. That constructor sets is_connected but it's lying about it. Because it sets it to true when the socket is merely open. You don't want this, nor do you need it.

    In your code, nobody is using that. What you might want in a deriving client, is a state machine. It's up to them. No need to add a racy, lying interface to your base class. Leave the responsibility where it belongs.

  10. Same with this:

     #ifdef OS_WIN
         SetThreadUILanguage(MAKELANGID(LANG_ENGLISH, SUBLANG_ENGLISH_US));
     #endif
    

    That's a violation of separation of concerns. You might want this behaviour, but your callers/users might want something else. Worse, this behaviour may break their code that had a different preference in place.

  11. Get() does nothing but obscure that it returns shared_from_this. If it is there to avoid explicitly qualifying with this-> (because the base class is a dependent type), just, again, use a using declaration:

    using std::enable_shared_from_this<SocketImpl>::shared_from_this;
    
  12. There's a big problem with PostCallback being std::function. It hides associated executor types! See boost::asio::bind_executor does not execute in strand for a detailed description of this issue

    In your case, there is absolutely no reason at all to type erase the Post argument, so don't:

    void Post(PostCallback callback) {
        post(socket_.get_executor(), std::move(callback));
    }
    

    Should be

    template <typename PostCallback>
    void Post(PostCallback&& callback) {
        post(socket_.get_executor(), std::forward<PostCallback>(callback));
    }
    

    I'd do the same for the other callback types.

    using ConnectCallback = std::function<void()>;
    using PromoteCallback = std::function<void()>;
    using WriteCallback   = std::function<void(size_t)>;
    using ReadCallback    = std::function<void(const uint8_t*, size_t)>;
    using ErrorCallback   = std::function<void(const std::string&)>;
    

    But I'll leave it as an exorcism for the reader for now.

  13. Socket<UDP>::Promote is a weird one. Firstly, I question the logic.

        void Socket<UDP>::Promote(const PromoteCallback &callback) {
            auto self = shared_from_this();
            Post([this, self, callback] {
                endpoint_iter_++;
                socket_.cancel();
                callback();
            });
        }
    

    I do not feel comfortable incrementing an endpoint_iter_ without checking whether it's not already past-the-end.

    Besides, nothing prevents this from running before async_resolve completes. I think it's cleaner to cancel the pending operations before incrementing that iterator.

    Finally, callback() is void(void) so - you're merely trying to synchronize on the completion of the task. I'd suggest a future for this:

       std::future<void> Promote() {
           return Post(std::packaged_task<void()>([this, self = shared_from_this()] {
               socket_.cancel(); // TODO wait for that to complete before incrementing
               endpoint_iter_++;
           }));
       }
    
  14. A class template with a template argument that is never used is a clear sign of the fact that it doesn't need to be a template. Socket<TCP> and Socket<UDP> are not related.

    Separate the conjoined twins makes their life easier:

    struct TCPSocket : SocketImpl<asio::ip::tcp, StrandEx> { /*...*/ };
    struct UDPSocket : SocketImpl<asio::ip::udp, StrandEx> { /*...*/ };
    

    If for some arcane reason you really want to have the template definition:

    template <Type> struct Socket;
    template <> struct Socket<TCP> : TCPSocket { using TCPSocket::TCPSocket; };
    template <> struct Socket<UDP> : UDPSocket { using UDPSocket::UDPSocket; };
    

    I hope that the triviality of it drives home the point that the types don't need to be related.

  15. deadline is missing code, you're calling it with a Handler callback, but it doesn't take any argument. I'll just make up the missing bits:

    template <typename Handler>
    void Await(boost::posix_time::time_duration ms, Handler f) {
        Post([this, self = shared_from_this(), ms, f] {
            timeout_.expires_from_now(ms);
            timeout_.template async_wait(
                [self, f = std::move(f)](error_code ec) {
                    if (!ec) {
                        asio::dispatch(std::move(f));
                    }
                });
        });
    }
    
  16. stop_await() makes it even more enigmatic: the fact that timeout_ is being canceled without regard for who posted Await, when and for how long does suggest that you wanted it to perform as a deadline, so indeed the user callback wasn't applicable really.

    However, I cannot explain why the timeout_ was being restarted automatically (although it wouldn't be if it were canceled, because of the if (!ec) check in the lambda. I admit I really cannot figure this out, so you'll have to decide what you wanted it to do yourself.

  17. The Read/ReadUntil interfaces are quite limited. I cannot see how I'd read a simple HTTP response with it for example. For my example, I'll just read the response headers, instead.

  18. You should always prefer using std::span or std::string_view over pairs of charT const*, size_t. It's just much less error prone, and much more expressive.

    using Data = std::basic_string_view<uint8_t>; // or span and similar
    // ... eg:
    using ReadCallback = std::function<void(Data)>;
    
  19. Wait, what is this?

    asio::ip::udp::endpoint endpoint = *endpoint_iter_;
    socket_.async_receive_from(
        asio::buffer(buff_.prepare(size)), endpoint,
    

    Did you mean to overwrite endpoints from the resolver results? That makes no sense.

    Note, async_receive_from uses the endpoint reference argument to indicate the source of an incoming message. You're passing a reference to a local variable here, which causes Undefined Behaviour because the async operation will be completing after the local variable disappeared.

    Instead, use a member variable.

    asio::ip::udp::endpoint sender_;
    
    void Read(size_t size, const ReadCallback &read_callback, const ErrorCallback &error_callback) override {
        Post([=, this, self = shared_from_this()] {
            socket_.async_receive_from(
                asio::buffer(buff_.prepare(size)), sender_,
    
  20. A streambuf seems overkill for most operations, but certainly for the datagram protocol. Also, you declare it in the base-class, which never uses it anywhere. Consider moving it to the TCP/UDP derived classes.

  21. The whole do_connect/do_resolve thing is something that I think needs to be in SocketImpl. It's largely identical for TCP/UDP and if that's not in the base class, and Read[Until]/Send are already per-protocol, I don't really see why you'd have a base-class at all.

    I'd switch on a IS_DATAGRAM property like

    static constexpr bool is_datagram = std::is_same_v<Protocol, asio::ip::udp>;
    

    And have one implementation:

    void do_resolve(std::string const& host, std::string const& port,
                    ConnectCallback connect_callback,
                    ErrorCallback   error_callback) {
        resolver_.async_resolve(
            host, port,
            [=, this, self = shared_from_this()](
                error_code ec, endpoint_iter_type endpoints) {
                stop_await();
                if (!ec) {
                    endpoint_iter_ = std::move(endpoints);
                    if constexpr (is_datagram) {
                        socket_.open(endpoint_iter_->endpoint().protocol());
                        connect_callback();
                    } else {
                        do_connect(endpoint_iter_, connect_callback,
                                   error_callback);
                    }
                } else {
                    error_callback("Unable to resolve host: " +
                                   ec.message());
                }
            });
    }
    

    If you're wondering how do_connect could compile

  22. For the virtual methods to make sense, there should be a shared interface, which you currently don't have. So either create a base class interface like:

    struct ISocket {
        virtual ~ISocket() = default;
        virtual void Send(Data msg, const WriteCallback &write_callback, const ErrorCallback &error_callback) = 0;
        virtual void Read(size_t size, const ReadCallback &read_callback, const ErrorCallback &error_callback) = 0;
    };
    
    template <typename Protocol, typename Executor = asio::any_io_executor>
    struct SocketImpl
        : public std::enable_shared_from_this<SocketImpl<Protocol, Executor>>
        , ISocket { // ....
    

    Note that this makes it more important to have the virtual destructor. (Although using make_shared<ConcreteType> can save you because the shared pointers contain the right deleter)

  23. Looks to me like the Await should also have been virtual. But you made it a member function template (actually the right thing to do IMO).

  24. Alternatively, instead of going for virtuals, embrace that you didn't need the shared interface based on dynamic polymorphism.

    If you ever need to make SocketImp behaviour dependent on the derived class, you can and make it CRTP (Curiously Recurring Template Pattern) instead.

    This is what I've done below.

Adapted Listing And Demo

Here's a listing with demo. The namespace network went from 313 to 229 lines.

#include <boost/asio.hpp>
#include <memory>
#include <string>

using Data = std::basic_string_view<uint8_t>; // or span and similar

namespace network {
    namespace asio = boost::asio;
    using boost::system::error_code;

    template <typename Protocol, typename Executor = asio::any_io_executor>
    struct SocketImpl
        : public std::enable_shared_from_this<SocketImpl<Protocol, Executor>> {
      public:
        using base_type = SocketImpl<Protocol, Executor>;

        static constexpr bool is_datagram = std::is_same_v<Protocol, asio::ip::udp>;
        using socket_type = std::conditional_t<is_datagram,
                               asio::basic_datagram_socket<Protocol, Executor>,
                               asio::basic_stream_socket<Protocol, Executor>>;
        using resolver_type      = asio::ip::basic_resolver<Protocol, Executor>;
        using endpoint_iter_type = typename resolver_type::iterator;

        using std::enable_shared_from_this<SocketImpl>::shared_from_this;

        using ConnectCallback = std::function<void()>;
        using PromoteCallback = std::function<void()>;
        using WriteCallback   = std::function<void(size_t)>;
        using ReadCallback    = std::function<void(Data)>;
        using ErrorCallback   = std::function<void(const std::string&)>;

        explicit SocketImpl(Executor executor)
            : resolver_(executor)
            , socket_(executor)
            , timeout_(executor) {}

        explicit SocketImpl(socket_type sock)
            : resolver_(sock.get_executor())
            , socket_(std::move(sock))
            , timeout_(socket_.get_executor()) {}

        template <typename Token> decltype(auto) Post(Token&& callback) {
            return asio::post(socket_.get_executor(), std::forward<Token>(callback));
        }

        void Connect(std::string Host, std::string Port,
                     const ConnectCallback& connect_callback,
                     const ErrorCallback&   error_callback) {
            Post([=, self = shared_from_this()] {
                self->do_resolve(Host, Port, connect_callback, error_callback);
            });
        }

        template <typename Handler>
        void Await(boost::posix_time::time_duration ms, Handler f) {
            Post([this, self = shared_from_this(), ms, f] {
                timeout_.expires_from_now(ms);
                timeout_.template async_wait(
                    [self, f = std::move(f)](error_code ec) {
                        if (!ec) {
                            asio::dispatch(std::move(f));
                        }
                    });
            });
        }

        void Disconnect() {
            Post([this, self = shared_from_this()] {
                timeout_.cancel();
                resolver_.cancel();
                if (socket_.is_open()) {
                    socket_.cancel();
                }
            });
        }

      protected:
        void stop_await() { timeout_.cancel(); }

        void do_resolve(std::string const& host, std::string const& port,
                        ConnectCallback connect_callback,
                        ErrorCallback   error_callback) {
            resolver_.async_resolve(
                host, port,
                [=, this, self = shared_from_this()](
                    error_code ec, endpoint_iter_type endpoints) {
                    stop_await();
                    if (!ec) {
                        endpoint_iter_ = std::move(endpoints);
                        if constexpr (is_datagram) {
                            socket_.open(endpoint_iter_->endpoint().protocol());
                            connect_callback();
                        } else {
                            do_connect(endpoint_iter_, connect_callback,
                                       error_callback);
                        }
                    } else {
                        error_callback("Unable to resolve host: " + ec.message());
                    }
                });
        }

        void do_connect(endpoint_iter_type endpoints,
                        ConnectCallback    connect_callback,
                        ErrorCallback      error_callback) {
            async_connect( //
                socket_, std::move(endpoints),
                [=, this, self = shared_from_this()](error_code ec, endpoint_iter_type) {
                    stop_await();
                    if (!ec) {
                        connect_callback();
                    } else {
                        error_callback("Unable to connect host: " + ec.message());
                    }
                });
        }

        resolver_type        resolver_;
        endpoint_iter_type   endpoint_iter_;
        socket_type          socket_;
        asio::deadline_timer timeout_;
    };

    using StrandEx = asio::strand<asio::io_context::executor_type>;

    struct TCPSocket : SocketImpl<asio::ip::tcp, StrandEx> {
        using base_type::base_type;

        void Send(Data msg, WriteCallback write_callback, ErrorCallback error_callback) {
            Post([=, this, self = shared_from_this()] {
                async_write(socket_, asio::buffer(msg),
                            [self, write_callback,
                             error_callback](error_code ec, size_t xfr) {
                                if (!ec) {
                                    write_callback(xfr);
                                } else {
                                    error_callback(ec.message());
                                }
                            });
            });
        }

        void Read(size_t size, ReadCallback read_callback, ErrorCallback error_callback) {
            Post([=, this, self = shared_from_this()] {
                async_read(
                    socket_, asio::buffer(buff_.prepare(size)),
                    [this, self, read_callback, error_callback](error_code ec,
                                                                size_t length) {
                        stop_await();
                        if (!ec) {
                            auto data =
                                asio::buffer_cast<const uint8_t*>(buff_.data());
                            read_callback({data, length});
                        } else {
                            error_callback(ec.message());
                        }
                        buff_.consume(length);
                    });
            });
        }

        void ReadUntil(std::string until_str, const ReadCallback &read_callback, const ErrorCallback &error_callback) {
            Post([=, this, self = shared_from_this()] {
                async_read_until(
                    socket_, buff_, until_str,
                    [=, this](error_code ec, size_t xfr) {
                        stop_await();
                        if (!ec) {
                            auto data =
                                asio::buffer_cast<const uint8_t*>(buff_.data());
                            read_callback({data, xfr});
                        } else {
                            error_callback(ec.message());
                        }
                        buff_.consume(xfr);
                    });
            });
        }

      protected:

        asio::streambuf buff_;
    };

    struct UDPSocket : SocketImpl<asio::ip::udp, StrandEx> {
        using base_type::base_type;

        void Send(Data msg, WriteCallback write_callback, ErrorCallback error_callback) {
            Post([=, this, self = shared_from_this()] {
                socket_.async_send_to( //
                    asio::buffer(msg), *endpoint_iter_,
                    [=](error_code ec, size_t xfr) {
                        if (!ec) {
                            write_callback(xfr);
                        } else {
                            error_callback(ec.message());
                        }
                    });
            });
        }

        void Read(size_t max_size, ReadCallback read_callback, ErrorCallback error_callback) {
            Post([=, this, self = shared_from_this()] {
                socket_.async_receive_from(
                    asio::buffer(buff_, max_size), sender_,
                    [this, self, read_callback, error_callback](error_code ec, size_t xfr) {
                        this->stop_await();
                        if (!ec) {
                            read_callback({buff_.data(), xfr});
                        } else {
                            error_callback(ec.message());
                        }
                    });
            });
        }

        std::future<void> Promote() {
            return Post(std::packaged_task<void()>([this, self = shared_from_this()] {
                socket_.cancel(); // TODO wait for that to complete before incrementing
                endpoint_iter_++;
            }));
        }

      protected:
        asio::ip::udp::endpoint sender_;
        std::array<uint8_t, 65530> buff_; // or whatever lower limit you accept
    };

    enum Type { UDP, TCP };
    template <Type> struct Socket;
    template <> struct Socket<TCP> : TCPSocket { using TCPSocket::TCPSocket; };
    template <> struct Socket<UDP> : UDPSocket { using UDPSocket::UDPSocket; };
} // namespace network

Now we can use the Socket<> or direct subtypes to define simplistic clients:

struct Client : network::Socket<network::TCP> {
    using network::Socket<network::TCP>::Socket;
};

struct UDPClient : network::UDPSocket {
    using network::UDPSocket::UDPSocket;
};

As you can see, I don't add any behaviour for now.

Instead, let's write a simple HTTP client and a simple UDP Echo client directly in main:

static void on_error(std::string const& s) {
    std::cout << "Error: " << s << std::endl;
}

int main() {
    asio::io_context io;
    static Data const request{reinterpret_cast<uint8_t const*>("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")};
    static Data const msg{reinterpret_cast<uint8_t const*>("Hello world!\n")};

    auto http = std::make_shared<Client>(make_strand(io));
    auto echo = std::make_shared<UDPClient>(make_strand(io));

    http->Connect(
        "example.com", "http",
        [http] {
            std::cout << "(Http) Connected" << std::endl;
            http->Await(seconds(2), [http] {
                std::cout << "(Http) Sending" << std::endl;
                http->Send(
                    request,
                    [http](size_t n) {
                        std::cout << "(Http) Sent " << n << std::endl;
                        http->ReadUntil(
                            "\r\n\r\n",
                            [http](Data response) {
                                std::cout << "(Http) Read: ";
                                std::cout.write(reinterpret_cast<char const*>(
                                                    response.data()),
                                                response.size());
                                std::cout << std::endl;
                                http->Disconnect();
                            },
                            on_error);
                    },
                    on_error);
            });
        },
        on_error);

    echo->Connect(
        "localhost", "echo",
        [echo] {
            std::cout << "(Echo) Connected" << std::endl;
            echo->Await(seconds(1), [echo] {
                std::cout << "(Echo) Sending" << std::endl;
                echo->Send(
                    msg,
                    [echo](size_t n) {
                        std::cout << "(Echo) Sent " << n << std::endl;
                        echo->Read(
                            msg.size(),
                            [echo](Data response) {
                                std::cout << "(Echo) Read: ";
                                std::cout.write(reinterpret_cast<char const*>(
                                                    response.data()),
                                                response.size());
                                std::cout << std::endl;

                                echo->Disconnect();
                            },
                            on_error);
                    },
                    on_error);
            });
        },
        on_error);

    std::cout << "START" << std::endl;
    io.run();
    std::cout << "END" << std::endl;
}

You can see it live on my box:

enter image description here

sehe
  • 374,641
  • 47
  • 450
  • 633
  • 1
    I'm not convinced it's actually that much simpler. Here's the same HTTP/Echo clients without any helper classes. It looks pretty much the same, just 240 fewer lines of code: http://coliru.stacked-crooked.com/a/aba4b8cd746c2c29 (demo https://imgur.com/a/j6oUtbt). And the `tcp_stream` actually does the correct timeout. And it would be easier to add the `Promote` support correctly: http://coliru.stacked-crooked.com/a/d0d09fbb692f399b (9 more lines of code) – sehe Apr 25 '22 at 02:41
  • 1
    I do realize that my short version supposes more understanding of Asio, but I'd like to submit that using your interface instead is also quite restrictive. It may work well for specific use cases (where `uint8_t` buffers, are the norm, e.g.) though. – sehe Apr 25 '22 at 02:43
  • Thank you very much, I didn't even expect such a comprehensive analysis of my garbage code! <3 – HamsterGamer Apr 25 '22 at 07:29
  • 1
    It's not garbage code by any measure :) It did the strands/post right (subtle) and avoided the many (many) pitfalls with using streambuf and conditional read_untils. The things I didn't mention are still there! I should learn to give more kudos in reviews. By the way, the other commenter was right, you should probably consider posting at codereview.stackexchange.com – sehe Apr 25 '22 at 12:04