0

My MVCE for SSL relay server:

#pragma once

#include <stdint.h>
#include <iostream>
#include <asio.hpp>
#include <asio/ssl.hpp>

namespace test
{

namespace setup
{
    const uint32_t maxMessageSize = 1024 * 1024;
    const uint32_t maxSessionsNum = 10;
}

enum class MessageType
{
    LOG_ON = 0,
    TEXT_MESSAGE = 1
};

class MessageHeader
{
public:
    uint32_t messageType;
    uint32_t messageLength;

    MessageHeader(uint32_t messageType, uint32_t messageLength) : messageType(messageType), messageLength(messageLength) {}
};

class LogOn
{
public:
    MessageHeader header;
    uint32_t      sessionId;
    uint32_t      isClient0;

    LogOn() : header((uint32_t)MessageType::LOG_ON, sizeof(LogOn)) {}
};

class TextMessage
{
public:
    MessageHeader header;
    uint8_t       data[];

    TextMessage() : header((uint32_t)MessageType::TEXT_MESSAGE, sizeof(TextMessage)){}
};

class ClientSocket;

class Session
{
public:
    ClientSocket* pClient0;
    ClientSocket* pClient1;
};

Session* getSession(uint32_t sessionId);

class ClientSocket
{
public:
    bool useTLS;

    std::shared_ptr<asio::ip::tcp::socket> socket;
    std::shared_ptr<asio::ssl::stream<asio::ip::tcp::socket>> socketSSL;

    Session* pSession;
    bool     isClient0;

    std::recursive_mutex writeBufferLock;
    std::vector<char>    readBuffer;
    uint32_t             readPos;

    ClientSocket(asio::ip::tcp::socket& socket) : useTLS(false)
    {
        this->socket = std::make_shared<asio::ip::tcp::socket>(std::move(socket));
        this->readBuffer.resize(setup::maxMessageSize + sizeof(MessageHeader));
        this->readPos = 0;
    }

 ClientSocket(asio::ssl::stream<asio::ip::tcp::socket>& socket) : useTLS(true)
    {
        this->socketSSL = std::make_shared<asio::ssl::stream<asio::ip::tcp::socket>>(std::move(socket));
        this->readBuffer.resize(setup::maxMessageSize + sizeof(MessageHeader));
        this->readPos = 0;
    }

    bool writeSocket(uint8_t* pBuffer, uint32_t bufferSize)
    {
        try
        {
            std::unique_lock<std::recursive_mutex>
lock(this->writeBufferLock);

            size_t writtenBytes = 0;

            if (true == this->useTLS)
            {
                writtenBytes = asio::write(*this->socketSSL,
asio::buffer(pBuffer, bufferSize));
            }
            else
            {
                writtenBytes = asio::write(*this->socket,
asio::buffer(pBuffer, bufferSize));
            }

            return (writtenBytes == bufferSize);
        }
        catch (asio::system_error e)
        {
            std::cout << e.what() << std::endl;
        }
        catch (std::exception e)
        {
            std::cout << e.what() << std::endl;
        }
        catch (...)
        {
            std::cout << "Some other exception" << std::endl;
        }

        return false;
    }

