30

This is essentially the same question as How to short-circuit reduce on Stream?. However, since that question focuses on a Stream of boolean values, and its answer cannot be generalized for other types and reduce operations, I'd like to ask the more general question.

How can we make a reduce on a stream so that it short-circuits when it encounters an absorbing element for the reducing operation?

The typical mathematical case would be 0 for multiplication. This Stream :

int product = IntStream.of(2, 3, 4, 5, 0, 7, 8)
        .reduce(1, (a, b) -> a * b);

will consume the last two elements (7 and 8) regardless of the fact that once 0 has been encountered the product is known.

Community
  • 1
  • 1
bowmore
  • 10,842
  • 1
  • 35
  • 43
  • 1
    Do you think that this is so common that it is worth adding a *conditional* to every multiplication? You’ll need lots of subsequent multiplications to compensate that. And then it might turn out that hotspot is smarter than you when it comes to multiplications in a loop… – Holger Sep 10 '15 at 08:11
  • @Holger I admit this is mostly an academic question, but I can think of at least a few other reductions that may possibly short circuit. (bitwise and vs 0, bitwise or vs 0xffff, fully filled partial data vs. merging partial data, ...) – bowmore Sep 10 '15 at 08:33
  • 5
    @Holger, more examples where short-circuit reduction might be helpful: intersecting the stream of sets (cancel once the intermediate result is empty), union of the stream of EnumSets (cancel once the intermediate result contains all possible values), joining the stream of strings with string length limit (adding ellipsis to the end if necessary). – Tagir Valeev Sep 10 '15 at 09:24
  • 3
    Since a `takeWhile` operation will [get added](http://stackoverflow.com/a/32304570/452775) to `Stream` in Java 9 these kind of things will get easier in the future. Whenever Java 9 is released... – Lii Sep 10 '15 at 14:14

4 Answers4

14

Unfortunately the Stream API has limited capabilities to create your own short-circuit operations. Not so clean solution would be to throw a RuntimeException and catch it. Here's the implementation for IntStream, but it can be generalized for other stream types as well:

public static int reduceWithCancelEx(IntStream stream, int identity, 
                      IntBinaryOperator combiner, IntPredicate cancelCondition) {
    class CancelException extends RuntimeException {
        private final int val;

        CancelException(int val) {
            this.val = val;
        }
    }

    try {
        return stream.reduce(identity, (a, b) -> {
            int res = combiner.applyAsInt(a, b);
            if(cancelCondition.test(res))
                throw new CancelException(res);
            return res;
        });
    } catch (CancelException e) {
        return e.val;
    }
}

Usage example:

int product = reduceWithCancelEx(
        IntStream.of(2, 3, 4, 5, 0, 7, 8).peek(System.out::println), 
        1, (a, b) -> a * b, val -> val == 0);
System.out.println("Result: "+product);

Output:

2
3
4
5
0
Result: 0

Note that even though it works with parallel streams, it's not guaranteed that other parallel tasks will be finished as soon as one of them throws an exception. The sub-tasks which are already started will likely to run till finish, so you may process more elements than expected.

Update: alternative solution which is much longer, but more parallel-friendly. It's based on custom spliterator which returns at most one element which is result of accumulation of all underlying elements). When you use it in sequential mode, it does all the work in single tryAdvance call. When you split it, each part generates the correspoding single partial result, which are reduced by Stream engine using the combiner function. Here's generic version, but primitive specialization is possible as well.

final static class CancellableReduceSpliterator<T, A> implements Spliterator<A>,
        Consumer<T>, Cloneable {
    private Spliterator<T> source;
    private final BiFunction<A, ? super T, A> accumulator;
    private final Predicate<A> cancelPredicate;
    private final AtomicBoolean cancelled = new AtomicBoolean();
    private A acc;

    CancellableReduceSpliterator(Spliterator<T> source, A identity,
            BiFunction<A, ? super T, A> accumulator, Predicate<A> cancelPredicate) {
        this.source = source;
        this.acc = identity;
        this.accumulator = accumulator;
        this.cancelPredicate = cancelPredicate;
    }

    @Override
    public boolean tryAdvance(Consumer<? super A> action) {
        if (source == null || cancelled.get()) {
            source = null;
            return false;
        }
        while (!cancelled.get() && source.tryAdvance(this)) {
            if (cancelPredicate.test(acc)) {
                cancelled.set(true);
                break;
            }
        }
        source = null;
        action.accept(acc);
        return true;
    }

    @Override
    public void forEachRemaining(Consumer<? super A> action) {
        tryAdvance(action);
    }

    @Override
    public Spliterator<A> trySplit() {
        if(source == null || cancelled.get()) {
            source = null;
            return null;
        }
        Spliterator<T> prefix = source.trySplit();
        if (prefix == null)
            return null;
        try {
            @SuppressWarnings("unchecked")
            CancellableReduceSpliterator<T, A> result = 
                (CancellableReduceSpliterator<T, A>) this.clone();
            result.source = prefix;
            return result;
        } catch (CloneNotSupportedException e) {
            throw new InternalError();
        }
    }

    @Override
    public long estimateSize() {
        // let's pretend we have the same number of elements
        // as the source, so the pipeline engine parallelize it in the same way
        return source == null ? 0 : source.estimateSize();
    }

    @Override
    public int characteristics() {
        return source == null ? SIZED : source.characteristics() & ORDERED;
    }

    @Override
    public void accept(T t) {
        this.acc = accumulator.apply(this.acc, t);
    }
}

Methods which are analogous to Stream.reduce(identity, accumulator, combiner) and Stream.reduce(identity, combiner), but with cancelPredicate:

public static <T, U> U reduceWithCancel(Stream<T> stream, U identity,
        BiFunction<U, ? super T, U> accumulator, BinaryOperator<U> combiner,
        Predicate<U> cancelPredicate) {
    return StreamSupport
            .stream(new CancellableReduceSpliterator<>(stream.spliterator(), identity,
                    accumulator, cancelPredicate), stream.isParallel()).reduce(combiner)
            .orElse(identity);
}

public static <T> T reduceWithCancel(Stream<T> stream, T identity,
        BinaryOperator<T> combiner, Predicate<T> cancelPredicate) {
    return reduceWithCancel(stream, identity, combiner, combiner, cancelPredicate);
}

Let's test both versions and count how many elements are actually processed. Let's put the 0 close to end. Exception version:

AtomicInteger count = new AtomicInteger();
int product = reduceWithCancelEx(
        IntStream.range(-1000000, 100).filter(x -> x == 0 || x % 2 != 0)
                .parallel().peek(i -> count.incrementAndGet()), 1,
        (a, b) -> a * b, x -> x == 0);
System.out.println("product: " + product + "/count: " + count);
Thread.sleep(1000);
System.out.println("product: " + product + "/count: " + count);

Typical output:

product: 0/count: 281721
product: 0/count: 500001

So while result is returned when only some elements are processed, the tasks continue working in background and counter is still increasing. Here's spliterator version:

AtomicInteger count = new AtomicInteger();
int product = reduceWithCancel(
        IntStream.range(-1000000, 100).filter(x -> x == 0 || x % 2 != 0)
                .parallel().peek(i -> count.incrementAndGet()).boxed(), 
                1, (a, b) -> a * b, x -> x == 0);
System.out.println("product: " + product + "/count: " + count);
Thread.sleep(1000);
System.out.println("product: " + product + "/count: " + count);

Typical output:

product: 0/count: 281353
product: 0/count: 281353

All the tasks are actually finished when the result is returned.

Tagir Valeev
  • 97,161
  • 19
  • 222
  • 334
  • I think the update code will work, but at the moment copy pasting the code has problems. – bowmore Sep 10 '15 at 21:40
  • @bowmore, which problems? Probably some imports missing? – Tagir Valeev Sep 11 '15 at 00:56
  • Never mind, turns out my JDK 8 at home wasn't the latest yet. Type inference issues resolved when I installed the latest. – bowmore Sep 11 '15 at 06:12
  • Ah, I see. Javac 1.8.0_25 crashes with `NullPointerException` on this class. Very funny. – Tagir Valeev Sep 11 '15 at 06:17
  • For most scenarios, bailing out with an exception will be much more expensive than iterating and processing the entire stream… – Holger Nov 04 '16 at 10:18
  • 1
    @Holger, I actually did much deeper research after this post and cannot agree with you. If you disable exception stack trace, the throwing usually adds constant delay of ~200-300 ns which quite often not so much compared to the whole stream operation. Iterating via `forEachRemaining` is usually faster than with `tryAdvance` as the latter has to maintain the state, so in many cases even replacing existing short-circuiting operation with `forEachRemaining`+throw could add a performance gain (not to mention streams containing `flatMap`). – Tagir Valeev Nov 15 '16 at 10:24
  • 1
    “*If you disable exception stack trace*”—I think, we could stop right here. In your answer, stack traces aren’t disabled, but anyway, 200 to 300ns still are quite expensive compared to the two multiplications that it saves in the OP’s scenario. I would even go so far to say that the introduced conditional branching, executed four times in this scenario, alone might be more expensive than the saved two multiplications. I didn’t say that there are no beneficial scenarios at all, but examples like the OP’s clearly show, how most of the time the potential saving is heavily overrated… – Holger Nov 15 '16 at 11:07
5

A general short-circuiting static reduce method can be implemented using the spliterator of a stream. It even turned out to be not very complicated! Using spliterators seems to be the way to go a lot of times when one wants to work with steams in a more flexible way.

public static <T> T reduceWithCancel(Stream<T> s, T acc, BinaryOperator<T> op, Predicate<? super T> cancelPred) {
    BoxConsumer<T> box = new BoxConsumer<T>();
    Spliterator<T> splitr = s.spliterator();

    while (!cancelPred.test(acc) && splitr.tryAdvance(box)) {
        acc = op.apply(acc, box.value);
    }

    return acc;
}

public static class BoxConsumer<T> implements Consumer<T> {
    T value = null;
    public void accept(T t) {
        value = t;
    }
}

Usage:

    int product = reduceWithCancel(
        Stream.of(1, 2, 0, 3, 4).peek(System.out::println),
        1, (acc, i) -> acc * i, i -> i == 0);

    System.out.println("Result: " + product);

Output:

1
2
0
Result: 0

The method could be generalised to perform other kinds of terminal operations.

This is based loosely on this answer about a take-while operation.

I don't know anything about the parallelisation potential of this.

Community
  • 1
  • 1
Lii
  • 11,553
  • 8
  • 64
  • 88
  • 2
    Note that in my solution `cancelPredicate` tested the result of reduction, not the next element. In this case it's actually better (for example, `65536*65536 == 0` in Java, though neither of arguments is zero). Your answer can be easily adapted to do the same. I have a spliterator-based idea which is parallel-friendly, but need some time to code it correctly... – Tagir Valeev Sep 10 '15 at 09:07
  • 1
    @TagirValeev: Making this run in parallel would be an interesting exercise in spliterators. Please post if you finished anything! – Lii Sep 10 '15 at 09:20
  • 1
    @TagirValeev: *"Note that in my solution cancelPredicate tested the result of reduction, not the next element."* On a second thought I think this is better too. Not just in this case but in general. The operation could result in a value on which you'd want to break. I've edited that into the answer. It also saved me a line! – Lii Sep 10 '15 at 09:49
  • 1
    Proof-of-concept parallel-friendly solution added to my answer. Probably I will add something based on it to my library... – Tagir Valeev Sep 10 '15 at 10:10
4

My own take at this is to not use reduce() per se, but use an existing short-circuiting final operation.

noneMatch() or allMatch() can be used for this when using a Predicate with a side effect. Admittedly also not the cleanest solution, but it does achieve the goal :

AtomicInteger product = new AtomicInteger(1);
IntStream.of(2, 3, 4, 5, 0, 7, 8)
        .peek(System.out::println)
        .noneMatch(i -> {
            if (i == 0) {
                product.set(0);
                return true;
            }
            int oldValue = product.get();
            while (oldValue != 0 && !product.compareAndSet(oldValue, i * oldValue)) {
                oldValue = product.get();
            }
            return oldValue == 0;
        });
System.out.println("Result: " + product.get());

It short-circuits and can be made parallel.

bowmore
  • 10,842
  • 1
  • 35
  • 43
  • Interesting solution, though it works with commutative combiners only. Usually commutativity is not required in Stream API for combiner functions. Nevertheless upvoted. I generalized this for object streams and custom identity/combiner/predicate methods, it's a little bit faster than my spliterator solution. – Tagir Valeev Sep 11 '15 at 02:18
  • Well I fixed a small mistake in `estimateSize()` of my spliterator (returning `Long.MAX_VALUE` caused stream engine to produce much more parallel tasks than necessary). Now my parallel version works times faster than generalized version of yours, though sequential is still somewhat slower than yours. – Tagir Valeev Sep 11 '15 at 06:20
2

this is how it is done after the introduction of takeWhile
since Java 9

int[] last = {1};
int product = IntStream.of(2, 3, 4, 5, 0, 7, 8)
    .takeWhile(i -> last[0] != 0).reduce(1, (a, b) -> (last[0] = a) * b);
Kaplan
  • 2,572
  • 13
  • 14