3

A is a point, and P is a list of points. I want to find which point P[i] is the closest to A, i.e. I want to find P[i_0] with:

i_0 = argmin_i || A - P[i]||^2

I do it this way:

import numpy as np

# P is a list of 4 points
P = [np.array([-1, 0, 7, 3]), np.array([5, -2, 8, 1]), np.array([0, 2, -3, 4]), np.array([-9, 11, 3, 4])]
A = np.array([1, 2, 3, 4])
distance = 1000000000     # better would be : +infinity
closest = None

for p in P:
  delta = sum((p - A)**2)
  if delta < distance:
    distance = delta
    closest = p

print closest    # the closest point to A among all the points in P

It works, but how to do this in a shorter/more Pythonic way?

More generally in Python (and even without using Numpy), how to find k_0 such that D[k_0] = min D[k]? i.e. k_0 = argmin_k D[k]

Basj
  • 41,386
  • 99
  • 383
  • 673
  • Do you have SciPy? [A data structure specifically designed for nearest-neighbor lookup would probably perform better than any array-based solution.](http://docs.scipy.org/doc/scipy-0.15.1/reference/generated/scipy.spatial.cKDTree.html) – user2357112 Dec 06 '15 at 09:48
  • @user2357112 I know that many modules implement nearest-neighbour lookup, but for this example, I would like to code it myself (with no performance optimization, but just shorter code). My question is more how to implement `argmin` nicely in Python. – Basj Dec 06 '15 at 09:52
  • More Pythonic != shorter code. – user2357112 Dec 06 '15 at 10:00
  • @user2357112 : In this particular example it is `==`, see the [answer](http://stackoverflow.com/a/34116108/1422096). – Basj Dec 06 '15 at 10:01

4 Answers4

3

A more Pythonic way of implementing the same algorithm you're using is to replace your loop with a call to min with a key function:

closest = min(P, key=lambda p: sum((p - A)**2))

Note that I'm using ** for exponentiation (^ is the binary-xor operator in Python).

Blckknght
  • 100,903
  • 11
  • 120
  • 169
  • Great! I didn't know this min/key pattern! – Basj Dec 06 '15 at 10:00
  • you can see your code in action here: [Simple (working) handwritten digit recognition](http://stackoverflow.com/questions/34116526/simple-working-handwritten-digit-recognition-how-to-improve-it)! – Basj Dec 06 '15 at 10:51
2

A fully vectorized approach in numpy. Similar to the one of @MikeMüller, but using numpy's broadcasting to avoid lambda functions.

With the example data:

>>> P = [np.array([-1, 0, 7, 3]), np.array([5, -2, 8, 1]), np.array([0, 2, -3, 4]), np.array([-9, 11, 3, 4])]
>>> A = np.array([1, 2, 3, 4])

And making P a 2D numpy array:

>>> P = np.asarray(P)
>>> P
array([[-1,  0,  7,  3],
       [ 5, -2,  8,  1],
       [ 0,  2, -3,  4],
       [-9, 11,  3,  4]])

It can be computed in one line using numpy:

>>> P[np.argmin(np.sum((P - A)**2, axis=1))]

Note that P - A, with P.shape = (N, 4) and A.shape = (4,) will brooadcast the substraction to all the rows of P (Pi = Pi - A).

For small N (number of rows in P), the pythonic approach is probably faster. For large values of N this should be significantly faster.

Imanol Luengo
  • 15,366
  • 2
  • 49
  • 67
  • Very fast method indeed! See comment on [this answer](http://stackoverflow.com/a/34116234/1422096). – Basj Dec 09 '15 at 17:20
1

A NumPy version as one-liner:

clostest = P[np.argmin(np.apply_along_axis(lambda p: np.sum((p - A) **2), 1, P))]
Mike Müller
  • 82,630
  • 20
  • 166
  • 161
  • 1
    You're computing a slightly different value (the smallest sum, squared, rather than the sum *of* squares). The numpy bits are probably useful though. – Blckknght Dec 06 '15 at 19:31
  • @Blckknght Ups, missing a pair of parenthesis. Fixed. Thanks for the hint. – Mike Müller Dec 06 '15 at 19:36
  • I compared both your solution and [@imaluengo's solution](http://stackoverflow.com/a/34122191/1422096) with P containing 5000 vectors of size 10. When doing 1000 different nearest neighbour search, I takes ~60 seconds with this method, and ~ 0.6sec with @imaluengo's method. It would be interesting to see why `np.apply_along_axis` has such an effect. – Basj Dec 09 '15 at 17:17
  • Calling Python functions is slow. 'lambda` is pure Python. – Mike Müller Dec 09 '15 at 17:25
0

Usage of the builtin min is the way for this:

import math
p1 = [1,2]
plst = [[1,3], [10,10], [5,5]]
res = min(plst, key=lambda x: math.sqrt(pow(p1[0]-x[0], 2) + pow(p1[1]-x[1], 2)))
print res
[1, 3]

Note that I just used plain python lists.

Netwave
  • 40,134
  • 6
  • 50
  • 93
  • 1
    The square root is not necessary, since the point with the smallest distance will also be the point with the smallest distance squared. – Blckknght Dec 06 '15 at 19:32