I am trying to make a socket chat in Python for my coursework in university. I need to encrypt messages for communication between multiple clients and send it to them.
My encryption and message sending scheme is:
I generate a packet using the pickle library using dumps(), inside it I encrypt the message, then I send the packet to the server, there it is decrypted using the loads() method, then I'm using dumps() again and sending the packet to all recipients, and the message is decrypted there
Here is my problem:
When I make a connection between two clients, everything is fine, but when I use three or more clients, something strange happens. Messages from the first sender get through fine, but when the second client sends the message, one of the three clients throws a decryption error
'utf-8' codec can't decode byte 0x8b in position 1: invalid start byte
How it works:
Here is the scheme: We have client #1, client #2, client #3. Client #1 (#2, #3, doesn't matter) send the message (1,2,3,... messages) and everything is fine. Next client (#2 for example) to send message (sorry for my english) is sending it and then one of the other clients (#1 or #3) disconnects from the server with this error.
The strangest thing about this is that the encryption proceeds normally until a third or more client connects
Here is my encryptor class (I use a Cryptodome library for this):
class Encryptor:
def __init__(self) -> None:
self.key = b'\xd2P\x05\x0b\xd5\x8e\xa2&#!\xe9\x80k\x17\xc7V'
self.iv = b'!\xc5\x1b\xca\xe7)\x89\xc0\xf8\x9e;\x0c\xf3H\xb3)'
self.cipher = AES.new(self.key, AES.MODE_CBC, self.iv)
self.d_cipher = AES.new(self.key, AES.MODE_CBC, self.iv)
def pad(self, message) -> str:
return message + ((16 - len(message) % 16) * "{")
def encrypt(self, message) -> bytes:
return self.cipher.encrypt(self.pad(message).encode(ENCODING))
def decrypt(self, message) -> str:
decrypted_msg = self.d_cipher.decrypt(message).decode(ENCODING)
decrypted_msg = decrypted_msg.replace('{', '')
return decrypted_msg
I am using the pickle library to pass a dictionary over sockets, here is the client side code, (SERVER_ADDRESS) is the (IP, PORT) tuple where IP is socket.gethostbyname(socket.gethostname())
and PORT is 5050 for example, and ENCODING == "utf-8"
:
def __init__(self) -> None:
self._encryptor = Encryptor()
self._is_connected = False
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
def send_username(self, username) -> None:
self._username = username.encode(ENCODING)
self._socket.send(self._username)
def connect(self, username) -> None:
self._socket.connect(SERVER_ADDRESS)
self.send_username(username)
def send_data(self, data) -> None:
enc_data = self._encryptor.encrypt(data['data'])
data["data"] = enc_data
data = pickle.dumps(data)
self._socket.send(data)
def handle_messages(self) -> None:
while True:
packet = self._socket.recv(8192)
if packet:
packet = pickle.loads(packet)
if not packet:
print('\r' + "Disconnecting from the server")
self._socket.close()
try:
sys.exit(0)
except SystemExit:
os._exit(0)
username = packet['username'].decode(ENCODING)
dec_message = self._encryptor.decrypt(packet['data'])
print(f"\r[{username}] {dec_message}")
print("[YOU] ", end="", flush=True)
client = Client()
try:
client.connect("username")
except Exception as e:
print("Server is offline. Try again later")
print(e)
client._is_connected = True
thread = threading.Thread(target=client.handle_messages)
thread.start()
while client._is_connected:
msg = input()
packet = {
"type": "message",
"username": client._username,
"data": msg
}
client.send_data(packet)
Here is the server side:
def __init__(self) -> None:
self._clients = {}
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self._socket.bind(SERVER_ADDRESS)
def send_all(self, conn, packet) -> None:
packet = pickle.dumps(packet)
for client in self._clients:
if client != conn:
client.send(packet)
def handle_client(self, conn, addr) -> None:
username = conn.recv(8192).decode(ENCODING)
self._clients[conn] = username
is_connected = True
try:
while is_connected:
packet = conn.recv(8192)
if not packet:
break
else:
packet = pickle.loads(packet)
self.send_all(conn, packet)
finally:
print(f"Client {addr} has been disconnected")
self._clients.pop(conn)
conn.close()
def run_server(self) -> None:
self._socket.listen()
print(f"[LISTENING] Server is listening on {SERVER_IP}")
while True:
conn, addr = self._socket.accept()
thread = threading.Thread(target=self.handle_client,
args=(conn, addr))
thread.start()
Thank you so much.