0
all_max = tf.convert_to_tensor([[4, 2, 3], [3, 4, 5]], dtype=tf.float32)

How to get the index of the element [3,4,5] from the tensor array all_max?

In list, we simply use list.index(element) to get the index for an element present in list.

Thanks

vishak raj
  • 21
  • 4

1 Answers1

0

i found this on : https://www.py4u.net/discuss/147615 worked for me

To find element index of a 2d/3d tensor covert it into 1d #ie example.view(number of elements)

Example:

mat=torch.tensor([[1,2],[4,3])
#to find index of 2
five = 2
mat=mat.view(4)
numb_of_col = 4
for o in range(numb_of_col):
   if mat[o] == five:
   print(torch.tensor([o]))    
H Sa
  • 128
  • 9