3

Considering a list of several million objects like:

case class Point(val name:String, val x:Double, val y:Double)

I need, for a given Point target, to pick the 10 other points which are closest to the target.

val target = Point("myPoint", 34, 42)
val points = List(...) // list of several million points

def distance(p1: Point, p2: Point) = ??? // return the distance between two points

val closest10 = points.sortWith((a, b) => {
  distance(a, target) < distance(b, target)
}).take(10)

This method works but is very slow. Indeed, the whole list is exhaustively sorted for each target request, whereas past the first 10 closest points, I really don't care about any kind of sorting. I don't even need that the first 10 closest are returned in the correct order.

Ideally, I'd be looking for a "return 10 first and don't pay attention to the rest" kind of method..

Naive solution that I can think of would sound like this: sort by buckets of 1000, take first bucket, sort it by buckets of 100, take first bucket, sort it by buckets of 10, take first bucket, done.

Question is, I guess this must be a very common problem in CS, so before rolling out my own solution based on this naive approach, I'd like to know of any state-of-the-art way of doing that, or even if some standard method already exists.

TL;DR how to get the first 10 items of an unsorted list, without having to sort the whole list?

Jivan
  • 21,522
  • 15
  • 80
  • 131
  • 2
    Don't sort. Just do a linear traversal and pick closest 10 points. – sarveshseri Nov 29 '17 at 12:12
  • @SarveshKumarSingh I don't really get that, how can you do a linear traversal (aka, hunting for the next closest point if I understand well) if the list is not sorted, or without looking at all the points? Do I need a specific data structure to do that? – Jivan Nov 29 '17 at 12:16
  • 2
    @Jivan: Start by taking the first 10 elements of your sequence. Then go through the rest of it, element by element. Every time you find one that is closer to the target that what you already have, drop the element the furthest away from the target from your result set (the 10 element sequence) and replace it by the one you just found. At the end of the traversal, you have a sequence containing the 10 element closest to your target. – Marth Nov 29 '17 at 12:20
  • @Marth genius. Sarvesh, I guess this is what you had in mind. Thanks to both. – Jivan Nov 29 '17 at 12:21
  • 1
    @Jivan, what about to use QuickSelect to get 10-th, 9-th, ..., 1-th biggest/smallest element? QuickSelect has average complexity O(N). https://en.wikipedia.org/wiki/Quickselect – Artavazd Balayan Nov 29 '17 at 12:46
  • 1
    @Artavazd The problem with QucikSelect is that you need a `O(1)` or `C` random access and update. But there are very few data structures in Scala which have these characteristics - http://docs.scala-lang.org/overviews/collections/performance-characteristics.html – sarveshseri Nov 29 '17 at 13:02
  • @sarvesh-kumar-singh, but there is still Array in Scala with O(1) access. From the post there is no information, does the author need to mutate this array (like removing elements - most expensive). – Artavazd Balayan Nov 29 '17 at 15:16
  • For algorithmic questions like this, they've almost certainly been documented before. So https://en.wikipedia.org/wiki/Selection_algorithm – The Archetypal Paul Dec 03 '17 at 09:35

4 Answers4

3

Below is a barebone method adapted from this SO answer for picking n smallest integers from a list (which can be enhanced to handle more complex data structure):

def nSmallest(n: Int, list: List[Int]): List[Int] = {
  def update(l: List[Int], e: Int): List[Int] =
    if (e < l.head) (e :: l.tail).sortWith(_ > _) else l

  list.drop(n).foldLeft( list.take(n).sortWith(_ > _) )( update(_, _) )
}

nSmallest( 5, List(3, 2, 8, 2, 9, 1, 5, 5, 9, 1, 7, 3, 4) )
// res1: List[Int] = List(3, 2, 2, 1, 1)

Please note that the output is in reverse order.

Leo C
  • 22,006
  • 3
  • 26
  • 39
1

I was looking at this and wondered if a PriorityQueue might be useful.

import scala.collection.mutable.PriorityQueue

case class Point(val name:String, val x:Double, val y:Double)
val target = Point("myPoint", 34, 42)
val points = List(...) //list of points

def distance(p1: Point, p2: Point) = ??? //distance between two points

//load points-priority-queue with first 10 points
val ppq = PriorityQueue(points.take(10):_*){
  case (a,b) => distance(a,target) compare distance(b,target) //prioritize points
}

//step through everything after the first 10
points.drop(10).foldLeft(distance(ppq.head,target))((mxDst,nextPnt) => 
  if (mxDst > distance(nextPnt,target)) {
    ppq.dequeue()             //drop current far point
    ppq.enqueue(nextPnt)      //load replacement point
    distance(ppq.head,target) //return new max distance
  } else mxDst)

val result: List[Double] = ppq.dequeueAll  //10 closest points
jwvh
  • 50,871
  • 7
  • 38
  • 64
1

