12

I've written code for the Project Euler's Challenge 14, in both Haskell and C++ (ideone links). They both remember any calculations they have previously done in an array.

Using ghc -O2 and g++ -O3 respectively, the C++ runs 10-15 times faster than the Haskell version.

Whilst I understand the Haskell version may run slower, and that Haskell is a nicer language to write in, it would be nice to know some code changes I can make to the Haskell version to make it run faster (ideally within a factor of 2 or 3 of the C++ version)?


Haskell code is here:

import Data.Array
import Data.Word
import Data.List

collatz_array = 
  let
    upperbound = 1000000
    a = array (1, upperbound) [(i :: Word64, f i :: Int) | i <- [1..upperbound]]
    f i = i `seq`
      let
        check_f i = i `seq` if i <= upperbound then a ! i else f i
      in
        if (i == 1) then 0 else (check_f ((if (even i) then i else 3 * i + 1) `div` 2)) + 1
  in a

main = 
  putStrLn $ show $ 
   foldl1' (\(x1,x2) (y1,y2) -> if (x2 >= y2) then (x1, x2) else (y1, y2)) $! (assocs collatz_array)

Edit:

I've now also done a version using unboxed mutable arrays. It is still 5 times slower than the C++ version, but a significant improvement. The code is on ideone here.

I'd like to know improvements to the mutable array version which bring it closer to the C++ version.

Martijn Pieters
  • 1,048,767
  • 296
  • 4,058
  • 3,343
Clinton
  • 22,361
  • 15
  • 67
  • 163
  • Just FYI, compiling with `-fllvm` improves performance by ~10% on my machine. – Greg E. Jun 04 '12 at 05:49
  • Your `seq` make no difference; both your functions are strict in `i`. GHC used to be quite bad at 64-bit arithmetic on 32-bit platforms, but I don't know what platform you're using. – augustss Jun 04 '12 at 09:24
  • 1
    doesn't explain your performance problem, but neither you C++ (at least what nanothief posted) nor Haskell code produces the correct answer. I can't compile your C++, but have a pure Haskell solution about the same length as your code that is around 25% faster on my machine, and produces the correct result. At this point about half the time looks like overhead associated with starting a Haskell program. – Philip JF Jun 04 '12 at 10:04
  • @PhilipJF In how far doesn't the code produce the correct result? Note that Clinton uses a slightly different step, namely for odd `n`, he goes directly to `(3*n+1)/2` instead of taking two steps for that. Thus he gets different chain lengths, but the starting points of the longest chains are the same. – Daniel Fischer Jun 04 '12 at 11:04
  • @DanielFischer Exactly, the problem description describes chain length where (3n+1)/2 increase length by 2. He has the right starting point, but the wrong length – Philip JF Jun 04 '12 at 19:24
  • @PhilipJF But fortunately, only the starting point is needed. Now, to prove that both ways of counting chain length always produce the same value as the starting point is a different problem (I'm not sure that that always holds, for example consider 7 and 44; but it may be that there are always starts of longer chains between such pairs). – Daniel Fischer Jun 04 '12 at 19:33

2 Answers2

4

Some problems with your (mutable array) code:

  • You use a fold to find the maximal chain length, for that the array has to be converted to an association list, that takes time and allocation the C++ version doesn't need.
  • You use even and div for testing resp dividing by 2. These are slow. g++ optimises both operations to the faster bit operations (on platforms where that is supposedly faster, at least), but GHC doesn't do these low-level optimisations (yet), so for the time being, they have to be done by hand.
  • You use readArray and writeArray. The extra bounds-checking that isn't done in the C++ code also takes time, once the other problems are dealt with, that amounts to a significant portion of the running time (ca. 25% on my box), since there are done a lot of reads and writes in the algorithm.

Incorporating that into the implementation, I get

import Data.Array.ST
import Data.Array.Base
import Control.Monad.ST
import Data.Bits

collatz_array :: ST s (STUArray s Int Int)
collatz_array = do
    let upper = 10000000
    arr <- newArray (0,upper) 0
    unsafeWrite arr 2 1
    let check i
            | upper < i = return arr
            | i .&. 1 == 0 = do
                l <- unsafeRead arr (i `shiftR` 1)
                unsafeWrite arr i (l+1)
                check (i+1)
            | otherwise = do
                let j = (3*i+1) `shiftR` 1
                    find k l
                        | upper < k = find (next k) $! l+1
                        | k < i     = do
                            m <- unsafeRead arr k
                            return (m+l)
                        | otherwise = do
                            m <- unsafeRead arr k
                            if m == 0
                              then do
                                  n <- find (next k) 1
                                  unsafeWrite arr k n
                                  return (n+l)
                              else return (m+l)
                          where
                            next h
                                | h .&. 1 == 0 = h `shiftR` 1
                                | otherwise = (3*h+1) `shiftR` 1
                l <- find j 1
                unsafeWrite arr i l
                check (i+1)
    check 3

collatz_max :: ST s (Int,Int)
collatz_max = do
    car <- collatz_array
    (_,upper) <- getBounds car
    let find w m i
            | upper < i = return (w,m)
            | otherwise = do
                l <- unsafeRead car i
                if m < l
                  then find i l (i+1)
                  else find w m (i+1)
    find 1 0 2

main :: IO ()
main = print (runST collatz_max)

And the timings (both for 10 million):

$ time ./cccoll
8400511 429

real    0m0.210s
user    0m0.200s
sys     0m0.009s
$ time ./stcoll
(8400511,429)

real    0m0.341s
user    0m0.307s
sys     0m0.033s

which doesn't look too bad.

Important note: That code only works on 64-bit GHC (so, in particular, on Windows, you need ghc-7.6.1 or later, previous GHCs were 32-bit even on 64-bit Windows) since intermediate chain elements exceed 32-bit range. On 32-bit systems, one would have to use Integer or a 64-bit integer type (Int64 or Word64) for following the chains, at a drastic performance cost, since the primitive 64-bit operations (arithmetic and shifts) are implemented as foreign calls to C functions in 32-bit GHCs (fast foreign calls, but still much slower than direct machine ops).

Daniel Fischer
  • 181,706
  • 17
  • 308
  • 431
  • `(3*h+1) `shiftR` 1` is the same as `(shiftR h 1) + h + 1` which may be faster on some machines – Philip JF Jun 04 '12 at 19:31
  • Indeed. Doesn't produce a reliably measurable difference on mine, so if there's a difference, it's smaller than the natural jittering here. But on machines with slow multiplication, that's definitely something to try. – Daniel Fischer Jun 04 '12 at 19:43
2

The ideone site is using a ghc 6.8.2, which is getting pretty old. On ghc version 7.4.1, the difference is much smaller.

With ghc:

$ ghc -O2 euler14.hs && time ./euler14
(837799,329)
./euler14  0.63s user 0.04s system 98% cpu 0.685 total

With g++ 4.7.0:

$ g++ --std=c++0x -O3 euler14.cpp && time ./a.out
8400511 429
./a.out  0.24s user 0.01s system 99% cpu 0.252 total

For me, the ghc version is only 2.7 times slower than the c++ version. Also, the two programs aren't giving the same result... (not a good sign, especially for benchmarking)

David Miani
  • 14,518
  • 2
  • 47
  • 66
  • Oops, I posted the 10 million, not 1 million test. The link is corrected. Note the c++ version did 10 million 2.7 times faster than Haskell did 1 million. – Clinton Jun 04 '12 at 07:35