0

Very similar to this question How to find an index of the first matching element in TensorFlow

And I tried the solution for that But the difference is that val is not a single number it is a tensor

so example

np.array([1, 1, 1],
         [1, 0, 1],
         [0, 0, 1])
val = np.array([1, 0, 1])


some tensorflow magic happens here!

result = 1

I know i could use a while loop but that seems messy. I can try a mapped function but is there something more elegant?

dtracers
  • 1,534
  • 3
  • 17
  • 37

1 Answers1

3

Here's one way -

(arr == val).all(axis=-1).argmax()

Sample run -

In [977]: arr
Out[977]: 
array([[1, 1, 1],
       [1, 0, 1],
       [0, 0, 1]])

In [978]: val
Out[978]: array([1, 0, 1])

In [979]: (arr == val).all(axis=1).argmax()
Out[979]: 1

Might be more performant with views -

# https://stackoverflow.com/a/44999009/ @Divakar
def view1D(a): # a is array
    a = np.ascontiguousarray(a)
    void_dt = np.dtype((np.void, a.dtype.itemsize * a.shape[1]))
    return a.view(void_dt).ravel()

out = (view1D(arr) == view1D(val[None])).argmax()

Extension to n-dim cases

Extending to n-dim array cases would need few more steps -

def first_match_index_along_axis(arr, val, axis):    
    s = [None]*arr.ndim
    s[axis] = Ellipsis
    mask = val[np.s_[s]] == arr
    idx = mask.all(axis=axis,keepdims=True).argmax()
    shp = list(arr.shape)
    del shp[axis]
    return np.unravel_index(idx, shp)

Sample runs -

In [74]: arr = np.random.randint(0,9,(4,5,6,7))

In [75]: first_match_index_along_axis(arr, arr[2,:,1,0], axis=1)
Out[75]: (2, 1, 0)

In [76]: first_match_index_along_axis(arr, arr[2,1,3,:], axis=3)
Out[76]: (2, 1, 3)
Divakar
  • 218,885
  • 19
  • 262
  • 358
  • Or using TensorFlow, that Numpy one-liner would be `tf.argmax(tf.cast(tf.reduce_all(tf.equal(arr, val), axis=1), tf.int32))` (working around the fact that == is not overloaded, .all is not implemented, and argmax doesn't work on bools). Note that both versions are somewhat dangerous since they'll return 0 if the pattern was never found. – Allen Lavoie Dec 11 '17 at 22:01
  • 1
    @AllenLavoie If not found, we could use : `np.where` , as discussed here - https://stackoverflow.com/questions/47269390/numpy-how-to-find-first-non-zero-value-in-every-column-of-a-numpy-array. – Divakar Dec 12 '17 at 05:07