-1

I made a single server socket that i want to allow multiple connections in a multithreaded fasion but there's an issue. It drops messages from clients for no apparent reason

Each socket is handled by their own thread so my guess was that it shouldn't be an issue (may be it is).

Here is the code

#define WIN32_LEAN_AND_MEAN
#define NOMINMAX
#include <windows.h>
#include <WinSock2.h>
#include <WS2tcpip.h>

#include <iostream>
#include <thread>
#include <mutex>
#include <chrono>
#include <vector>
#include <sstream>
#include <cassert>

// Taken from: https://stackoverflow.com/a/46104456/6119618
static std::string wsa_error_to_string(int wsa_error)
{
    char msgbuf [256];   // for a message up to 255 bytes.
    msgbuf [0] = '\0';    // Microsoft doesn't guarantee this on man page.

    FormatMessage(
        FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, // flags
        nullptr,                                                    // lpsource
        wsa_error,                                                  // message id
        MAKELANGID (LANG_NEUTRAL, SUBLANG_DEFAULT),                 // languageid
        msgbuf,                                                     // output buffer
        sizeof (msgbuf),                                            // size of msgbuf, bytes
        nullptr
    );

    if (! *msgbuf)
        sprintf (msgbuf, "%d", wsa_error);  // provide error # if no string available
    return msgbuf;
}

#define PRINT_ERROR_AND_TERMINATE(MSG) do { std::cerr << (MSG) << std::endl; assert(0); } while(0)

struct wsa_lifetime
{
    wsa_lifetime()
    {
        int result = ::WSAStartup(MAKEWORD(2,2), &wsa_data);
        assert(result == 0);
        is_initialized = true;
    }

    ~wsa_lifetime()
    {
        ::WSACleanup();
    }

    WSAData wsa_data {};
    bool is_initialized {false};
};

static wsa_lifetime wsa_lifetime;

static SOCKET socket_create()
{
    SOCKET socket = ::socket(AF_INET, SOCK_STREAM, 0);
    assert(socket != INVALID_SOCKET);
    return socket;
}

static void socket_destroy(SOCKET socket)
{
    ::closesocket(socket);
    socket = INVALID_SOCKET;
}

static void socket_bind(SOCKET socket, const char *address, uint16_t port)
{
    sockaddr_in addr {};
    addr.sin_family = AF_INET;
    inet_pton(AF_INET, address, &addr.sin_addr.s_addr);
    addr.sin_port = htons(port);

    int bind_result = ::bind(socket, reinterpret_cast<SOCKADDR *>(&addr), sizeof(addr));
    if (bind_result == SOCKET_ERROR)
        PRINT_ERROR_AND_TERMINATE(WSAGetLastError());
}

static void socket_connect(SOCKET socket, const char *address, uint16_t port)
{
    sockaddr_in addr {};
    addr.sin_family = AF_INET;
    inet_pton(AF_INET, address, &addr.sin_addr.s_addr);
    addr.sin_port = htons(port);

    int connect_result = ::connect(socket, reinterpret_cast<SOCKADDR *>(&addr), sizeof(addr));
    if (connect_result == SOCKET_ERROR)
        PRINT_ERROR_AND_TERMINATE(WSAGetLastError());
}

static void socket_listen(SOCKET socket)
{
    int listen_result = ::listen(socket, SOMAXCONN);
    if (listen_result == SOCKET_ERROR)
        PRINT_ERROR_AND_TERMINATE(WSAGetLastError());
}

static SOCKET socket_accept(SOCKET socket)
{
    SOCKET accepted_socket = ::accept(socket, nullptr, nullptr);
    if (accepted_socket == INVALID_SOCKET)
        PRINT_ERROR_AND_TERMINATE(WSAGetLastError());
    return accepted_socket;
}

static size_t socket_recv(SOCKET socket, char *buffer, size_t buffer_size, int flags = 0)
{
    int bytes_received = ::recv(socket, buffer, static_cast<int>(buffer_size), flags);
    if (bytes_received == SOCKET_ERROR)
    {
        int err = WSAGetLastError();
        if (err == WSAECONNRESET)
            return 0; // Disconnected client
        PRINT_ERROR_AND_TERMINATE(WSAGetLastError());
    }
    return bytes_received;
}

static size_t socket_send(SOCKET socket, const char *data, size_t data_size, int flags = 0)
{
    int bytes_sent = ::send(socket, data, static_cast<int>(data_size), flags);
    if (bytes_sent == SOCKET_ERROR)
    {
        int err = WSAGetLastError();
        if (err == WSAECONNRESET)
            return 0; // Disconnected client
        PRINT_ERROR_AND_TERMINATE(WSAGetLastError());
    }
    return bytes_sent;
}

static std::mutex output_mutex;