    void asyncReadNextMessage(uint32_t messageSize)
    {
        auto readMessageLambda = [&](const asio::error_code errorCode, std::size_t length)
        {
            this->readPos += (uint32_t)length;

            if (0 != errorCode.value())
            {
                //send socket to remove
                 printf("errorCode= %u, message=%s\n", errorCode.value(), errorCode.message().c_str());
                //sendRemoveMeSignal();
                return;
            }

            if ((this->readPos < sizeof(MessageHeader)))
            {
                asyncReadNextMessage(sizeof(MessageHeader) - this->readPos);
                return;
            }

            MessageHeader* pMessageHeader = (MessageHeader*)this->readBuffer.data();

            if (pMessageHeader->messageLength > setup::maxMessageSize)
            {
                //Message to big - should disconnect ?
                this->readPos = 0;
                asyncReadNextMessage(sizeof(MessageHeader));
                return;
            }

            if (this->readPos < pMessageHeader->messageLength)
            {
                asyncReadNextMessage(pMessageHeader->messageLength - this->readPos);
                return;
            }

            MessageType messageType = (MessageType)pMessageHeader->messageType;

            switch(messageType)
            {
                case MessageType::LOG_ON:
                {
                    LogOn* pLogOn = (LogOn*)pMessageHeader;
                    printf("LOG_ON message sessionId=%u, isClient0=%u\n", pLogOn->sessionId, pLogOn->isClient0);

                    this->isClient0 = pLogOn->isClient0;
                    this->pSession  = getSession(pLogOn->sessionId);

                    if (this->isClient0)
                        this->pSession->pClient0 = this;
                    else
                        this->pSession->pClient1 = this;

                }
                break;
                case MessageType::TEXT_MESSAGE:
                {
                    TextMessage* pTextMessage = (TextMessage*)pMessageHeader;

                    if (nullptr != pSession)
                    {
                        if (this->isClient0)
                        {
                            if (nullptr != pSession->pClient1)
                            {
                                pSession->pClient1->writeSocket((uint8_t*)pTextMessage, pTextMessage->header.messageLength);
                            }
                        }
                        else
                        {
                            if (nullptr != pSession->pClient0)
                            {
                                pSession->pClient0->writeSocket((uint8_t*)pTextMessage, pTextMessage->header.messageLength);
                            }
                        }
                    }
                }
                break;
            }

            this->readPos = 0;
            asyncReadNextMessage(sizeof(MessageHeader));
        };

        if (true == this->useTLS)
        {
            this->socketSSL->async_read_some(asio::buffer(this->readBuffer.data() + this->readPos, messageSize), readMessageLambda);
        }
        else
        {
            this->socket->async_read_some(asio::buffer(this->readBuffer.data() + this->readPos, messageSize), readMessageLambda);
        }
    }
};

class SSLRelayServer
{
public:
    static SSLRelayServer* pSingleton;

    asio::io_context   ioContext;
    asio::ssl::context sslContext;

    std::vector<std::thread> workerThreads;

    asio::ip::tcp::acceptor* pAcceptor;
    asio::ip::tcp::endpoint* pEndpoint;

    bool useTLS;

    Session* sessions[setup::maxSessionsNum];

    SSLRelayServer() : pAcceptor(nullptr), pEndpoint(nullptr), sslContext(asio::ssl::context::tlsv13_server)//sslContext(asio::ssl::context::sslv23)
    {
        this->useTLS     = false;
        this->pSingleton = this;

        //this->sslContext.set_options(asio::ssl::context::default_workarounds | asio::ssl::context::no_sslv2);
        this->sslContext.set_password_callback(std::bind(&SSLRelayServer::getPrivateKeyPEMFilePassword, this));
        this->sslContext.use_certificate_chain_file("server_cert.pem");
        this->sslContext.use_private_key_file("server_private_key.pem",
        asio::ssl::context::pem);
    }

    static SSLRelayServer* getSingleton()
    {
        return pSingleton;
    }

    std::string getPrivateKeyPEMFilePassword() const
    {
        return "";
    }

    void addClientSocket(asio::ip::tcp::socket& socket)
    {
        ClientSocket* pClientSocket = new ClientSocket(socket); // use smart pointers
        pClientSocket->asyncReadNextMessage(sizeof(MessageHeader));
    }

    void addSSLClientToken(asio::ssl::stream<asio::ip::tcp::socket>&sslSocket)
    {
        ClientSocket* pClientSocket = new ClientSocket(sslSocket); // use smart pointers
        pClientSocket->asyncReadNextMessage(sizeof(MessageHeader));
    }

    void handleAccept(asio::ip::tcp::socket& socket, const asio::error_code& errorCode)
    {
        if (!errorCode)
        {
            printf("accepted\n");

            if (true == socket.is_open())
            {
                asio::ip::tcp::no_delay no_delay_option(true);
                socket.set_option(no_delay_option);

                addClientSocket(socket);
            }
        }
    }

