I am looking for a TensorFlow way of implementing something similar to Python's list.index() function.
Given a matrix and a value to find, I want to know the first occurrence of the value in each row of the matrix.
For example,
m is a <batch_size, 100> matrix of integers
val = 23
result = [0] * batch_size
for i, row_elems in enumerate(m):
result[i] = row_elems.index(val)
I cannot assume that 'val' appears only once in each row, otherwise I would have implemented it using tf.argmax(m == val). In my case, it is important to get the index of the first occurrence of 'val' and not any.