7

I meet a problem: After I use the tf.nn.max_pool_with_argmax, I obtain the indices i.e. argmax: A Tensor of type Targmax. 4-D. The flattened indices of the max values chosen for each output.

How to unravel the flattened indices back to the coordinates list in Tensorflow?

Thank you very much.

karl_TUM
  • 5,769
  • 10
  • 24
  • 41

1 Answers1

3

I had the same problem today and I ended up with this solution:

def unravel_argmax(argmax, shape):
    output_list = []
    output_list.append(argmax // (shape[2] * shape[3]))
    output_list.append(argmax % (shape[2] * shape[3]) // shape[3])
    return tf.pack(output_list)

Here is an usage example in an ipython notebook (I use it to forward the pooling argmax positions to my unpooling method)

Fabian
  • 3,139
  • 2
  • 23
  • 49