    void handleAcceptTLS(asio::ip::tcp::socket& socket, const asio::error_code& errorCode)
    {
        if (!errorCode)
        {
            printf("accepted\n");

            if (true == socket.is_open())
            {
                asio::ip::tcp::no_delay no_delay_option(true);

                asio::ssl::stream<asio::ip::tcp::socket> sslStream(std::move(socket), this->sslContext);

                try
                {
                    sslStream.handshake(asio::ssl::stream_base::server);
                    sslStream.lowest_layer().set_option(no_delay_option);

                    addSSLClientToken(sslStream);
                }
                catch (asio::system_error e)
                {
                    std::cout << e.what() << std::endl;
                    return;
                }
                catch (std::exception e)
                {
                    std::cout << e.what() << std::endl;
                    return;
                }
                catch (...)
                {
                    std::cout << "Other exception" << std::endl;
                    return;
                }

            }

        }
    }

    void startAccept()
    {
        auto acceptHandler = [this](const asio::error_code& errorCode, asio::ip::tcp::socket socket)
        {
            printf("acceptHandler\n");

            handleAccept(socket, errorCode);

            this->startAccept();
        };

        auto tlsAcceptHandler = [this](const asio::error_code& errorCode, asio::ip::tcp::socket socket)
        {
            printf("tlsAcceptHandler\n");

            handleAcceptTLS(socket, errorCode);

            this->startAccept();
        };

        if (true == this->useTLS)
        {
            this->pAcceptor->async_accept(tlsAcceptHandler);
        }
        else
        {
            this->pAcceptor->async_accept(acceptHandler);
        }
    }

    bool run(uint32_t servicePort, uint32_t threadsNum, bool useTLS)
    {
        this->useTLS = useTLS;

        this->pEndpoint = new asio::ip::tcp::endpoint(asio::ip::tcp::v4(), servicePort);
        this->pAcceptor = new asio::ip::tcp::acceptor(ioContext, *pEndpoint);

        this->pAcceptor->listen();

        this->startAccept();

        for (uint32_t threadIt = 0; threadIt < threadsNum; ++threadIt)
        {
            this->workerThreads.emplace_back([&]() {
#ifdef WINDOWS
                SetThreadDescription(GetCurrentThread(), L"SSLRelayServer worker thread");
#endif
                this->ioContext.run(); }
            );
        }

        return true;
    }

    Session* getSession(uint32_t sessionId)
    {
        if (nullptr == this->sessions[sessionId])
        {
            this->sessions[sessionId] = new Session();
        }

        return this->sessions[sessionId];
    }
};

SSLRelayServer* SSLRelayServer::pSingleton = nullptr;

Session* getSession(uint32_t sessionId)
{
    SSLRelayServer* pServer = SSLRelayServer::getSingleton();
    Session*        pSession = pServer->getSession(sessionId);
    return pSession;
}

class Client
{
public:
    asio::ssl::context sslContext;

    std::shared_ptr<asio::ip::tcp::socket> socket;
    std::shared_ptr<asio::ssl::stream<asio::ip::tcp::socket>> socketSSL;

    asio::io_context ioContext;

    bool useTLS;
    bool isClient0;

    uint32_t             readDataIt;
    std::vector<uint8_t> readBuffer;

    std::thread listenerThread;

    Client() : sslContext(asio::ssl::context::tlsv13_client)//sslContext(asio::ssl::context::sslv23)
    {
        sslContext.load_verify_file("server_cert.pem");
        //sslContext.set_verify_mode(asio::ssl::verify_peer);

        using asio::ip::tcp;
        using std::placeholders::_1;
        using std::placeholders::_2;
        sslContext.set_verify_callback(std::bind(&Client::verifyCertificate, this, _1, _2));

        this->readBuffer.resize(setup::maxMessageSize);
        this->readDataIt = 0;
    }

    bool verifyCertificate(bool preverified, asio::ssl::verify_context& verifyCtx)
    {
        return true;
    }

