tf.math.argmax returns index of maximum value in a tensor.
a = tf.constant([1,2,3])
print(a)
print(tf.math.argmax(input = a))
output:
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
<tf.Tensor: shape=(), dtype=int64, numpy=2>
I want to apply tf.math.argmax function on a list of tensors. How can I do it.
input = tf.constant([1,2,3,4,5,6])
split_sequence = tf.split(input, num_or_size_splits=2, axis=-1)
print(split_sequence)
tf.math.argmax(input = split_sequence)
output:
[<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([4, 5, 6], dtype=int32)>]
tf.Tensor([1 2 3 4 5 6], shape=(6,), dtype=int32)
<tf.Tensor: shape=(3,), dtype=int64, numpy=array([1, 1, 1])>
It is giving wrong indices -> numpy=array([1, 1, 1]
desired output:
numpy=array([[2],[2]]