The most convenient multi-threading paradigm for a Merge Sort is the fork-join paradigm. This is provided from Java 8 and later. The following code demonstrates a Merge Sort using a fork-join.
import java.util.*;
import java.util.concurrent.*;
public class MergeSort<N extends Comparable<N>> extends RecursiveTask<List<N>> {
private List<N> elements;
public MergeSort(List<N> elements) {
this.elements = new ArrayList<>(elements);
}
@Override
protected List<N> compute() {
if(this.elements.size() <= 1)
return this.elements;
else {
final int pivot = this.elements.size() / 2;
MergeSort<N> leftTask = new MergeSort<N>(this.elements.subList(0, pivot));
MergeSort<N> rightTask = new MergeSort<N>(this.elements.subList(pivot, this.elements.size()));
leftTask.fork();
rightTask.fork();
List<N> left = leftTask.join();
List<N> right = rightTask.join();
return merge(left, right);
}
}
private List<N> merge(List<N> left, List<N> right) {
List<N> sorted = new ArrayList<>();
while(!left.isEmpty() || !right.isEmpty()) {
if(left.isEmpty())
sorted.add(right.remove(0));
else if(right.isEmpty())
sorted.add(left.remove(0));
else {
if( left.get(0).compareTo(right.get(0)) < 0 )
sorted.add(left.remove(0));
else
sorted.add(right.remove(0));
}
}
return sorted;
}
public static void main(String[] args) {
ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
List<Integer> result = forkJoinPool.invoke(new MergeSort<Integer>(Arrays.asList(7,2,9,10,1)));
System.out.println("result: " + result);
}
}
While much less straight forward the following varient of the code eliminates the excessive copying of the ArrayList. The initial unsorted list is only created once, and the calls to sublist do not need to perform any copying themselves. Before we would copy the array list each time the algorithm forked. Also, now, when merging lists instead of creating a new list and copying values in it each time we reuse the left list and insert our values into there. By avoiding the extra copy step we improve performance. We use a LinkedList here because inserts are rather cheap compared to an ArrayList. We also eliminate the call to remove, which can be expensive on an ArrayList as well.
import java.util.*;
import java.util.concurrent.*;
public class MergeSort<N extends Comparable<N>> extends RecursiveTask<List<N>> {
private List<N> elements;
public MergeSort(List<N> elements) {
this.elements = elements;
}
@Override
protected List<N> compute() {
if(this.elements.size() <= 1)
return new LinkedList<>(this.elements);
else {
final int pivot = this.elements.size() / 2;
MergeSort<N> leftTask = new MergeSort<N>(this.elements.subList(0, pivot));
MergeSort<N> rightTask = new MergeSort<N>(this.elements.subList(pivot, this.elements.size()));
leftTask.fork();
rightTask.fork();
List<N> left = leftTask.join();
List<N> right = rightTask.join();
return merge(left, right);
}
}
private List<N> merge(List<N> left, List<N> right) {
int leftIndex = 0;
int rightIndex = 0;
while(leftIndex < left.size() || rightIndex < right.size()) {
if(leftIndex >= left.size())
left.add(leftIndex++, right.get(rightIndex++));
else if(rightIndex >= right.size())
return left;
else {
if( left.get(leftIndex).compareTo(right.get(rightIndex)) < 0 )
leftIndex++;
else
left.add(leftIndex++, right.get(rightIndex++));
}
}
return left;
}
public static void main(String[] args) {
ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
List<Integer> result = forkJoinPool.invoke(new MergeSort<Integer>(Arrays.asList(7,2,9,-7,777777,10,1)));
System.out.println("result: " + result);
}
}
We can also improve the code one step further by using iterators instead of calling get directly when performing the merge. The reason for this is that get on a LinkedList by index has poor time performance (linear) so by using an iterator we eliminate slow-down caused by internally iterating the linked list on each get. The call to next on an iterator is constant time as opposed to linear time for the call to get. The following code is modified to use iterators instead.
import java.util.*;
import java.util.concurrent.*;
public class MergeSort<N extends Comparable<N>> extends RecursiveTask<List<N>> {
private List<N> elements;
public MergeSort(List<N> elements) {
this.elements = elements;
}
@Override
protected List<N> compute() {
if(this.elements.size() <= 1)
return new LinkedList<>(this.elements);
else {
final int pivot = this.elements.size() / 2;
MergeSort<N> leftTask = new MergeSort<N>(this.elements.subList(0, pivot));
MergeSort<N> rightTask = new MergeSort<N>(this.elements.subList(pivot, this.elements.size()));
leftTask.fork();
rightTask.fork();
List<N> left = leftTask.join();
List<N> right = rightTask.join();
return merge(left, right);
}
}
private List<N> merge(List<N> left, List<N> right) {
ListIterator<N> leftIter = left.listIterator();
ListIterator<N> rightIter = right.listIterator();
while(leftIter.hasNext() || rightIter.hasNext()) {
if(!leftIter.hasNext()) {
leftIter.add(rightIter.next());
rightIter.remove();
}
else if(!rightIter.hasNext())
return left;
else {
N rightElement = rightIter.next();
if( leftIter.next().compareTo(rightElement) < 0 )
rightIter.previous();
else {
leftIter.previous();
leftIter.add(rightElement);
}
}
}
return left;
}
public static void main(String[] args) {
ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
List<Integer> result = forkJoinPool.invoke(new MergeSort<Integer>(Arrays.asList(7,2,9,-7,777777,10,1)));
System.out.println("result: " + result);
}
}
Finally the most complex versions of the code, this iteration uses an entirely in-place operation. Only the initial ArrayList is created and no additional collections are ever created. As such the logic is particularly difficult to follow (so i saved it for last). But should be as close to an ideal implementation as we can get.
import java.util.*;
import java.util.concurrent.*;
public class MergeSort<N extends Comparable<N>> extends RecursiveTask<List<N>> {
private List<N> elements;
public MergeSort(List<N> elements) {
this.elements = elements;
}
@Override
protected List<N> compute() {
if(this.elements.size() <= 1)
return this.elements;
else {
final int pivot = this.elements.size() / 2;
MergeSort<N> leftTask = new MergeSort<N>(this.elements.subList(0, pivot));
MergeSort<N> rightTask = new MergeSort<N>(this.elements.subList(pivot, this.elements.size()));
leftTask.fork();
rightTask.fork();
List<N> left = leftTask.join();
List<N> right = rightTask.join();
merge(left, right);
return this.elements;
}
}
private void merge(List<N> left, List<N> right) {
int leftIndex = 0;
int rightIndex = 0;
while(leftIndex < left.size() ) {
if(rightIndex == 0) {
if( left.get(leftIndex).compareTo(right.get(rightIndex)) > 0 ) {
swap(left, leftIndex++, right, rightIndex++);
} else {
leftIndex++;
}
} else {
if(rightIndex >= right.size()) {
if(right.get(0).compareTo(left.get(left.size() - 1)) < 0 )
merge(left, right);
else
return;
}
else if( right.get(0).compareTo(right.get(rightIndex)) < 0 ) {
swap(left, leftIndex++, right, 0);
} else {
swap(left, leftIndex++, right, rightIndex++);
}
}
}
if(rightIndex < right.size() && rightIndex != 0)
merge(right.subList(0, rightIndex), right.subList(rightIndex, right.size()));
}
private void swap(List<N> left, int leftIndex, List<N> right, int rightIndex) {
//N leftElement = left.get(leftIndex);
left.set(leftIndex, right.set(rightIndex, left.get(leftIndex)));
}
public static void main(String[] args) {
ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
List<Integer> result = forkJoinPool.invoke(new MergeSort<Integer>(new ArrayList<>(Arrays.asList(5,9,8,7,6,1,2,3,4))));
System.out.println("result: " + result);
}
}