1

tf.math.argmax returns index of maximum value in a tensor.

a = tf.constant([1,2,3])
print(a)
print(tf.math.argmax(input = a))

output:

tf.Tensor([1 2 3], shape=(3,), dtype=int32)
<tf.Tensor: shape=(), dtype=int64, numpy=2>

I want to apply tf.math.argmax function on a list of tensors. How can I do it.

input = tf.constant([1,2,3,4,5,6])
split_sequence = tf.split(input, num_or_size_splits=2, axis=-1)
print(split_sequence)
tf.math.argmax(input = split_sequence)

output:

[<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([4, 5, 6], dtype=int32)>]
tf.Tensor([1 2 3 4 5 6], shape=(6,), dtype=int32)
<tf.Tensor: shape=(3,), dtype=int64, numpy=array([1, 1, 1])>

It is giving wrong indices -> numpy=array([1, 1, 1]

desired output:

numpy=array([[2],[2]]
I'mahdi
  • 23,382
  • 5
  • 22
  • 30
AB Music Box
  • 81
  • 1
  • 6

3 Answers3

3

You can use map to apply any function on each value in the list.

(It's better don't use built-in function of python as a variable so I change input to inp)

import tensorflow as tf

inp = tf.constant([1,2,3,4,5,6])
split_sequence = tf.split(inp, num_or_size_splits=2, axis=-1)
print(split_sequence)

result = list(map(lambda x: [tf.math.argmax(x).numpy()] , split_sequence))
print(result)

Or by thanks @jkr, we can use List Comprehensions too. (Which one is better, map vs List comprehension)

>>> [[tf.math.argmax(item).numpy()] for item in split_sequence]
[[2], [2]]

[
    <tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>, 
    <tf.Tensor: shape=(3,), dtype=int32, numpy=array([4, 5, 6], dtype=int32)>
]

[[2], [2]]

Benchmark (on colab):

import tensorflow as tf
input = tf.constant([1,2,3,4,5,6]*1_000_000)
split_sequence = tf.split(input, num_or_size_splits=20, axis=-1)

%timeit tf.math.top_k(split_sequence, k=1).indices
# 13.5 ms ± 394 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


%timeit list(map(lambda x: [tf.math.argmax(x).numpy()] , split_sequence))
# 14 ms ± 2.39 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


%timeit [[tf.math.argmax(item).numpy()] for item in split_sequence]
# 8.77 ms ± 113 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
I'mahdi
  • 23,382
  • 5
  • 22
  • 30
  • 2
    if using map, why not a list comprehension? `[tf.math.argmax(item) for item in split_sequence]` – jkr Aug 28 '22 at 14:38
  • @jkr, We can read [here](https://stackoverflow.com/a/6407222/1740577), in some cases `map` is better that list comprehension, but here we need to check. – I'mahdi Aug 28 '22 at 14:47
  • @jkr, Here `list_comprehension ` is the fastest approach. – I'mahdi Aug 28 '22 at 15:11
  • i wasn't thinking about speed. i was thinking about readability. in my humble opinion, the list comprehension is easier to read and it is easier to understand what is happening. – jkr Aug 28 '22 at 17:04
  • @jkr, Yes, I agree with you, `list_comprehension` is more readable than using `map` and in this case, is faster too. – I'mahdi Aug 28 '22 at 17:08
3

I would recommend simply using tf.math.top_k in your case:

import tensorflow as tf

input = tf.constant([1,2,3,4,5,6])
split_sequence = tf.split(input, num_or_size_splits=2, axis=-1)
x = tf.math.top_k(split_sequence, sorted=False, k=1).indices
print(x)
tf.Tensor(
[[2]
 [2]], shape=(2, 1), dtype=int32)

Afterwards, if you want a Numpy array, just call x.numpy().

Update 1: A even simpler but slower approach is to just change the axis of tf.argmax:

input = tf.constant([1,2,3,4,5,6])
split_sequence = tf.split(input, num_or_size_splits=2, axis=-1)
tf.argmax(split_sequence, axis=-1)
# <tf.Tensor: shape=(2,), dtype=int64, numpy=array([2, 2])>

You do not need any explicit loop or map. See benchmarks:

import tensorflow as tf

input = tf.constant([1,2,3,4,5,6]*1_000_000)
split_sequence = tf.split(input, num_or_size_splits=20, axis=-1)

@tf.function
def top_k(split_sequence):
  return tf.math.top_k(split_sequence, k=1, sorted=False).indices

@tf.function
def argmax(split_sequence):
  return tf.argmax(split_sequence, axis=-1)

@tf.function
def _map(split_sequence):
  return list(map(lambda x: [tf.math.argmax(x)] , split_sequence))

@tf.function
def _list(split_sequence):
  return [[tf.math.argmax(item)] for item in split_sequence]

%timeit top_k(split_sequence)
# 3.5 ms ± 246 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

%timeit argmax(split_sequence)
# 16.6 ms ± 3.79 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)

%timeit _map(split_sequence)
# 10.3 ms ± 929 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

%timeit _list(split_sequence)
# 10.2 ms ± 2.15 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
AloneTogether
  • 25,814
  • 5
  • 20
  • 39
  • @AloneTogether Thanks. Doesn't the top_k method use loop internally? Actually It would be helpful if i avoid using loops because i want to use this code in custom layers in tensorflow. – AB Music Box Aug 29 '22 at 06:42
  • Yes, but actually you can just use `tf.argmax` directly as shown in my answer. – AloneTogether Aug 29 '22 at 06:45
  • @AloneTogether dose the tf.argmax use grap code for fast execution? – AB Music Box Aug 29 '22 at 06:52
  • @ABMusicBox if graph mode execution is activated, see benchbarks. `tf.math.top_k` is fastest in `Graph` mode. – AloneTogether Aug 29 '22 at 09:33
  • I have posted a question in which i am trying to build a coustom layer in ternsorflow where i need to use this code. I would be thankful of you can you help me solve this problem link: https://stackoverflow.com/q/73527851/13096521. I have gave the whole code in the question but need a little help as i dont know a lot about tensorflow. – AB Music Box Aug 29 '22 at 11:13
1

<tf.Tensor: shape=(), dtype=int64, numpy=2>

You can see the output in numpy = 2 i.e, 2nd index of your constant which is value 3