I want to compute the pairwise square distance of a batch of feature in Tensorflow. I have a simple implementation using + and * operations by tiling the original tensor :
def pairwise_l2_norm2(x, y, scope=None):
with tf.op_scope([x, y], scope, 'pairwise_l2_norm2'):
size_x = tf.shape(x)[0]
size_y = tf.shape(y)[0]
xx = tf.expand_dims(x, -1)
xx = tf.tile(xx, tf.pack([1, 1, size_y]))
yy = tf.expand_dims(y, -1)
yy = tf.tile(yy, tf.pack([1, 1, size_x]))
yy = tf.transpose(yy, perm=[2, 1, 0])
diff = tf.sub(xx, yy)
square_diff = tf.square(diff)
square_dist = tf.reduce_sum(square_diff, 1)
return square_dist
This function takes as input two matrices of size (m,d) and (n,d) and compute the squared distance between each row vector. The output is a matrix of size (m,n) with element 'd_ij = dist(x_i, y_j)'.
The problem is that I have a large batch and high dim features 'm, n, d' replicating the tensor consume a lot of memory. I'm looking for another way to implement this without increasing the memory usage and just only store the final distance tensor. Kind of double looping the original tensor.