I have been trying to implement parallel merge sort in Scala. But with 8 cores, using .sorted
is still about twice as fast.
edit:
I rewrote most of the code to minimize object creation. Now it runs about as fast as the .sorted
Input file with 1.2M integers:
- 1.333580 seconds (my implementation)
- 1.439293 seconds (
.sorted
)
How should I parallelize this?
New implementation
object Mergesort extends App
{
//=====================================================================================================================
// UTILITY
implicit object comp extends Ordering[Any] {
def compare(a: Any, b: Any) = {
(a, b) match {
case (a: Int, b: Int) => a compare b
case (a: String, b: String) => a compare b
case _ => 0
}
}
}
//=====================================================================================================================
// MERGESORT
val THRESHOLD = 30
def inssort[A](a: Array[A], left: Int, right: Int): Array[A] = {
for (i <- (left+1) until right) {
var j = i
val item = a(j)
while (j > left && comp.lt(item,a(j-1))) {
a(j) = a(j-1)
j -= 1
}
a(j) = item
}
a
}
def mergesort_merge[A](a: Array[A], temp: Array[A], left: Int, right: Int, mid: Int) : Array[A] = {
var i = left
var j = right
while (i < mid) { temp(i) = a(i); i+=1; }
while (j > mid) { temp(i) = a(j-1); i+=1; j-=1; }
i = left
j = right-1
var k = left
while (k < right) {
if (comp.lt(temp(i), temp(j))) { a(k) = temp(i); i+=1; k+=1; }
else { a(k) = temp(j); j-=1; k+=1; }
}
a
}
def mergesort_split[A](a: Array[A], temp: Array[A], left: Int, right: Int): Array[A] = {
if (right-left == 1) a
if ((right-left) > THRESHOLD) {
val mid = (left+right)/2
mergesort_split(a, temp, left, mid)
mergesort_split(a, temp, mid, right)
mergesort_merge(a, temp, left, right, mid)
}
else
inssort(a, left, right)
}
def mergesort[A: ClassTag](a: Array[A]): Array[A] = {
val temp = new Array[A](a.size)
mergesort_split(a, temp, 0, a.size)
}
Previous implementation
Input file with 1.2M integers:
- 4.269937 seconds (my implementation)
- 1.831767 seconds (
.sorted
)
What sort of tricks there are to make it faster and cleaner?
object Mergesort extends App
{
//=====================================================================================================================
// UTILITY
val StartNano = System.nanoTime
def dbg(msg: String) = println("%05d DBG ".format(((System.nanoTime - StartNano)/1e6).toInt) + msg)
def time[T](work: =>T) = {
val start = System.nanoTime
val res = work
println("%f seconds".format((System.nanoTime - start)/1e9))
res
}
implicit object comp extends Ordering[Any] {
def compare(a: Any, b: Any) = {
(a, b) match {
case (a: Int, b: Int) => a compare b
case (a: String, b: String) => a compare b
case _ => 0
}
}
}
//=====================================================================================================================
// MERGESORT
def merge[A](left: List[A], right: List[A]): Stream[A] = (left, right) match {
case (x :: xs, y :: ys) if comp.lteq(x, y) => x #:: merge(xs, right)
case (x :: xs, y :: ys) => y #:: merge(left, ys)
case _ => if (left.isEmpty) right.toStream else left.toStream
}
def sort[A](input: List[A], length: Int): List[A] = {
if (length < 100) return input.sortWith(comp.lt)
input match {
case Nil | List(_) => input
case _ =>
val middle = length / 2
val (left, right) = input splitAt middle
merge(sort(left, middle), sort(right, middle + length%2)).toList
}
}
def msort[A](input: List[A]): List[A] = sort(input, input.length)
//=====================================================================================================================
// PARALLELIZATION
//val cores = Runtime.getRuntime.availableProcessors
//dbg("Detected %d cores.".format(cores))
//lazy implicit val ec = ExecutionContext.fromExecutorService(Executors.newFixedThreadPool(cores))
def futuremerge[A](fa: Future[List[A]], fb: Future[List[A]])(implicit order: Ordering[A], ec: ExecutionContext) =
{
for {
a <- fa
b <- fb
} yield merge(a, b).toList
}
def parallel_msort[A](input: List[A], length: Int)(implicit order: Ordering[A]): Future[List[A]] = {
val middle = length / 2
val (left, right) = input splitAt middle
if(length > 500) {
val fl = parallel_msort(left, middle)
val fr = parallel_msort(right, middle + length%2)
futuremerge(fl, fr)
}
else {
Future(msort(input))
}
}
//=====================================================================================================================
// MAIN
val results = time({
val src = Source.fromFile("in.txt").getLines
val header = src.next.split(" ").toVector
val lines = if (header(0) == "i") src.map(_.toInt).toList else src.toList
val f = parallel_msort(lines, lines.length)
Await.result(f, concurrent.duration.Duration.Inf)
})
println("Sorted as comparison...")
val sorted_src = Source.fromFile(input_folder+"in.txt").getLines
sorted_src.next
time(sorted_src.toList.sorted)
val writer = new PrintWriter("out.txt", "UTF-8")
try writer.print(results.mkString("\n"))
finally writer.close
}