6

I discovered this while solving Problem 205 of Project Euler. The problem is as follows:

Peter has nine four-sided (pyramidal) dice, each with faces numbered 1, 2, 3, 4. Colin has six six-sided (cubic) dice, each with faces numbered 1, 2, 3, 4, 5, 6.

Peter and Colin roll their dice and compare totals: the highest total wins. The result is a draw if the totals are equal.

What is the probability that Pyramidal Pete beats Cubic Colin? Give your answer rounded to seven decimal places in the form 0.abcdefg

I wrote a naive solution using Guava:

import com.google.common.collect.Sets;
import com.google.common.collect.ImmutableSet;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.*;
import java.util.stream.Collectors;

public class Problem205 {
    public static void main(String[] args) {
        long startTime = System.currentTimeMillis();
        List<Integer> peter = Sets.cartesianProduct(Collections.nCopies(9, ImmutableSet.of(1, 2, 3, 4)))
                .stream()
                .map(l -> l
                        .stream()
                        .mapToInt(Integer::intValue)
                        .sum())
                .collect(Collectors.toList());
        List<Integer> colin = Sets.cartesianProduct(Collections.nCopies(6, ImmutableSet.of(1, 2, 3, 4, 5, 6)))
                .stream()
                .map(l -> l
                        .stream()
                        .mapToInt(Integer::intValue)
                        .sum())
                .collect(Collectors.toList());

        long startTime2 = System.currentTimeMillis();
        // IMPORTANT BIT HERE! v
        long solutions = peter
                .stream()
                .mapToLong(p -> colin
                        .stream()
                        .filter(c -> p > c)
                        .count())
                .sum();

        // IMPORTANT BIT HERE! ^
        System.out.println("Counting solutions took " + (System.currentTimeMillis() - startTime2) + "ms");

        System.out.println("Solution: " + BigDecimal
                .valueOf(solutions)
                .divide(BigDecimal
                                .valueOf((long) Math.pow(4, 9) * (long) Math.pow(6, 6)),
                        7,
                        RoundingMode.HALF_UP));
        System.out.println("Found in: " + (System.currentTimeMillis() - startTime) + "ms");
    }
}

The code I have highlighted, which uses a simple filter(), count() and sum(), seems to run much faster in Java 9 than Java 8. Specifically, Java 8 counts the solutions in 37465ms on my machine. Java 9 does it in about 16000ms, which is the same whether I run the file compiled with Java 8 or one compiled with Java 9.

If I replace the streams code with what would seem to be the exact pre-streams equivalent:

long solutions = 0;
for (Integer p : peter) {
    long count = 0;
    for (Integer c : colin) {
        if (p > c) {
            count++;
        }
    }
    solutions += count;
}

It counts the solutions in about 35000ms, with no measurable difference between Java 8 and Java 9.

What am I missing here? Why is the streams code so much faster in Java 9, and why isn't the for loop?


I am running Ubuntu 16.04 LTS 64-bit. My Java 8 version:

java version "1.8.0_131"
Java(TM) SE Runtime Environment (build 1.8.0_131-b11)
Java HotSpot(TM) 64-Bit Server VM (build 25.131-b11, mixed mode)

My Java 9 version:

java version "9"
Java(TM) SE Runtime Environment (build 9+181)
Java HotSpot(TM) 64-Bit Server VM (build 9+181, mixed mode)
bcsb1001
  • 2,834
  • 3
  • 24
  • 35
  • 3
    If you're not using a micro-benchmarking test harness, such as JMH, you're doing it wrong and your conclusions about the execution speed may be wrong. – scottb Oct 07 '17 at 00:27
  • @scottb I did consider that, but I don't see how some code consistently running twice as fast in Java 9 than in Java 8 could be solely caused by lazy benchmarking. There has to be some good reason for it. – bcsb1001 Oct 07 '17 at 00:30
  • 3
    Java 9 has a bunch of [JVM enhancements](https://docs.oracle.com/javase/9/whatsnew/toc.htm#JSNEW-GUID-0564E449-0601-4EEC-B130-73ABD83074AC), such as improved caching and dynamic linking. – Mick Mnemonic Oct 07 '17 at 00:33
  • 2
    *"... but I don't see how some code consistently running twice as fast in Java 9 than in Java 8 could be solely caused by lazy benchmarking."* - Just because you can't see how doesn't mean that it can't happen. – Stephen C Oct 07 '17 at 00:36
  • 2
    Yep. Sounds like a benchmarking issue, not an actual performance difference. – Louis Wasserman Oct 07 '17 at 00:39
  • 2
    Do you have something _against_ better performance (perceived or otherwise)? – Kevin Anderson Oct 07 '17 at 02:14
  • 3
    The argument about poor benchmarking doesn't apply here, since the code isn't a microbenchmark - it is a complete standalone application that solves the certain task. So, the real question here is why JIT behaves so strange in this particular example. – apangin Oct 07 '17 at 19:18

1 Answers1

20

1. Why the stream works faster on JDK 9

Stream.count() implementation is rather dumb in JDK 8: it just iterates through the whole stream adding 1L for each element.

This was fixed in JDK 9. Even though the bug report says about SIZED streams, new code improves non-sized streams, too.

If you replace .count() with Java 8-style implementation .mapToLong(e -> 1L).sum(), it will be slow again even on JDK 9.

2. Why naive loop works slow

When you put all your code in main method, it cannot be JIT-compiled efficiently. This method is executed only once, it starts running in interpreter and later, when JVM detects a hot loop, it switches from interpreted mode to compiled on-the-go. This is called on-stack replacement (OSR).

OSR compilations are often not as optimized as regular compiled methods. I've explained this in detail earlier, see this and this answer.

JIT will produce better code if you put the inner loop in a separate method:

    long solutions = 0;
    for (Integer p : peter) {
        solutions += countLargerThan(colin, p);
    }

    ...

    private static int countLargerThan(List<Integer> colin, int p) {
        int count = 0;
        for (Integer c : colin) {
            if (p > c) {
                count++;
            }
        }
        return count;
    }

In this case countLargerThan method will be compiled normally, and the performance will be better than with streams both on JDK 8 and on JDK 9.

apangin
  • 92,924
  • 10
  • 193
  • 247
  • Which improvements were made for `count()` on non `SIZED` streams? – Holger Oct 09 '17 at 09:08
  • 4
    @Holger New [`TerminalOp`](http://hg.openjdk.java.net/jdk9/jdk9/jdk/file/65464a307408/src/java.base/share/classes/java/util/stream/ReduceOps.java#l236) with a specialized sink. Now there is no intermediate mapping operation in the pipeline, and there is no extra indirection in reduce. – apangin Oct 09 '17 at 13:20