How it can be done with QuickSelect. I used in-place QuickSelect. Basically, for every target point we calculate the distance between all points and target and use QuickSelect to get k-th smallest distance (k-th order statistic). Will this algo be faster than using sorting depends on factors like number of points, number of nearests and number of targets. In my machine for 3kk random generated points, 10 target points and asking for 10 nearest points, it's 2 times faster than using Sort algo:

Number of points: 3000000
Number of targets: 10
Number of nearest: 10
QuickSelect: 10737 ms.
Sort: 20763 ms.
Results from QuickSelect are valid

Code:

import scala.annotation.tailrec
import scala.concurrent.duration.Deadline
import scala.util.Random

case class Point(val name: String, val x: Double, val y: Double)

class NearestPoints(val points: Seq[Point]) {
  private case class PointWithDistance(p: Point, d: Double) extends Ordered[PointWithDistance] {
    def compare(that: PointWithDistance): Int = d.compareTo(that.d)
  }
  def distance(p1: Point, p2: Point): Double = {
    Math.sqrt(Math.pow(p2.x - p1.x, 2) + Math.pow(p2.y - p1.y, 2))
  }
  def get(target: Point, n: Int): Seq[Point] = {
    val pd = points.map(p => PointWithDistance(p, distance(p, target))).toArray
    (1 to n).map(i => quickselect(i, pd).get.p)
  }
  // In-place QuickSelect from https://gist.github.com/mooreniemi/9e45d55c0410cad0a9eb6d62a5b9b7ae
  def quickselect[T <% Ordered[T]](k: Int, xs: Array[T]): Option[T] = {
    def randint(lo: Int, hi: Int): Int =
      lo + scala.util.Random.nextInt((hi - lo) + 1)

    @inline
    def swap[T](xs: Array[T], i: Int, j: Int): Unit = {
      val t = xs(i)
      xs(i) = xs(j)
      xs(j) = t
    }

    def partition[T <% Ordered[T]](xs: Array[T], l: Int, r: Int): Int = {
      var pivotIndex = randint(l, r)
      val pivotValue = xs(pivotIndex)
      swap(xs, r, pivotIndex)
      pivotIndex = l

      var i = l
      while (i <= r - 1) {
        if (xs(i) < pivotValue) {
          swap(xs, i, pivotIndex)
          pivotIndex = pivotIndex + 1
        }
        i = i + 1
      }
      swap(xs, r, pivotIndex)
      pivotIndex
    }

    @tailrec
    def quickselect0[T <% Ordered[T]](xs: Array[T], l: Int, r: Int, k: Int): T = {
      if (l == r) {
        xs(l)
      } else {
        val pivotIndex = partition(xs, l, r)
        k compare pivotIndex match {
          case 0 => xs(k)
          case -1 => quickselect0(xs, l, pivotIndex - 1, k)
          case 1 => quickselect0(xs, pivotIndex + 1, r, k)
        }
      }
    }

    xs match {
      case _ if xs.isEmpty => None
      case _ if k < 1 || k > xs.length => None
      case _ => Some(quickselect0(xs, 0, xs.size - 1, k - 1))
    }
  }
}

object QuickSelectVsSort {
  def main(args: Array[String]): Unit = {
    val rnd = new Random(42L)
    val MAX_N: Int = 3000000
    val NUM_OF_NEARESTS: Int = 10
    val NUM_OF_TARGETS: Int = 10
    println(s"Number of points: $MAX_N")
    println(s"Number of targets: $NUM_OF_TARGETS")
    println(s"Number of nearest: $NUM_OF_NEARESTS")

    // Generate random points
    val points = (1 to MAX_N)
      .map(x => Point(x.toString, rnd.nextDouble, rnd.nextDouble))

    // Generate target points
    val targets = (1 to NUM_OF_TARGETS).map(x => Point(s"Target$x", rnd.nextDouble, rnd.nextDouble))

    var start = Deadline.now
    val np = new NearestPoints(points)
    val viaQuickSelect = targets.map { case target =>
      val nearest = np.get(target, NUM_OF_NEARESTS)
      nearest
    }
    var end = Deadline.now
    println(s"QuickSelect: ${(end - start).toMillis} ms.")

    start = Deadline.now
    val viaSort = targets.map { case target =>
      val closest = points.sortWith((a, b) => {
        np.distance(a, target) < np.distance(b, target)
      }).take(NUM_OF_NEARESTS)
      closest
    }
    end = Deadline.now
    println(s"Sort: ${(end - start).toMillis} ms.")

    // Validate
    assert(viaQuickSelect.length == viaSort.length)
    viaSort.zipWithIndex.foreach { case (p, idx) =>
      assert(p == viaQuickSelect(idx))
    }
    println("Results from QuickSelect are valid")
  }
}
Artavazd Balayan
  • 2,353
  • 1
  • 16
  • 25
0

For finding the top n elements in a list you can Quicksort it and terminate early. That is, terminate at the point where you know there are n elements that are bigger than the pivot. See my implementation in the Rank class of Apache Jackrabbit (in Java though), which does just that.

michid
  • 10,536
  • 3
  • 32
  • 59