The proper solution is probably to use NIO in some manner. I already commented on how Hadoop did it here using nio underneath.
But the simpler solution is in Dexter's answer. I also came across an answer from EJP who suggest to use a BufferedOutputStream
to control when data goes out. So I combined the two to arrive at the TimedOutputStream
shown below. It does not give complete control on output buffering to remote (much of it is done by the OS), but combining an appropriate buffer size and write timeout provides at least some control (see the second program for testing the TimedOutputStream
).
I have not completely tested the TimedOutputStream
, so do your own due diligence.
Edit: updated write-method for better correlation between buffer-size and write timeout, also tweaked test program. Added comments about non-safe async close of socket outputstream.
import java.io.*;
import java.util.concurrent.*;
/**
* A {@link BufferedOutputStream} that sets time-out tasks on write operations
* (typically when the buffer is flushed). If a write timeout occurs, the underlying outputstream is closed
* (which may not be appropriate when sockets are used, see also comments on {@link TimedOutputStream#interruptWriteOut}).
* A {@link ScheduledThreadPoolExecutor} is required to schedule the time-out tasks.
* This {@code ScheduledThreadPoolExecutor} should have {@link ScheduledThreadPoolExecutor#setRemoveOnCancelPolicy(boolean)}
* set to {@code true} to prevent a huge task queue.
* If no {@code ScheduledThreadPoolExecutor} is provided in the constructor,
* the executor is created and shutdown with the {@link #close()} method.
* @author vanOekel
*
*/
public class TimedOutputStream extends FilterOutputStream {
protected int timeoutMs = 50_000;
protected final boolean closeExecutor;
protected final ScheduledExecutorService executor;
protected ScheduledFuture<?> timeoutTask;
protected volatile boolean writeTimedout;
protected volatile IOException writeTimeoutCloseException;
/* *** new methods not in BufferedOutputStream *** */
/**
* Default timeout is 50 seconds.
*/
public void setTimeoutMs(int timeoutMs) {
this.timeoutMs = timeoutMs;
}
public int getTimeoutMs() {
return timeoutMs;
}
public boolean isWriteTimeout() {
return writeTimedout;
}
/**
* If a write timeout occurs and closing the underlying output-stream caused an exception,
* then this method will return a non-null IOException.
*/
public IOException getWriteTimeoutCloseException() {
return writeTimeoutCloseException;
}
public ScheduledExecutorService getScheduledExecutor() {
return executor;
}
/**
* See {@link BufferedOutputStream#close()}.
*/
@Override
public void close() throws IOException {
try {
super.close(); // calls flush via FilterOutputStream.
} finally {
if (closeExecutor) {
executor.shutdownNow();
}
}
}
/* ** Mostly a copy of java.io.BufferedOutputStream and updated with time-out options. *** */
protected byte buf[];
protected int count;
public TimedOutputStream(OutputStream out) {
this(out, null);
}
public TimedOutputStream(OutputStream out, ScheduledExecutorService executor) {
this(out, 8192, executor);
}
public TimedOutputStream(OutputStream out, int size) {
this(out, size, null);
}
public TimedOutputStream(OutputStream out, int size, ScheduledExecutorService executor) {
super(out);
if (size <= 0) {
throw new IllegalArgumentException("Buffer size <= 0");
}
if (executor == null) {
this.executor = Executors.newScheduledThreadPool(1);
ScheduledThreadPoolExecutor stp = (ScheduledThreadPoolExecutor) this.executor;
stp.setRemoveOnCancelPolicy(true);
closeExecutor = true;
} else {
this.executor = executor;
closeExecutor = false;
}
buf = new byte[size];
}
/**
* Flushbuffer is called by all the write-methods and "flush()".
*/
protected void flushBuffer(boolean flushOut) throws IOException {
if (count > 0 || flushOut) {
timeoutTask = executor.schedule(new TimeoutTask(this), getTimeoutMs(), TimeUnit.MILLISECONDS);
try {
// long start = System.currentTimeMillis(); int len = count;
if (count > 0) {
out.write(buf, 0, count);
count = 0;
}
if (flushOut) {
out.flush(); // in case out is also buffered, this will do the actual write.
}
// System.out.println(Thread.currentThread().getName() + " Write [" + len + "] " + (flushOut ? "and flush " : "") + "time: " + (System.currentTimeMillis() - start));
} finally {
timeoutTask.cancel(false);
}
}
}
protected class TimeoutTask implements Runnable {
protected final TimedOutputStream tout;
public TimeoutTask(TimedOutputStream tout) {
this.tout = tout;
}
@Override public void run() {
tout.interruptWriteOut();
}
}
/**
* Closes the outputstream after a write timeout.
* If sockets are used, calling {@link java.net.Socket#shutdownOutput()} is probably safer
* since the behavior of an async close of the outputstream is undefined.
*/
protected void interruptWriteOut() {
try {
writeTimedout = true;
out.close();
} catch (IOException e) {
writeTimeoutCloseException = e;
}
}
/**
* See {@link BufferedOutputStream#write(int b)}
*/
@Override
public void write(int b) throws IOException {
if (count >= buf.length) {
flushBuffer(false);
}
buf[count++] = (byte)b;
}
/**
* Like {@link BufferedOutputStream#write(byte[], int, int)}
* but with one big difference: the full buffer is always written
* to the underlying outputstream. Large byte-arrays are chopped
* into buffer-size pieces and writtten out piece by piece.
* <br>This provides a closer relation to the write timeout
* and the maximum (buffer) size of the write-operation to wait on.
*/
@Override
public void write(byte b[], int off, int len) throws IOException {
if (count >= buf.length) {
flushBuffer(false);
}
if (len <= buf.length - count) {
System.arraycopy(b, off, buf, count, len);
count += len;
} else {
final int fill = buf.length - count;
System.arraycopy(b, off, buf, count, fill);
count += fill;
flushBuffer(false);
final int remaining = len - fill;
int start = off + fill;
for (int i = 0; i < remaining / buf.length; i++) {
System.arraycopy(b, start, buf, count, buf.length);
count = buf.length;
flushBuffer(false);
start += buf.length;
}
count = remaining % buf.length;
System.arraycopy(b, start, buf, 0, count);
}
}
/**
* See {@link BufferedOutputStream#flush()}
* <br>If a write timeout occurred (i.e. {@link #isWriteTimeout()} returns {@code true}),
* then this method does nothing.
*/
@Override
public void flush() throws IOException {
// Protect against flushing before closing after a write-timeout.
// If that happens, then "out" is already closed in interruptWriteOut.
if (!isWriteTimeout()) {
flushBuffer(true);
}
}
}
And the test program:
import java.io.*;
import java.net.*;
import java.util.concurrent.*;
public class TestTimedSocketOut implements Runnable, Closeable {
public static void main(String[] args) {
TestTimedSocketOut m = new TestTimedSocketOut();
try {
m.run();
} finally {
m.close();
}
}
final int clients = 3; // 2 is minimum, client 1 is expected to fail.
final int timeOut = 1000;
final int bufSize = 4096;
final long maxWait = 5000L;
// need a large array to write, else the OS just buffers everything and makes it work
byte[] largeMsg = new byte[28_602];
final ThreadPoolExecutor tp = (ThreadPoolExecutor) Executors.newCachedThreadPool();
final ScheduledThreadPoolExecutor stp = (ScheduledThreadPoolExecutor) Executors.newScheduledThreadPool(1);
final ConcurrentLinkedQueue<Closeable> closeables = new ConcurrentLinkedQueue<Closeable>();
final CountDownLatch[] serversReady = new CountDownLatch[clients];
final CountDownLatch clientsDone = new CountDownLatch(clients);
final CountDownLatch serversDone = new CountDownLatch(clients);
ServerSocket ss;
int port;
@Override public void run() {
stp.setRemoveOnCancelPolicy(true);
try {
ss = new ServerSocket();
ss.bind(null);
port = ss.getLocalPort();
tp.execute(new SocketAccept());
for (int i = 0; i < clients; i++) {
serversReady[i] = new CountDownLatch(1);
ClientSideSocket css = new ClientSideSocket(i);
closeables.add(css);
tp.execute(css);
// need sleep to ensure client 0 connects first.
Thread.sleep(50L);
}
if (!clientsDone.await(maxWait, TimeUnit.MILLISECONDS)) {
println("CLIENTS DID NOT FINISH");
} else {
if (!serversDone.await(maxWait, TimeUnit.MILLISECONDS)) {
println("SERVERS DID NOT FINISH");
} else {
println("Finished");
}
}
} catch (Exception e) {
e.printStackTrace();
}
}
@Override public void close() {
try { if (ss != null) ss.close(); } catch (Exception ignored) {}
Closeable c = null;
while ((c = closeables.poll()) != null) {
try { c.close(); } catch (Exception ignored) {}
}
tp.shutdownNow();
println("Scheduled tasks executed: " + stp.getTaskCount() + ", max. threads: " + stp.getLargestPoolSize());
stp.shutdownNow();
}
class SocketAccept implements Runnable {
@Override public void run() {
try {
for (int i = 0; i < clients; i++) {
SeverSideSocket sss = new SeverSideSocket(ss.accept(), i);
closeables.add(sss);
tp.execute(sss);
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
class SeverSideSocket implements Runnable, Closeable {
Socket s;
int number, cnumber;
boolean completed;
public SeverSideSocket(Socket s, int number) {
this.s = s;
this.number = number;
cnumber = -1;
}
@Override public void run() {
String t = "nothing";
try {
DataInputStream in = new DataInputStream(s.getInputStream());
DataOutputStream out = new DataOutputStream(s.getOutputStream());
serversReady[number].countDown();
Thread.sleep(timeOut);
t = in.readUTF();
in.readFully(new byte[largeMsg.length], 0, largeMsg.length);
t += in.readUTF();
out.writeByte(1);
out.flush();
cnumber = in.readInt();
completed = true;
} catch (Exception e) {
println("server side " + number + " stopped after " + e);
// e.printStackTrace();
} finally {
println("server side " + number + " received: " + t);
if (completed && cnumber != number) {
println("server side " + number + " expected client number " + number + " but got " + cnumber);
}
close();
serversDone.countDown();
}
}
@Override public void close() {
TestTimedSocketOut.close(s);
s = null;
}
}
class ClientSideSocket implements Runnable, Closeable {
Socket s;
int number;
public ClientSideSocket(int number) {
this.number = number;
}
@SuppressWarnings("resource")
@Override public void run() {
Byte b = -1;
TimedOutputStream tout = null;
try {
s = new Socket();
s.connect(new InetSocketAddress(port));
DataInputStream in = new DataInputStream(s.getInputStream());
tout = new TimedOutputStream(s.getOutputStream(), bufSize, stp);
if (number == 1) {
// expect fail
tout.setTimeoutMs(timeOut / 2);
} else {
// expect all OK
tout.setTimeoutMs(timeOut * 2);
}
DataOutputStream out = new DataOutputStream(tout);
if (!serversReady[number].await(maxWait, TimeUnit.MILLISECONDS)) {
throw new RuntimeException("Server side for client side " + number + " not ready.");
}
out.writeUTF("client side " + number + " starting transfer");
out.write(largeMsg);
out.writeUTF(" - client side " + number + " completed transfer");
out.flush();
b = in.readByte();
out.writeInt(number);
out.flush();
} catch (Exception e) {
println("client side " + number + " stopped after " + e);
// e.printStackTrace();
} finally {
println("client side " + number + " result: " + b);
if (tout != null) {
if (tout.isWriteTimeout()) {
println("client side " + number + " had write timeout, close exception: " + tout.getWriteTimeoutCloseException());
} else {
println("client side " + number + " had no write timeout");
}
}
close();
clientsDone.countDown();
}
}
@Override public void close() {
TestTimedSocketOut.close(s);
s = null;
}
}
private static void close(Socket s) {
try { if (s != null) s.close(); } catch (Exception ignored) {}
}
private static final long START_TIME = System.currentTimeMillis();
private static void println(String msg) {
System.out.println((System.currentTimeMillis() - START_TIME) + "\t " + msg);
}
}