You can do this with normal Tensorflow operations such as tf.where
, tf.math.squared_difference
, tf.math.argmin
, and tf.gather
. Here, I demonstrate an example with a negative and a positive value:
import tensorflow as tf
t = tf.random.normal((5, 2))
print(t, '\n')
closest_neighbors = [-1, 2]
for c in closest_neighbors:
tensor = tf.math.squared_difference(t, c)
indices = tf.math.argmin(tensor, axis=0)
a = tensor[indices[0],0]
b = tensor[indices[1],1]
final_indices = tf.where(tf.less(a, b), [indices[0],0], [indices[1],1])
closest_value = tf.gather_nd(t, final_indices)
print('Closest value to {} is {}'.format(c, closest_value))
tf.Tensor(
[[ 0.9975055 -2.148285 ]
[-2.27254 -1.2470466 ]
[-1.0182583 1.1855317 ]
[-0.7712745 0.63082063]
[-0.5022545 0.08102719]], shape=(5, 2), dtype=float32)
Closest value to -1 is -1.0182583332061768
Closest value to 2 is 1.185531735420227