39

What's a fast implementation of median in scala?

This is what I found on rosetta code:

  def median(s: Seq[Double])  =
  {
    val (lower, upper) = s.sortWith(_<_).splitAt(s.size / 2)
    if (s.size % 2 == 0) (lower.last + upper.head) / 2.0 else upper.head
  }

I don't like it because it does a sort. I know there are ways to compute the median in linear time.

EDIT:

I would like to have a set of median functions that I can use in various scenarios:

  1. fast, in place median computation that can be done in linear time
  2. median that works on a stream that you can traverse multiple times, but you can only keep O(log n) values in memory like this
  3. median that works on a stream, where you can hold at most O(log n) values in memory, and you can traverse the stream at most once (is this even possible?)

Please only post code that compiles and correctly computes the median. For simplicity, you may assume that all inputs contain an odd number of values.

Community
  • 1
  • 1
dsg
  • 12,924
  • 21
  • 67
  • 111
  • 1
    A "good" algorithm is much more complicated. Google for "Median of Medians" or "Median of five". – Landei Jan 11 '11 at 20:41
  • A well implemented (i.e. library) sorting algorithm might prove faster in your applications's reality than some implementation of some allegedly linear time algorithm. As for the above code, you might leave out the split and do indexed access instead, depending on the kind of Seq implementiation you assume. – Raphael Jan 17 '11 at 11:07
  • I don't think the third scenario is possible. Let's say I got numbers from 1000 to 1500, for example. The median is 1250. Now, if I start getting numbers below 1000, the median will decrease by one until it reaches 1000. Likewise, if I start getting numbers above 1500, the median will increase until it reaches 1500. So you need to keep all numbers seen so far. – Daniel C. Sobral Jan 21 '11 at 20:21
  • Quick Google search gave me [this](http://valis.cs.uiuc.edu/~sariel/research/CG/applets/linear_prog/median.html) and [this](http://en.wikipedia.org/wiki/Selection_algorithm). Basically, what you are looking for is the selection algorithm. Scala version left as an exercise for the reader. – Taylor Leese Jan 11 '11 at 20:36

1 Answers1

64

Immutable Algorithm

The first algorithm indicated by Taylor Leese is quadratic, but has linear average. That, however, depends on the pivot selection. So I'm providing here a version which has a pluggable pivot selection, and both the random pivot and the median of medians pivot (which guarantees linear time).

import scala.annotation.tailrec

@tailrec def findKMedian(arr: Array[Double], k: Int)(implicit choosePivot: Array[Double] => Double): Double = {
    val a = choosePivot(arr)
    val (s, b) = arr partition (a >)
    if (s.size == k) a
    // The following test is used to avoid infinite repetition
    else if (s.isEmpty) {
        val (s, b) = arr partition (a ==)
        if (s.size > k) a
        else findKMedian(b, k - s.size)
    } else if (s.size < k) findKMedian(b, k - s.size)
    else findKMedian(s, k)
}

def findMedian(arr: Array[Double])(implicit choosePivot: Array[Double] => Double) = findKMedian(arr, (arr.size - 1) / 2)

Random Pivot (quadratic, linear average), Immutable

This is the random pivot selection. Analysis of algorithms with random factors is trickier than normal, because it deals largely with probability and statistics.

def chooseRandomPivot(arr: Array[Double]): Double = arr(scala.util.Random.nextInt(arr.size))

Median of Medians (linear), Immutable

The median of medians method, which guarantees linear time when used with the algorithm above. First, and algorithm to compute the median of up to 5 numbers, which is the basis of the median of medians algorithm. This one was provided by Rex Kerr in this answer -- the algorithm depends a lot on the speed of it.

def medianUpTo5(five: Array[Double]): Double = {
  def order2(a: Array[Double], i: Int, j: Int) = {
    if (a(i)>a(j)) { val t = a(i); a(i) = a(j); a(j) = t }
  }

  def pairs(a: Array[Double], i: Int, j: Int, k: Int, l: Int) = {
    if (a(i)<a(k)) { order2(a,j,k); a(j) }
    else { order2(a,i,l); a(i) }
  }

  if (five.length < 2) return five(0)
  order2(five,0,1)
  if (five.length < 4) return (
    if (five.length==2 || five(2) < five(0)) five(0)
    else if (five(2) > five(1)) five(1)
    else five(2)
  )
  order2(five,2,3)
  if (five.length < 5) pairs(five,0,1,2,3)
  else if (five(0) < five(2)) { order2(five,1,4); pairs(five,1,4,2,3) }
  else { order2(five,3,4); pairs(five,0,1,3,4) }
}

And, then, the median of medians algorithm itself. Basically, it guarantees that the choosen pivot will be greater than at least 30% and smaller than other 30% of the list, which is enough to guarantee the linearity of the previous algorithm. Look up the wikipedia link provided in another answer for details.

def medianOfMedians(arr: Array[Double]): Double = {
    val medians = arr grouped 5 map medianUpTo5 toArray;
    if (medians.size <= 5) medianUpTo5 (medians)
    else medianOfMedians(medians)
}

In-place Algorithm

So, here's an in-place version of the algorithm. I'm using a class that implements a partition in-place, with a backing array, so that the changes to the algorithms are minimal.

case class ArrayView(arr: Array[Double], from: Int, until: Int) {
    def apply(n: Int) = 
        if (from + n < until) arr(from + n)
        else throw new ArrayIndexOutOfBoundsException(n)

    def partitionInPlace(p: Double => Boolean): (ArrayView, ArrayView) = {
      var upper = until - 1
      var lower = from
      while (lower < upper) {
        while (lower < until && p(arr(lower))) lower += 1
        while (upper >= from && !p(arr(upper))) upper -= 1
        if (lower < upper) { val tmp = arr(lower); arr(lower) = arr(upper); arr(upper) = tmp }
      }
      (copy(until = lower), copy(from = lower))
    }

    def size = until - from
    def isEmpty = size <= 0

    override def toString = arr mkString ("ArraySize(", ", ", ")")
}; object ArrayView {
    def apply(arr: Array[Double]) = new ArrayView(arr, 0, arr.size)
}

@tailrec def findKMedianInPlace(arr: ArrayView, k: Int)(implicit choosePivot: ArrayView => Double): Double = {
    val a = choosePivot(arr)
    val (s, b) = arr partitionInPlace (a >)
    if (s.size == k) a
    // The following test is used to avoid infinite repetition
    else if (s.isEmpty) {
        val (s, b) = arr partitionInPlace (a ==)
        if (s.size > k) a
        else findKMedianInPlace(b, k - s.size)
    } else if (s.size < k) findKMedianInPlace(b, k - s.size)
    else findKMedianInPlace(s, k)
}

def findMedianInPlace(arr: Array[Double])(implicit choosePivot: ArrayView => Double) = findKMedianInPlace(ArrayView(arr), (arr.size - 1) / 2)

Random Pivot, In-place

I'm only implementing the radom pivot for the in-place algorithms, as the median of medians would require more support than what is presently provided by the ArrayView class I defined.

def chooseRandomPivotInPlace(arr: ArrayView): Double = arr(scala.util.Random.nextInt(arr.size))

Histogram Algorithm (O(log(n)) memory), Immutable

So, about streams. It is impossible to do anything less than O(n) memory for a stream that can only be traversed once, unless you happen to know what the string length is (in which case it ceases to be a stream in my book).

Using buckets is also a bit problematic, but if we can traverse it multiple times, then we can know its size, maximum and minimum, and work from there. For example:

def findMedianHistogram(s: Traversable[Double]) = {
    def medianHistogram(s: Traversable[Double], discarded: Int, medianIndex: Int): Double = {
        // The buckets
        def numberOfBuckets = (math.log(s.size).toInt + 1) max 2
        val buckets = new Array[Int](numberOfBuckets)

        // The upper limit of each bucket
        val max = s.max
        val min = s.min
        val increment = (max - min) / numberOfBuckets
        val indices = (-numberOfBuckets + 1 to 0) map (max + increment * _)

        // Return the bucket a number is supposed to be in
        def bucketIndex(d: Double) = indices indexWhere (d <=)

        // Compute how many in each bucket
        s foreach { d => buckets(bucketIndex(d)) += 1 }

        // Now make the buckets cumulative
        val partialTotals = buckets.scanLeft(discarded)(_+_).drop(1)

        // The bucket where our target is at
        val medianBucket = partialTotals indexWhere (medianIndex <)

        // Keep track of how many numbers there are that are less 
        // than the median bucket
        val newDiscarded = if (medianBucket == 0) discarded else partialTotals(medianBucket - 1)

        // Test whether a number is in the median bucket
        def insideMedianBucket(d: Double) = bucketIndex(d) == medianBucket

        // Get a view of the target bucket
        val view = s.view filter insideMedianBucket

        // If all numbers in the bucket are equal, return that
        if (view forall (view.head ==)) view.head
        // Otherwise, recurse on that bucket
        else medianHistogram(view, newDiscarded, medianIndex)
    }

    medianHistogram(s, 0, (s.size - 1) / 2)
}

Test and Benchmark

To test the algorithms, I'm using Scalacheck, and comparing the output of each algorithm to the output of a trivial implementation with sorting. That assumes the sorting version is correct, of course.

I'm benchmarking each of the above algorithms with all provided pivot selections, plus a fixed pivot selection (halfway the array, round down). Each algorithm is tested with three different input array sizes, and for three times against each one.

Here's the testing code:

import org.scalacheck.{Prop, Pretty, Test}
import Prop._
import Pretty._

def test(algorithm: Array[Double] => Double, 
         reference: Array[Double] => Double): String = {
    def prettyPrintArray(arr: Array[Double]) = arr mkString ("Array(", ", ", ")")
    val resultEqualsReference = forAll { (arr: Array[Double]) => 
        arr.nonEmpty ==> (algorithm(arr) == reference(arr)) :| prettyPrintArray(arr)
    }
    Test.check(Test.Params(), resultEqualsReference)(Pretty.Params(verbosity = 0))
}

import java.lang.System.currentTimeMillis

def bench[A](n: Int)(body: => A): Long = {
  val start = currentTimeMillis()
  1 to n foreach { _ => body }
  currentTimeMillis() - start
}

import scala.util.Random.nextDouble

def benchmark(algorithm: Array[Double] => Double,
              arraySizes: List[Int]): List[Iterable[Long]] = 
    for (size <- arraySizes)
    yield for (iteration <- 1 to 3)
        yield bench(50000)(algorithm(Array.fill(size)(nextDouble)))

def testAndBenchmark: String = {
    val immutablePivotSelection: List[(String, Array[Double] => Double)] = List(
        "Random Pivot"      -> chooseRandomPivot,
        "Median of Medians" -> medianOfMedians,
        "Midpoint"          -> ((arr: Array[Double]) => arr((arr.size - 1) / 2))
    )
    val inPlacePivotSelection: List[(String, ArrayView => Double)] = List(
        "Random Pivot (in-place)" -> chooseRandomPivotInPlace,
        "Midpoint (in-place)"     -> ((arr: ArrayView) => arr((arr.size - 1) / 2))
    )
    val immutableAlgorithms = for ((name, pivotSelection) <- immutablePivotSelection)
        yield name -> (findMedian(_: Array[Double])(pivotSelection))
    val inPlaceAlgorithms = for ((name, pivotSelection) <- inPlacePivotSelection)
        yield name -> (findMedianInPlace(_: Array[Double])(pivotSelection))
    val histogramAlgorithm = "Histogram" -> ((arr: Array[Double]) => findMedianHistogram(arr))
    val sortingAlgorithm = "Sorting" -> ((arr: Array[Double]) => arr.sorted.apply((arr.size - 1) / 2))
    val algorithms = sortingAlgorithm :: histogramAlgorithm :: immutableAlgorithms ::: inPlaceAlgorithms

    val formattingString = "%%-%ds  %%s" format (algorithms map (_._1.length) max)

    // Tests
    val testResults = for ((name, algorithm) <- algorithms)
        yield formattingString format (name, test(algorithm, sortingAlgorithm._2))

    // Benchmarks
    val arraySizes = List(100, 500, 1000)
    def formatResults(results: List[Long]) = results map ("%8d" format _) mkString

    val benchmarkResults: List[String] = for {
        (name, algorithm) <- algorithms
        results <- benchmark(algorithm, arraySizes).transpose
    } yield formattingString format (name, formatResults(results))

    val header = formattingString format ("Algorithm", formatResults(arraySizes.map(_.toLong)))

    "Tests" :: "*****" :: testResults ::: 
    ("" :: "Benchmark" :: "*********" :: header :: benchmarkResults) mkString ("", "\n", "\n")
}

Results

Tests:

Tests
*****
Sorting                OK, passed 100 tests.
Histogram              OK, passed 100 tests.
Random Pivot           OK, passed 100 tests.
Median of Medians      OK, passed 100 tests.
Midpoint               OK, passed 100 tests.
Random Pivot (in-place)OK, passed 100 tests.
Midpoint (in-place)    OK, passed 100 tests.

Benchmarks:

Benchmark
*********
Algorithm                   100     500    1000
Sorting                    1038    6230   14034
Sorting                    1037    6223   13777
Sorting                    1039    6220   13785
Histogram                  2918   11065   21590
Histogram                  2596   11046   21486
Histogram                  2592   11044   21606
Random Pivot                904    4330    8622
Random Pivot                902    4323    8815
Random Pivot                896    4348    8767
Median of Medians          3591   16857   33307
Median of Medians          3530   16872   33321
Median of Medians          3517   16793   33358
Midpoint                   1003    4672    9236
Midpoint                   1010    4755    9157
Midpoint                   1017    4663    9166
Random Pivot (in-place)     392    1746    3430
Random Pivot (in-place)     386    1747    3424
Random Pivot (in-place)     386    1751    3431
Midpoint (in-place)         378    1735    3405
Midpoint (in-place)         377    1740    3408
Midpoint (in-place)         375    1736    3408

Analysis

All algorithms (except the sorting version) have results that are compatible with average linear time complexity.

The median of medians, which guarantees linear time complexity in the worst case is much slower than the random pivot.

The fixed pivot selection is slightly worse than random pivot, but may have much worse performance on non-random inputs.

The in-place version is about 230% ~ 250% faster, but further tests (not shown) seem to indicate this advantage grows with the size of the array.

I was very surprised by the histogram algorithm. It displayed linear time complexity average, and it's also 33% faster than the median of medians. However, the input is random. The worst case is quadratic -- I saw some examples of it while I was debugging the code.

Community
  • 1
  • 1
Daniel C. Sobral
  • 295,120
  • 86
  • 501
  • 681
  • The three problems with this code are (a) it doesn't compile (recursive functions need an explicit return-type), (b) it *isn't* linear time (since the partition is O(n) and it is run O(n) times), and (c) it produces the wrong answer. Other than that, yeah. – Michael Lorton Jan 12 '11 at 06:03
  • @Malvolio A few bugs here and there, but nothing so crass as to think it runs O(n) times... ;-) Anyway, I don't care whether the algorithm works or had the correct complexity, I'm just translating someone else's algorithm which is claimed to be linear time into Scala. – Daniel C. Sobral Jan 12 '11 at 17:15
  • The code seemed to be a very faithful translation of the algorithm, which was why I was surprised it kept giving me the wrong answer, but I was too tired to debug it last night. Not really worth it, since the given algorithm is O(n²) and `scala.util.Sorting.quickSort(a); a(a.length / 2)` is O(n log n) and a lot short and easier to understand (and works) – Michael Lorton Jan 12 '11 at 19:45
  • 1
    @Malvolio This algorithm is (or seems to be) O(nlogn), as the size of Arr is, on average, halved each time. However, this analysis is superficial. The algorithm looks pretty much like a quicksort, but only half of the partition is recursed into, which makes it faster than quicksort already. Also, it doesn't need to descend all the way into 1-sized partitions. As for the bugs, they were off-by-one errors, mostly related to the fact that the original algorithm implicitly removed `a` from the partition, and a missing "arr" when declaring `a`. Off by one errors suck. – Daniel C. Sobral Jan 12 '11 at 20:45
  • While we're fixing the algorithm, I would take the `random` call out because (a) it's expensive and pointless and (b) `arr.length / 2` is no worse in the random case but for a sorted list (which is common in the real world), the algorithm becomes constant-time. – Michael Lorton Jan 12 '11 at 20:58
  • 1
    @Malvolio random is not 'pointless'. If you don't use random, someone might guess the strategy you're using, pick an example on which your program takes O(n^2) time and hang your server. The algorithm is correct and linear on average. – adamax Jan 17 '11 at 08:41
  • @Adamax -- yes, if you are trying to design a system that a malicious user would be allowed to upload megabytes of data, and timing out is not an acceptable solution, then yes, throwing away an otherwise huge advantage would be a good idea. On my home planet, that situation is fairly rare. – Michael Lorton Jan 18 '11 at 01:25
  • Malvolio, it is you rejecting valid statements on basis of assumptions that have yet to be verified by the OP (i.e. input characteristics). On my home planet, this is how you fail algorithms courses. – Raphael Jan 19 '11 at 08:44
  • If you are really concerned about that, you can always choose `arr.length / 2` in the first iteration in order to be fast on sorted lists and a random element after that in order to keep the better average runtime. By the way, the argument brought forth in the reference for linear runtime is crap since is can only explain $mathcal{O}(n \log n). – Raphael Jan 19 '11 at 09:14
  • 1
    @Raphael The argument is completely valid. Suppose that each time the length of the array reduces by a factor of two. Then the first iteration takes n units of time, the second one n/2 units, the third one n/4 units, etc. which sums up to n+n/2+n/4 + ... = 2*n. Of course it's just an intuitive explanation, the rigorous proof can be found in any book on algorithms I suppose. – adamax Jan 20 '11 at 08:26
  • He writes: "Intuitively, the expected running time of this algorithm is linear because each recursive call takes on the average linear time, and each recursive call reduces the size of the problem by a constant factor." Merge- and Quicksort both fulfill these conditions but neither run in linear time. Therefore, the argument is invalid. What _you_ are doing is something different, and definitely valid, no doubt. – Raphael Jan 20 '11 at 11:02
  • I know the difference, but the requirements stated do not talk about that. "problem size" is incredibly fishy, since the median algorithm does not reduce the original problem's size either. Also, we talk about recursive algorithms, there are no "iterations". If you talked about the number of recursive calls or, even better, number of nodes in the call tree and its height, then we would be getting somewhere. Some things they teach at university are really necessary in order to talk properly about algorithms. – Raphael Jan 20 '11 at 21:50
  • Correctness is certainly relative to definitions. We apparently use different sets. – Raphael Jan 20 '11 at 22:31
  • Please, fellas, keep it classy. Any thoughts on implementations? – dsg Jan 21 '11 at 10:07
  • @dsg This algorithm gives you average linear time, because of the random selection. Analysis of algorithms using randomization is tougher and non-intuitive. It might even be the case that this algorithm has ammortized linear complexity. The wikipedia mentions, without backing citation, that this random algorithm is better on average than the fully linear one. So.. what else do you want? – Daniel C. Sobral Jan 21 '11 at 12:35
  • @Raphael While I have some disagreements with you, the argument was not civil, so I'll delete the whole discussion from my side. – Daniel C. Sobral Jan 21 '11 at 20:23
  • @Daniel C. Sobral: This is fantastic! How about the other two questions that deal with streams? – dsg Jan 21 '11 at 22:56
  • @dsg There's a median of 5 algorithm on another question that's supposed to be much faster than mine. It's presently missing a case, but I'll benchmark it once it's fixed. Also, if it gets into shouting range of random pivot, I'll make an in-place version that uses it. As for streams... well, I'll add something. – Daniel C. Sobral Jan 22 '11 at 23:56
  • Wow! Fantastic answer! Definitely worth more that 50 pts. – dsg Jan 24 '11 at 02:15
  • @dsg You caught the answer mid-editing. I had it almost completely edited when my computer froze and I lost everything, so I then decided to save my edits as I went through the changes. Now it's done. I'll probably add in-place median of medians, just to see if it performs any better as an in-place algorithm. – Daniel C. Sobral Jan 24 '11 at 03:26
  • Sorry to hear about your computer, that must have been frustrating. Maybe your algorithms will make it into the api? – dsg Jan 24 '11 at 04:06
  • @dsg Doubtful. This sort of thing belongs in an applied math library of some sort -- statistics, financial, etc. – Daniel C. Sobral Jan 24 '11 at 18:04
  • If I am not mistaken the median of median implementation is wrong. It recursively calls itself instead of calling the median algorithm. The resulting pivot will not necessarily be larger or smaller than 30% of the elements. – Mishael Rosenthal May 18 '16 at 19:58