int main()
{
    const char *server_address = "127.0.0.1";
    uint16_t server_port = 23456;
    bool server_terminate = false;

    std::thread server_thread([server_address, server_port, &server_terminate](){
        SOCKET server = socket_create();
        socket_bind(server, server_address, server_port);
        socket_listen(server);

        std::vector<SOCKET> clients;
        std::vector<std::thread> client_threads;

        while (!server_terminate)
        {
            SOCKET incoming_client = socket_accept(server);
            if (server_terminate)
                break;

            clients.push_back(incoming_client);
            size_t client_id = clients.size();

            std::thread incoming_client_thread([&incoming_client, client_id](){
                const size_t data_size = 1024;
                char data[data_size];

                while (true)
                {
                    size_t bytes_received = socket_recv(incoming_client, data, data_size, 0);
                    if (bytes_received == 0)
                        break;

                    std::string_view client_message(data, bytes_received);
                    {
                        std::unique_lock lock(output_mutex);
                        std::cout << "Client (" << client_id << "): " << client_message << std::endl;
                    }
                }
            });
            client_threads.push_back(std::move(incoming_client_thread));
        }

        for (std::thread &client_thread: client_threads)
            if (client_thread.joinable())
                client_thread.join();
    });

    std::vector<SOCKET> clients;
    std::vector<std::thread> client_threads;

    for (int i = 0; i < 4; i++)
    {
        SOCKET client = socket_create();
        clients.push_back(client);
    }

    for (SOCKET client : clients)
    {
        std::thread client_thread([server_address, server_port, client](){
            socket_connect(client, server_address, server_port);

            for (int i = 0; i < 10; i++)
            {
                std::string data_str = (std::stringstream() << "hello " << i).str();
                socket_send(client, data_str.c_str(), data_str.size());

                using namespace std::chrono_literals;
                std::this_thread::sleep_for(100ms + 1ms * (rand() % 100));
            }
        });
        client_threads.push_back(std::move(client_thread));
    }

    for (std::thread &client_thread : client_threads)
        if (client_thread.joinable())
            client_thread.join();

    for (SOCKET client: clients)
        socket_destroy(client);
    clients.clear();

    server_terminate = true;

    SOCKET dummy_socket = socket_create();
    socket_connect(dummy_socket, server_address, server_port); // just to unblock server's socket_accept() blocking call
    socket_destroy(dummy_socket);

    if (server_thread.joinable())
        server_thread.join();

    return 0;
}

Possible output:

Client (2): hello 0
Client (2): hello 0
Client (3): hello 1
Client (2): hello 2
Client (1): hello 3
Client (4): hello 4
Client (3): hello 5
Client (2): hello 6
Client (1): hello 7
Client (4): hello 8
Client (3): hello 9

I expected each client to send 10 messages, 40 in total but some messages are dropped as you can see. I think it shouldn't drop even with UDP transport because all job is done on my loopback network
Wireshark registers all the messages

IC_
  • 1,624
  • 1
  • 23
  • 57
  • You detach server threads. Thus servers highly probably exit before they finish their works. – 273K Mar 25 '22 at 14:54
  • @273K I don't think this is the issue. `accept` is a blocking call, if i add an endless loop just before `return 0;` the output is the same – IC_ Mar 25 '22 at 15:46
  • 1
    While the problem may be in the code you've posted, without seeing the code you haven't posted, it's essentially impossible to guess where the problem may be. My immediate guess would be that some of the client threads haven't even started running before you shut down--you loop through, and join the ones that are joinable. But if any isn't joinable, that (at least usually) means it hasn't started running yet, so you're exiting before they've even started. Wait until they're all joinable, then join them all, then exit. – Jerry Coffin Mar 25 '22 at 17:21
  • @JerryCoffin i've updated the code, now it is self-sufficient to copy, paste, and compile – IC_ Mar 26 '22 at 12:16
  • What does UDP have to do with it? This is TCP. Where is the client code? – user207421 Mar 27 '22 at 04:18
  • 1
    I may be confused, but is there a good reason for the `incoming_client_thread` lambda to capture `incoming_client` by reference rather than by copy? – Hasturkun Mar 30 '22 at 15:54
  • @hasturkun did you mean vise-versa? If yes, then it doesn't change anything. It's just a mistake when i ported my actual code to minimal example – IC_ Mar 30 '22 at 16:14
  • 1
    No, I mean that the value changes within the loop, but the thread holds a reference to the variable, not its value. The first thread might not be calling `socket_recv` on the same socket once a second `socket_accept` succeeds. – Hasturkun Mar 30 '22 at 16:18
  • @hasturkun oh, you're right. This was the source of the issue. Thank you – IC_ Mar 31 '22 at 03:51

1 Answers1

1

When constructing the lambda incoming_client_thread, you capture incoming_client by reference and not by copy.

Since this variable is reset at the start of each loop by socket_accept, a thread might not be calling socket_recv on the same socket once another socket_accept succeeds.

Hasturkun
  • 35,395
  • 6
  • 71
  • 104