    void listenerRunner() 
    {
#ifdef WINDOWS
        if (this->isClient0)
        {
            SetThreadDescription(GetCurrentThread(), L"listenerRunner client0");
        }
        else
        {
            SetThreadDescription(GetCurrentThread(), L"listenerRunner client1");
        }
#endif

        while (1==1)
        {
            asio::error_code errorCode;

            size_t transferred = 0;
            if (true == this->useTLS)
            {
                transferred = this->socketSSL->read_some(asio::buffer(this->readBuffer.data() + this->readDataIt, sizeof(MessageHeader) - this->readDataIt), errorCode);
            }
            else
            {
                transferred = this->socket->read_some(asio::buffer(this->readBuffer.data() + this->readDataIt, sizeof(MessageHeader) - this->readDataIt), errorCode);
            }

            this->readDataIt += transferred;

            if (0 != errorCode.value())
            {
                this->readDataIt = 0;
                continue;
            }

            if (this->readDataIt < sizeof(MessageHeader))
                continue;

            MessageHeader* pMessageHeader = (MessageHeader*)this->readBuffer.data();

            if (pMessageHeader->messageLength > setup::maxMessageSize)
            {
                exit(1);
            }

            bool resetSocket = false;

            while (pMessageHeader->messageLength > this->readDataIt)
            {
                printf("readDataIt=%u, threadId=%u\n", this->readDataIt, GetCurrentThreadId());

                {
                    //message not complete
                    if (true == this->useTLS)
                    {
                        transferred = this->socketSSL->read_some(asio::buffer(this->readBuffer.data() + this->readDataIt, pMessageHeader->messageLength - this->readDataIt), errorCode);
                    }
                    else
                    {
                        transferred = this->socket->read_some(asio::buffer(this->readBuffer.data() + this->readDataIt, pMessageHeader->messageLength - this->readDataIt), errorCode);
                    }

                    this->readDataIt += transferred;
                }

                if (0 != errorCode.value())
                {
                    exit(1);
                }
            }

            MessageType messageType = (MessageType)pMessageHeader->messageType;

            switch (messageType)
            {
                case MessageType::TEXT_MESSAGE:
                {
                    TextMessage* pTextMessage = (TextMessage*)pMessageHeader;
                    printf("TEXT_MESSAGE: %s\n", pTextMessage->data);
                }
                break;
            }

            this->readDataIt = 0;
        }
    }

    void run(uint32_t sessionId, bool isClient0, bool useTLS, uint32_t servicePort)
    {
        this->useTLS    = useTLS;
        this->isClient0 = isClient0;

        if (useTLS)
        {
            socketSSL = std::make_shared<asio::ssl::stream<asio::ip::tcp::socket>>(ioContext, sslContext);
        }
        else
        {
            socket = std::make_shared<asio::ip::tcp::socket>(ioContext);
        }

        asio::ip::tcp::resolver resolver(ioContext);

        asio::ip::tcp::resolver::results_type endpoints = resolver.resolve(asio::ip::tcp::v4(), "127.0.0.1", std::to_string(servicePort));

        asio::ip::tcp::no_delay no_delay_option(true);

        if (true == useTLS)
        {
            asio::ip::tcp::endpoint sslEndpoint = asio::connect(socketSSL->lowest_layer(), endpoints);
            socketSSL->handshake(asio::ssl::stream_base::client);
            socketSSL->lowest_layer().set_option(no_delay_option);
        }
        else
        {
            asio::ip::tcp::endpoint endpoint = asio::connect(*socket, endpoints);
            socket->set_option(no_delay_option);
        }

        this->listenerThread = std::thread(&Client::listenerRunner, this);

        LogOn logOn;
        logOn.isClient0 = isClient0;
        logOn.sessionId = sessionId;

        const uint32_t logOnSize = sizeof(logOn);

        if (true == useTLS)
        {
            size_t transferred = asio::write(*socketSSL, asio::buffer(&logOn, sizeof(LogOn)));
        }
        else
        {
            size_t transferred = asio::write(*socket, asio::buffer(&logOn, sizeof(LogOn)));
        }

        uint32_t counter = 0;

        while (1 == 1)
        {
            std::string number  = std::to_string(counter);
            std::string message;

            if (this->isClient0)
            {
                message = "Client0: " + number;
            }
            else
            {
                message = "Client1: " + number;
            }

            TextMessage textMessage;
            textMessage.header.messageLength += message.size() + 1;

            if (this->useTLS)
            {
                size_t transferred = asio::write(*socketSSL, asio::buffer(&textMessage, sizeof(TextMessage)));
                transferred        = asio::write(*socketSSL, asio::buffer(message.c_str(), message.length() + 1));
            }
            else
            {
                size_t transferred = asio::write(*socket, asio::buffer(&textMessage, sizeof(TextMessage)));
                transferred        = asio::write(*socket, asio::buffer(message.c_str(), message.length() + 1));
            }

            ++counter;
            //Sleep(1000);
        }
    }
};

void clientTest(uint32_t sessionId, bool isClient0, bool useTLS,
uint32_t servicePort)
{
#ifdef WINDOWS
    if (isClient0)
    {
        SetThreadDescription(GetCurrentThread(), L"Client0");
    }
    else
    {
        SetThreadDescription(GetCurrentThread(), L"Client1");
    }
#endif

    Client client;

    client.run(sessionId, isClient0, useTLS, servicePort);

    while (1 == 1)
    {
        Sleep(1000);
    }
}

void SSLRelayTest()
{
    SSLRelayServer relayServer;

    const uint32_t threadsNum  = 1;
    const bool     useTLS      = true;
    const uint32_t servicePort = 777;
    relayServer.run(servicePort, threadsNum, useTLS);

    Sleep(5000);

    std::vector<std::thread> threads;

    const uint32_t sessionId = 0;
    threads.emplace_back(clientTest, sessionId, true, useTLS, servicePort);
    threads.emplace_back(clientTest, sessionId, false, useTLS,servicePort);

    for (std::thread& threadIt : threads)
    {
        threadIt.join();
    }
}

}

