I need to find k
nearest neighbors for each object from a set. Each object has its coordinates as properties.
To solve the task, I am trying to use spatial.KDTree
from scipy
. It works fine if I use a list or tuple to represent a point, but it doesn't work for objects.
I implemented __getitem__
and __len__
methods in my class, but KDTree
implementation asks my objects for non-existing coordinate axis (say for 3-rd coordinate of 2-dimensional point).
Here is a simple script to reproduce the problem:
from scipy import spatial
class Unit:
def __init__(self, x,y):
self.x = x
self.y = y
def __getitem__(self, index):
if index == 0:
return self.x
elif index == 1:
return self.y
else:
raise Exception('Unit coordinates are 2 dimensional')
def __len__(self):
return 2
#points = [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)]
#points = [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]]
points = [Unit(1,1), Unit(2,2), Unit(3,3), Unit(4,4), Unit(5,5)]
tree = spatial.KDTree(points)
#result = tree.query((6,6), 3)
result = tree.query(Unit(6,6), 3)
print(result)
It is not necessary for me to use this specific implementation or library or even algorithm, but the requirement is to deal with objects.
P.S. I may add id
field to each object and move all coordinates into separate array where index is object id
. But I still want to avoid such approach if possible.