How it can be done with QuickSelect. I used in-place QuickSelect. Basically, for every target point we calculate the distance between all points and target and use QuickSelect to get k-th smallest distance (k-th order statistic). Will this algo be faster than using sorting depends on factors like number of points, number of nearests and number of targets. In my machine for 3kk random generated points, 10 target points and asking for 10 nearest points, it's 2 times faster than using Sort algo:
Number of points: 3000000
Number of targets: 10
Number of nearest: 10
QuickSelect: 10737 ms.
Sort: 20763 ms.
Results from QuickSelect are valid
Code:
import scala.annotation.tailrec
import scala.concurrent.duration.Deadline
import scala.util.Random
case class Point(val name: String, val x: Double, val y: Double)
class NearestPoints(val points: Seq[Point]) {
private case class PointWithDistance(p: Point, d: Double) extends Ordered[PointWithDistance] {
def compare(that: PointWithDistance): Int = d.compareTo(that.d)
}
def distance(p1: Point, p2: Point): Double = {
Math.sqrt(Math.pow(p2.x - p1.x, 2) + Math.pow(p2.y - p1.y, 2))
}
def get(target: Point, n: Int): Seq[Point] = {
val pd = points.map(p => PointWithDistance(p, distance(p, target))).toArray
(1 to n).map(i => quickselect(i, pd).get.p)
}
// In-place QuickSelect from https://gist.github.com/mooreniemi/9e45d55c0410cad0a9eb6d62a5b9b7ae
def quickselect[T <% Ordered[T]](k: Int, xs: Array[T]): Option[T] = {
def randint(lo: Int, hi: Int): Int =
lo + scala.util.Random.nextInt((hi - lo) + 1)
@inline
def swap[T](xs: Array[T], i: Int, j: Int): Unit = {
val t = xs(i)
xs(i) = xs(j)
xs(j) = t
}
def partition[T <% Ordered[T]](xs: Array[T], l: Int, r: Int): Int = {
var pivotIndex = randint(l, r)
val pivotValue = xs(pivotIndex)
swap(xs, r, pivotIndex)
pivotIndex = l
var i = l
while (i <= r - 1) {
if (xs(i) < pivotValue) {
swap(xs, i, pivotIndex)
pivotIndex = pivotIndex + 1
}
i = i + 1
}
swap(xs, r, pivotIndex)
pivotIndex
}
@tailrec
def quickselect0[T <% Ordered[T]](xs: Array[T], l: Int, r: Int, k: Int): T = {
if (l == r) {
xs(l)
} else {
val pivotIndex = partition(xs, l, r)
k compare pivotIndex match {
case 0 => xs(k)
case -1 => quickselect0(xs, l, pivotIndex - 1, k)
case 1 => quickselect0(xs, pivotIndex + 1, r, k)
}
}
}
xs match {
case _ if xs.isEmpty => None
case _ if k < 1 || k > xs.length => None
case _ => Some(quickselect0(xs, 0, xs.size - 1, k - 1))
}
}
}
object QuickSelectVsSort {
def main(args: Array[String]): Unit = {
val rnd = new Random(42L)
val MAX_N: Int = 3000000
val NUM_OF_NEARESTS: Int = 10
val NUM_OF_TARGETS: Int = 10
println(s"Number of points: $MAX_N")
println(s"Number of targets: $NUM_OF_TARGETS")
println(s"Number of nearest: $NUM_OF_NEARESTS")
// Generate random points
val points = (1 to MAX_N)
.map(x => Point(x.toString, rnd.nextDouble, rnd.nextDouble))
// Generate target points
val targets = (1 to NUM_OF_TARGETS).map(x => Point(s"Target$x", rnd.nextDouble, rnd.nextDouble))
var start = Deadline.now
val np = new NearestPoints(points)
val viaQuickSelect = targets.map { case target =>
val nearest = np.get(target, NUM_OF_NEARESTS)
nearest
}
var end = Deadline.now
println(s"QuickSelect: ${(end - start).toMillis} ms.")
start = Deadline.now
val viaSort = targets.map { case target =>
val closest = points.sortWith((a, b) => {
np.distance(a, target) < np.distance(b, target)
}).take(NUM_OF_NEARESTS)
closest
}
end = Deadline.now
println(s"Sort: ${(end - start).toMillis} ms.")
// Validate
assert(viaQuickSelect.length == viaSort.length)
viaSort.zipWithIndex.foreach { case (p, idx) =>
assert(p == viaQuickSelect(idx))
}
println("Results from QuickSelect are valid")
}
}