What this sample does ? It runs SSL relay server on localhost port 777 which connects two clients and allows exchanging of text messages between them.

Promblem: When I run that sample server returns error "errorCode= 167772441, message=decryption failed or bad record mac (SSL routines)" in void "asyncReadNextMessage(uint32_t messageSize)" I found out this is caused by client which reads and writes to client SSL socket from separate threads (changing variable useTLS to 0 runs it on normal socket which proves that it is SSL socket problem). Apparently TLS is not full-duplex protocol (I did not know about that). I can't synchronize access to read and write with mutex because when socket enters read state and there is no incoming message writing to socked will be blocked forever. At this thread Boost ASIO, SSL: How do strands help the implementation? someone recommended using strands but someone else wrote that asio only synchronizes not concurrent execution of read and write handles which does not fix the problem.

I expect that somehow there is a way to synchronize read and write to SSL socket. I'm 100% sure that problem lies in synchronizing read and writes to socket because when I wrote example with read and write to socket done by one thread it worked. However then client always expects that there is message to read which can block all write if there is not. Can it be solved without using separate sockets for reads and writes ?

0x90h
  • 21
  • 3
  • *"Apparently TLS is not full-duplex protocol"* - it is. This is not a problem of the protocol itself but of a specific implementation. OpenSSL has a structure about the current TLS state which need to be updated both when reading and writing - so one need to do the usual protection common for shared resources in multi-threading. – Steffen Ullrich Nov 14 '22 at 20:40
  • Ok so If I invoke ioContext.run() on only one thread inside client and use asio::async_write and async_read instead of asio::read and asio::write this should work beacuse use of implicit strand but if it does not it means it is not possible with asio ? – 0x90h Nov 14 '22 at 21:20
  • I'm not familiar with the asio internals. But note the information about threads at the end of [the asio SSL documentation](https://www.boost.org/doc/libs/1_75_0/doc/html/boost_asio/overview/ssl.html) – Steffen Ullrich Nov 14 '22 at 21:23

1 Answers1

2

Okay I figured it out by writting many diffrent samples of code including SSL sockets. When asio::io_context is already running you can't simply schedule asio::async_write or asio::async_read from thread which is not associated with strand connected to that socket.

So when there is: asio::async_write(*this->socketSSL, asio::buffer(pBuffer, bufferSize), asio::bind_executor(readWriteStrand,writeMessageLambda)); but thread which is executing is not running from readWriteStrand strand then it should be written as: asio::post(ioContext, asio::bind_executor(readWriteStrand, [&]() {asio::async_read(*this->socketSSL, asio::buffer(readBuffer.data() + this->readDataIt, messageSize), asio::bind_executor(readWriteStrand, readMessageLambda)); }));

0x90h
  • 21
  • 3