I wanted to write an efficient implementation of the Floyd-Warshall all pairs shortest path algorithm in Haskell using Vector
s to hopefully get good performance.
The implementation is quite straight-forward, but instead of using a 3-dimensional |V|×|V|×|V| matrix, a 2-dimensional vector is used, since we only ever read the previous k
value.
Thus, the algorithm is really just a series of steps where a 2D vector is passed in, and a new 2D vector is generated. The final 2D vector contains the shortest paths between all nodes (i,j).
My intuition told me that it would be important to make sure that the previous 2D vector was evaluated before each step, so I used BangPatterns
on the prev
argument to the fw
function and the strict foldl'
:
{-# Language BangPatterns #-}
import Control.DeepSeq
import Control.Monad (forM_)
import Data.List (foldl')
import qualified Data.Map.Strict as M
import Data.Vector (Vector, (!), (//))
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as V hiding (length, replicate, take)
type Graph = Vector (M.Map Int Double)
type TwoDVector = Vector (Vector Double)
infinity :: Double
infinity = 1/0
-- calculate shortest path between all pairs in the given graph, if there are
-- negative cycles, return Nothing
allPairsShortestPaths :: Graph -> Int -> Maybe TwoDVector
allPairsShortestPaths g v =
let initial = fw g v V.empty 0
results = foldl' (fw g v) initial [1..v]
in if negCycle results
then Nothing
else Just results
where -- check for negative elements along the diagonal
negCycle a = any not $ map (\i -> a ! i ! i >= 0) [0..(V.length a-1)]
-- one step of the Floyd-Warshall algorithm
fw :: Graph -> Int -> TwoDVector -> Int -> TwoDVector
fw g v !prev k = V.create $ do -- ← bang
curr <- V.new v
forM_ [0..(v-1)] $ \i ->
V.write curr i $ V.create $ do
ivec <- V.new v
forM_ [0..(v-1)] $ \j -> do
let d = distance g prev i j k
V.write ivec j d
return ivec
return curr
distance :: Graph -> TwoDVector -> Int -> Int -> Int -> Double
distance g _ i j 0 -- base case; 0 if same vertex, edge weight if neighbours
| i == j = 0.0
| otherwise = M.findWithDefault infinity j (g ! i)
distance _ a i j k = let c1 = a ! i ! j
c2 = (a ! i ! (k-1))+(a ! (k-1) ! j)
in min c1 c2
However, when running this program with a 1000-node graph with 47978 edges, things does not look good at all. The memory usage is very high and the program takes way too long to run. The program was compiled with ghc -O2
.
I rebuilt the program for profiling, and limited the number of iterations to 50:
results = foldl' (fw g v) initial [1..50]
I then ran the program with +RTS -p -hc
and +RTS -p -hd
:
This is... interesting, but I guess it's showing that it's accumulating tonnes of thunks. Not good.
Ok, so after a few shots in the dark, I added a deepseq
in fw
to make sure prev
really is evaluted:
let d = prev `deepseq` distance g prev i j k
Now things look better, and I can actually run the program to completion with constant memory usage. It's obvious that the bang on the prev
argument was not enough.
For comparison with the previous graphs, here is the memory usage for 50 iterations after adding the deepseq
:
Ok, so things are better, but I still have some questions:
- Is this the correct solution for this space leak? I am wrong in feeling that inserting a
deepseq
is a bit ugly? - Is my usage of
Vector
s here idiomatic/correct? I'm building a completely new vector for every iteration and hoping that the garbage collector will delete the oldVector
s. - Is there any other things I could do to make this run faster with this approach?
For references, here is graph.txt
: http://sebsauvage.net/paste/?45147f7caf8c5f29#7tiCiPovPHWRm1XNvrSb/zNl3ujF3xB3yehrxhEdVWw=
Here is main
:
main = do
ls <- fmap lines $ readFile "graph.txt"
let numVerts = head . map read . words . head $ ls
let edges = map (map read . words) (tail ls)
let g = V.create $ do
g' <- V.new numVerts
forM_ [0..(numVerts-1)] (\idx -> V.write g' idx M.empty)
forM_ edges $ \[f,t,w] -> do
-- subtract one from vertex IDs so we can index directly
curr <- V.read g' (f-1)
V.write g' (f-1) $ M.insert (t-1) (fromIntegral w) curr
return g'
let a = allPairsShortestPaths g numVerts
case a of
Nothing -> putStrLn "Negative cycle detected."
Just a' -> do
putStrLn $ "The shortest, shortest path has length "
++ show ((V.minimum . V.map V.minimum) a')