It appears the solution is to combine Bruno's suggestion and Paulo's solution.
Paulo's solution allows us to customize the behavior of our SSLSocket or SSLServerSocket using delegates.
Bruno's suggestion allows us to tell the default SSL implementation to use our modified SSLSocket or SSLServerSocket.
Here is what I did :
- Create a delegate ServerSocket class ( MyServerSocket )
- Create a delegate ServerSocketFactory class (MyServerSocketFactory)
- Create a delegate SocketFactory class (MySocketFactory)
- Create a delegate Socket class (MySocket)
- Create XorInputStream (find it here)
- Create XorOutputStream (find it here)
On the server side :
// Initialisation as usual
...
sslSocketFactory = sslContext.getSocketFactory();
serverSocketFactory = ServerSocketFactory.getDefault();
serverSocketFactory = new MyServerSocketFactory(serverSocketFactory);
serverSocket = serverSocketFactory.createServerSocket(port);
...
Socket s = (Socket) serverSocket.accept();
sslSocket = (SSLSocket) sslSocketFactory.createSocket(s, null, s.getPort(), false);
sslSocket.setUseClientMode(false);
sslSocket.setEnabledCipherSuites(new String[]{"SSL_RSA_WITH_RC4_128_MD5"});
sslSocket.setNeedClientAuth(true);
...
On the client side:
Socket s = new MySocketFactory(SocketFactory.getDefault()).createSocket(host, port);
SSLSocket socket = (SSLSocket) factory.createSocket(s, host, port, false);
Sources
public class MyServerSocket extends ServerSocket {
private ServerSocket baseSocket;
public MyServerSocket(ServerSocket baseSocket) throws IOException {
this.baseSocket = baseSocket;
}
@Override
public Socket accept() throws IOException {
return new MySocket(baseSocket.accept());
}
@Override
public void bind(SocketAddress endpoint) throws IOException {
baseSocket.bind(endpoint);
}
@Override
public void bind(SocketAddress endpoint, int backlog) throws IOException {
baseSocket.bind(endpoint, backlog);
}
@Override
public void close() throws IOException {
baseSocket.close();
}
@Override
public ServerSocketChannel getChannel() {
return baseSocket.getChannel();
}
@Override
public InetAddress getInetAddress() {
return baseSocket.getInetAddress();
}
@Override
public int getLocalPort() {
return baseSocket.getLocalPort();
}
@Override
public SocketAddress getLocalSocketAddress() {
return baseSocket.getLocalSocketAddress();
}
@Override
public synchronized int getReceiveBufferSize() throws SocketException {
return baseSocket.getReceiveBufferSize();
}
@Override
public boolean getReuseAddress() throws SocketException {
return baseSocket.getReuseAddress();
}
@Override
public synchronized int getSoTimeout() throws IOException {
return baseSocket.getSoTimeout();
}
@Override
public boolean isBound() {
return baseSocket.isBound();
}
@Override
public boolean isClosed() {
return baseSocket.isClosed();
}
@Override
public void setPerformancePreferences(int connectionTime, int latency, int bandwidth) {
baseSocket.setPerformancePreferences(connectionTime, latency, bandwidth);
}
@Override
public synchronized void setReceiveBufferSize(int size) throws SocketException {
baseSocket.setReceiveBufferSize(size);
}
@Override
public void setReuseAddress(boolean on) throws SocketException {
baseSocket.setReuseAddress(on);
}
@Override
public synchronized void setSoTimeout(int timeout) throws SocketException {
baseSocket.setSoTimeout(timeout);
}
@Override
public String toString() {
return baseSocket.toString();
}
}
public class MyServerSocketFactory extends ServerSocketFactory {
private ServerSocketFactory baseFactory;
public MyServerSocketFactory(ServerSocketFactory baseFactory) {
this.baseFactory = baseFactory;
}
@Override
public ServerSocket createServerSocket(int i) throws IOException {
return new MyServerSocket(baseFactory.createServerSocket(i));
}
@Override
public ServerSocket createServerSocket(int i, int i1) throws IOException {
return new MyServerSocket(baseFactory.createServerSocket(i, i1));
}
@Override
public ServerSocket createServerSocket(int i, int i1, InetAddress ia) throws IOException {
return new MyServerSocket(baseFactory.createServerSocket(i, i1, ia));
}
}
public class MySocket extends Socket {
private Socket baseSocket;
public MySocket(Socket baseSocket) {
this.baseSocket = baseSocket;
}
private XorInputStream xorInputStream = null;
private XorOutputStream xorOutputStream = null;
private final byte pattern = (byte)0xAC;
@Override
public InputStream getInputStream() throws IOException {
if (xorInputStream == null)
{
xorInputStream = new XorInputStream(baseSocket.getInputStream(), pattern);
}
return xorInputStream;
}
@Override
public OutputStream getOutputStream() throws IOException {
if (xorOutputStream == null)
{
xorOutputStream = new XorOutputStream(baseSocket.getOutputStream(), pattern);
}
return xorOutputStream;
}
@Override
public void bind(SocketAddress bindpoint) throws IOException {
baseSocket.bind(bindpoint);
}
@Override
public synchronized void close() throws IOException {
baseSocket.close();
}
@Override
public void connect(SocketAddress endpoint) throws IOException {
baseSocket.connect(endpoint);
}
@Override
public void connect(SocketAddress endpoint, int timeout) throws IOException {
baseSocket.connect(endpoint, timeout);
}
@Override
public SocketChannel getChannel() {
return baseSocket.getChannel();
}
@Override
public InetAddress getInetAddress() {
return baseSocket.getInetAddress();
}
@Override
public boolean getKeepAlive() throws SocketException {
return baseSocket.getKeepAlive();
}
@Override
public InetAddress getLocalAddress() {
return baseSocket.getLocalAddress();
}
@Override
public int getLocalPort() {
return baseSocket.getLocalPort();
}
@Override
public SocketAddress getLocalSocketAddress() {
return baseSocket.getLocalSocketAddress();
}
@Override
public boolean getOOBInline() throws SocketException {
return baseSocket.getOOBInline();
}
@Override
public int getPort() {
return baseSocket.getPort();
}
@Override
public synchronized int getReceiveBufferSize() throws SocketException {
return baseSocket.getReceiveBufferSize();
}
@Override
public SocketAddress getRemoteSocketAddress() {
return baseSocket.getRemoteSocketAddress();
}
@Override
public boolean getReuseAddress() throws SocketException {
return baseSocket.getReuseAddress();
}
@Override
public synchronized int getSendBufferSize() throws SocketException {
return baseSocket.getSendBufferSize();
}
@Override
public int getSoLinger() throws SocketException {
return baseSocket.getSoLinger();
}
@Override
public synchronized int getSoTimeout() throws SocketException {
return baseSocket.getSoTimeout();
}
@Override
public boolean getTcpNoDelay() throws SocketException {
return baseSocket.getTcpNoDelay();
}
@Override
public int getTrafficClass() throws SocketException {
return baseSocket.getTrafficClass();
}
@Override
public boolean isBound() {
return baseSocket.isBound();
}
@Override
public boolean isClosed() {
return baseSocket.isClosed();
}
@Override
public boolean isConnected() {
return baseSocket.isConnected();
}
@Override
public boolean isInputShutdown() {
return baseSocket.isInputShutdown();
}
@Override
public boolean isOutputShutdown() {
return baseSocket.isOutputShutdown();
}
@Override
public void sendUrgentData(int data) throws IOException {
baseSocket.sendUrgentData(data);
}
@Override
public void setKeepAlive(boolean on) throws SocketException {
baseSocket.setKeepAlive(on);
}
@Override
public void setOOBInline(boolean on) throws SocketException {
baseSocket.setOOBInline(on);
}
@Override
public void setPerformancePreferences(int connectionTime, int latency, int bandwidth) {
baseSocket.setPerformancePreferences(connectionTime, latency, bandwidth);
}
@Override
public synchronized void setReceiveBufferSize(int size) throws SocketException {
baseSocket.setReceiveBufferSize(size);
}
@Override
public void setReuseAddress(boolean on) throws SocketException {
baseSocket.setReuseAddress(on);
}
@Override
public synchronized void setSendBufferSize(int size) throws SocketException {
baseSocket.setSendBufferSize(size);
}
@Override
public void setSoLinger(boolean on, int linger) throws SocketException {
baseSocket.setSoLinger(on, linger);
}
@Override
public synchronized void setSoTimeout(int timeout) throws SocketException {
baseSocket.setSoTimeout(timeout);
}
@Override
public void setTcpNoDelay(boolean on) throws SocketException {
baseSocket.setTcpNoDelay(on);
}
@Override
public void setTrafficClass(int tc) throws SocketException {
baseSocket.setTrafficClass(tc);
}
@Override
public void shutdownInput() throws IOException {
baseSocket.shutdownInput();
}
@Override
public void shutdownOutput() throws IOException {
baseSocket.shutdownOutput();
}
@Override
public String toString() {
return baseSocket.toString();
}
}
public class MySocketFactory extends SocketFactory {
private SocketFactory baseFactory;
public MySocketFactory(SocketFactory baseFactory) {
this.baseFactory = baseFactory;
}
@Override
public Socket createSocket() throws IOException {
return baseFactory.createSocket();
}
@Override
public boolean equals(Object obj) {
return baseFactory.equals(obj);
}
@Override
public int hashCode() {
return baseFactory.hashCode();
}
@Override
public String toString() {
return baseFactory.toString();
}
@Override
public Socket createSocket(String string, int i) throws IOException, UnknownHostException {
return new MySocket(baseFactory.createSocket(string, i));
}
@Override
public Socket createSocket(String string, int i, InetAddress ia, int i1) throws IOException, UnknownHostException {
return baseFactory.createSocket(string, i, ia, i1);
}
@Override
public Socket createSocket(InetAddress ia, int i) throws IOException {
return baseFactory.createSocket(ia, i);
}
@Override
public Socket createSocket(InetAddress ia, int i, InetAddress ia1, int i1) throws IOException {
return baseFactory.createSocket(ia, i, ia1, i1);
}
}