Network connections in Java appear to all ultimately use SocketImpl
to perform the actual network access. These are created by the Socket
and ServerSocket
classes.
Therefore, network access can be blocked by replacing the SocketImpl
used. Specifically, the replacement SocketImpl
should block any external network traffic, while delegating any local network traffic to the standard SocketImpl
.
To block the SocketImpl
globally, a JUnit 5 Extension
could be written to replace the socket factories used by Socket
and ServerSocket
:
public class DisableRemoteSocketsExtension implements BeforeAllCallback {
private static final AtomicBoolean APPLIED = new AtomicBoolean(false);
@Override
public void beforeAll(ExtensionContext context) throws Exception {
if (!APPLIED.get()) {
System.out.println("Globally disabling non-loopback sockets");
Socket.setSocketImplFactory(DisableRemoteSocketImpl::forSocket);
ServerSocket.setSocketFactory(DisableRemoteSocketImpl::forServerSocket);
APPLIED.set(true);
}
}
}
To ensure that this extension is always used, and not require the developer to annotate each test with @ExtendWith
, this extension can be registered automatically using the automatic extension registration functionality:
src/test/resources/META-INF/services/org.junit.jupiter.api.extension.Extension
com.example.DisableRemoteSocketsExtension
src/test/resources/junit-platform.properties
junit.jupiter.extensions.autodetection.enabled=true
As a number of the classes and methods involved are package-private to the java.net
package, reflective workarounds are needed to delegate to the underlying SocketImpl
methods in order to forward local traffic. Here is one possible implementation:
public class DisableRemoteSocketImpl extends SocketImpl {
private static final Pattern LOOPBACK_PATTERN =
Pattern.compile("^localhost$|^127(?:\\.[0-9]+){0,2}\\.[0-9]+$|^(?:0*:)*?:?0*1$");
private final SocketImpl delegate;
public DisableRemoteSocketImpl(SocketImpl delegate) {
this.delegate = delegate;
}
public static DisableRemoteSocketImpl forSocket() {
// Mimics implementation in Socket.setImpl()
SocketImpl delegate = newSocksSocketImpl(getDefaultSocketImpl());
return new DisableRemoteSocketImpl(delegate);
}
public static DisableRemoteSocketImpl forServerSocket() {
// Mimics implementation in ServerSocket.setImpl()
SocketImpl delegate = getDefaultSocketImpl();
return new DisableRemoteSocketImpl(delegate);
}
@Override
protected void create(boolean stream) throws IOException {
callDelegate("create", new Class<?>[]{boolean.class}, new Object[]{stream});
}
@Override
protected void connect(String host, int port) throws IOException {
requireLoopbackAddress(host);
callDelegate("connect", new Class<?>[]{String.class, int.class}, new Object[]{host, port});
}
@Override
protected void connect(InetAddress address, int port) throws IOException {
requireLoopbackAddress(address);
callDelegate("connect", new Class<?>[]{InetAddress.class, int.class}, new Object[]{address, port});
}
@Override
protected void connect(SocketAddress address, int timeout) throws IOException {
if (!(address instanceof InetSocketAddress)) {
throw new UnsupportedOperationException("Unsupported address type: " + address);
}
requireLoopbackAddress(((InetSocketAddress) address).getHostString());
callDelegate("connect", new Class<?>[]{SocketAddress.class, int.class}, new Object[]{address, timeout});
}
@Override
protected void bind(InetAddress host, int port) throws IOException {
requireLoopbackAddress(host);
callDelegate("bind", new Class<?>[]{InetAddress.class, int.class}, new Object[]{host, port});
}
@Override
protected void listen(int backlog) throws IOException {
callDelegate("listen", new Class<?>[]{int.class}, new Object[]{backlog});
}
@Override
protected void accept(SocketImpl s) throws IOException {
callDelegate("accept", new Class<?>[]{SocketImpl.class}, new Object[]{s});
}
@Override
protected InputStream getInputStream() throws IOException {
return callDelegate("getInputStream", new Class<?>[]{}, new Object[]{});
}
@Override
protected OutputStream getOutputStream() throws IOException {
return callDelegate("getOutputStream", new Class<?>[]{}, new Object[]{});
}
@Override
protected int available() throws IOException {
return callDelegate("available", new Class<?>[]{}, new Object[]{});
}
@Override
protected void close() throws IOException {
callDelegate("close", new Class<?>[]{}, new Object[]{});
}
@Override
protected void sendUrgentData(int data) throws IOException {
callDelegate("close", new Class<?>[]{int.class}, new Object[]{data});
}
@Override
public void setOption(int optID, Object value) throws SocketException {
delegate.setOption(optID, value);
}
@Override
public Object getOption(int optID) throws SocketException {
return delegate.getOption(optID);
}
private void requireLoopbackAddress(String host) {
if (!LOOPBACK_PATTERN.matcher(host).matches()) {
throw new UnsupportedOperationException("Attempted to connect to remote host: " + host);
}
}
private void requireLoopbackAddress(InetAddress address) {
if (!address.isLoopbackAddress()) {
throw new UnsupportedOperationException("Attempted to connect to remote host: " + address);
}
}
private static SocketImpl createDelegate() {
// Mimics implementation in Socket.setImpl()
return newSocksSocketImpl(getDefaultSocketImpl());
}
private static SocketImpl getDefaultSocketImpl() {
try {
Method factoryMethod = SocketImpl.class.getDeclaredMethod("createPlatformSocketImpl", Boolean.TYPE);
factoryMethod.setAccessible(true);
return (SocketImpl) factoryMethod.invoke(null, false);
} catch (ReflectiveOperationException e) {
throw new RuntimeException(e);
}
}
private static SocketImpl newSocksSocketImpl(SocketImpl delegate) {
try {
Constructor<? extends SocketImpl> constructor = Class.forName("java.net.SocksSocketImpl")
.asSubclass(SocketImpl.class).getDeclaredConstructor(SocketImpl.class);
constructor.setAccessible(true);
return constructor.newInstance(delegate);
} catch (ReflectiveOperationException e) {
throw new RuntimeException(e);
}
}
@SuppressWarnings("unchecked")
private <T> T callDelegate(String methodName, Class<?>[] parameterTypes, Object[] args) throws IOException {
try {
Method method = SocketImpl.class.getDeclaredMethod(methodName, parameterTypes);
method.setAccessible(true);
return (T) method.invoke(delegate, args);
} catch (InvocationTargetException invocationTargetException) {
Throwable e = invocationTargetException.getCause();
if (e instanceof IOException) {
throw (IOException) e;
} else if (e instanceof RuntimeException) {
throw (RuntimeException) e;
} else if (e instanceof Error) {
throw (Error) e;
} else {
throw new AssertionError(invocationTargetException);
}
} catch (ReflectiveOperationException e) {
throw new RuntimeException(e);
}
}
}