3

How do I compute the set difference of elements of two arrays in tensorflow?

Example: I want to subtract all elements of b from a:

import numpy as np

a = np.array([[1, 0, 1], [2, 0, 1], [3, 0, 1], [0, 0, 0]])
b = np.array([[1, 0, 1], [2, 0, 1]])

Expected result:

array([[3, 0, 1], 
       [0, 0, 0]])

It can probably be done with tf.sets.set_difference(), but I fail to see how.

In numpy, you can do something like this, but I'm after a tensorflow solution to offload this operation to a GPU device, as this operation is computationally expensive for large arrays.

lukostaz
  • 176
  • 1
  • 6
  • `tf.sets.set_difference` operates only on the lowest dimension of your tensors and allows for only the last dimension to be different than the others in the two input tensors. You're trying to remove entire rows, i.e., you're operating on dimension `-2` of your set, this won't work with `set_difference`. – GPhilo Nov 29 '17 at 14:35

1 Answers1

0

What about this solution:

import tensorflow as tf

def diff(tens_x, tens_y):
    with tf.get_default_graph().as_default():
        i=tf.constant(0)
        score_list = tf.constant(dtype=tf.int32, value=[])

    def cond(score_list, i):
        return tf.less(i, tf.shape(tens_y)[0])

    def body(score_list, i):
        locs_1 = tf.not_equal(tf.gather(tens_y, i), tens_x)
        locs = tf.reduce_any(locs_1, axis=1)
        ans = tf.reshape(tf.cast(tf.where(locs), dtype=tf.int32), [-1])
        score_list = tf.concat([score_list, ans], axis=0)

        return [score_list, tf.add(i, 1)]

    all_scores, _ = tf.while_loop(cond, body, loop_vars=[score_list, i],
                                  shape_invariants=[tf.TensorShape([None,]), i.get_shape()])

    uniq, __, counts = tf.unique_with_counts(all_scores)

    return tf.gather(tens_x,tf.gather(uniq, tf.where(counts > tf.shape(tens_y)[0] - 1)))


if __name__ == '__main__':
    tens_x = tf.constant([[1, 0, 1], [2, 0, 1], [3, 0, 1], [0, 0, 0]])
    tens_y = tf.constant([[1, 0, 1], [2, 0, 1]])

    results = diff(tens_x, tens_y)

with tf.Session() as sess:
    ans_ = sess.run(results)
    print(ans_)

[[[3 0 1]]

 [[0 0 0]]]
Mahdi Ghelichi
  • 1,090
  • 14
  • 23