17

I am looking for a TensorFlow way of implementing something similar to Python's list.index() function.

Given a matrix and a value to find, I want to know the first occurrence of the value in each row of the matrix.

For example,

m is a <batch_size, 100> matrix of integers
val = 23

result = [0] * batch_size
for i, row_elems in enumerate(m):
  result[i] = row_elems.index(val)

I cannot assume that 'val' appears only once in each row, otherwise I would have implemented it using tf.argmax(m == val). In my case, it is important to get the index of the first occurrence of 'val' and not any.

Igor Tsvetkov
  • 183
  • 1
  • 1
  • 4

4 Answers4

19

It seems that tf.argmax works like np.argmax (according to the test), which will return the first index when there are multiple occurrences of the max value. You can use tf.argmax(tf.cast(tf.equal(m, val), tf.int32), axis=1) to get what you want. However, currently the behavior of tf.argmax is undefined in case of multiple occurrences of the max value.

If you are worried about undefined behavior, you can apply tf.argmin on the return value of tf.where as @Igor Tsvetkov suggested. For example,

# test with tensorflow r1.0
import tensorflow as tf

val = 3
m = tf.placeholder(tf.int32)
m_feed = [[0  ,   0, val,   0, val],
          [val,   0, val, val,   0],
          [0  , val,   0,   0,   0]]

tmp_indices = tf.where(tf.equal(m, val))
result = tf.segment_min(tmp_indices[:, 1], tmp_indices[:, 0])

with tf.Session() as sess:
    print(sess.run(result, feed_dict={m: m_feed})) # [2, 0, 1]

Note that tf.segment_min will raise InvalidArgumentError when there is some row containing no val. In your code row_elems.index(val) will raise exception too when row_elems don't contain val.

Jenny
  • 791
  • 6
  • 15
  • This is very helpful! what if we wanted to update val to become a new_val? I asked this question here: https://stackoverflow.com/questions/45684445/tensorflow-update-first-matching-element-in-each-row – reese0106 Aug 14 '17 at 23:49
  • The TF documentation on `argmax` specifically states: "Note that in case of ties the identity of the return value is not guaranteed." Which leads me to believe that you can't rely on `argmax` returning the first value like numpy does, which I suspect is because of non-deterministic behavior on distributed devices like GPUs. – David Parks Oct 12 '18 at 23:34
  • Since TF 2.3, the [documentation](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/math/argmax) for `tf.argmax` *does* guarantee that "*In case of identity returns the smallest index.*" – P-Gn Apr 29 '21 at 21:04
4

Looks a little ugly but works (assuming m and val are both tensors):

idx = list()
for t in tf.unpack(m, axis=0):
    idx.append(tf.reduce_min(tf.where(tf.equal(t, val))))
idx = tf.pack(idx, axis=0)

EDIT: As Yaroslav Bulatov mentioned, you could achieve the same result with tf.map_fn:

def index1d(t):
    return tf.reduce_min(tf.where(tf.equal(t, val)))

idx = tf.map_fn(index1d, m, dtype=tf.int64)
Community
  • 1
  • 1
Dmitriy Danevskiy
  • 3,119
  • 1
  • 11
  • 15
3

Here is another solution to the problem, assuming there is a hit on every row.

import tensorflow as tf

val = 3
m = tf.constant([
    [0  ,   0,   val,   0, val],
    [val,   0,   val, val,   0],
    [0  , val,     0,   0,   0]])

# replace all entries in the matrix either with its column index, or out-of-index-number
match_indices = tf.where(                          # [[5, 5, 2, 5, 4],
    tf.equal(val, m),                              #  [0, 5, 2, 3, 5],
    x=tf.range(tf.shape(m)[1]) * tf.ones_like(m),  #  [5, 1, 5, 5, 5]]
    y=(tf.shape(m)[1])*tf.ones_like(m))

result = tf.reduce_min(match_indices, axis=1)

with tf.Session() as sess:
    print(sess.run(result)) # [2, 0, 1]
trudolf
  • 1,809
  • 16
  • 11
2

Here is a solution which also considers the case the element is not included by the matrix (solution from github repository of DeepMind)

def get_first_occurrence_indices(sequence, eos_idx):
    '''
    args:
        sequence: [batch, length]
        eos_idx: scalar
    '''
    batch_size, maxlen = sequence.get_shape().as_list()
    eos_idx = tf.convert_to_tensor(eos_idx)
    tensor = tf.concat(
            [sequence, tf.tile(eos_idx[None, None], [batch_size, 1])], axis = -1)
    index_all_occurrences = tf.where(tf.equal(tensor, eos_idx))
    index_all_occurrences = tf.cast(index_all_occurrences, tf.int32)
    index_first_occurrences = tf.segment_min(index_all_occurrences[:, 1], 
index_all_occurrences[:, 0])
    index_first_occurrences.set_shape([batch_size])
    index_first_occurrences = tf.minimum(index_first_occurrences + 1, maxlen)
    
    return index_first_occurrences

And:

import tensorflow as tf
mat = tf.Variable([[1,2,3,4,5], [2,3,4,5,6], [3,4,5,6,7], [0,0,0,0,0]], dtype = tf.int32)
idx = 3
first_occurrences = get_first_occurrence_indices(mat, idx)

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
sess.run(first_occurrence) # [3, 2, 1, 5]
Yanghoon
  • 572
  • 5
  • 18