17

As generalized slicing is being worked on in this issue, what would be the best way to achieve an op gathering columns of a 2D tensor (matrix)? For example, for tensor t:

1 2 3 4
5 6 7 8 

and indices [1,3], I would like to get:

2 4
6 8

which is equivalent to numpy t[:, [1,3]].

nbro
  • 15,395
  • 32
  • 113
  • 196
Andrzej Pronobis
  • 33,828
  • 17
  • 76
  • 92

3 Answers3

25

Meanwhile the gather method has an axis parameter.

import tensorflow as tf
params = tf.constant([[1,2,3],[4,5,6]])
indices = [0,2]
op = tf.gather(params, indices, axis=1)

produces the output

[[1 3]
 [4 6]]
AlexConfused
  • 801
  • 1
  • 10
  • 15
9

There is a function named tf.nn.embedding_lookup(params, ind) which retrieves the rows of the params tensor.

To achieve what you want, we can first transpose the tensor t from which you want to select certain columns from. Then look up the rows of tf.transpose(t) (columns of t). After the selection, we transpose the result back.

import tensorflow as tf


t = tf.constant([[1, 2, 3], 
                 [4, 5, 6]])
ind = tf.constant([0, 2])

result = tf.transpose(tf.nn.embedding_lookup(tf.transpose(t), ind))

with tf.Session() as sess:
    print(sess.run(result))
nbro
  • 15,395
  • 32
  • 113
  • 196
lucky6qi
  • 965
  • 7
  • 10
5

So far, I created a workaround by flattening the input and using gather:

def gather_cols(params, indices, name=None):
    """Gather columns of a 2D tensor.

    Args:
        params: A 2D tensor.
        indices: A 1D tensor. Must be one of the following types: ``int32``, ``int64``.
        name: A name for the operation (optional).

    Returns:
        A 2D Tensor. Has the same type as ``params``.
    """
    with tf.op_scope([params, indices], name, "gather_cols") as scope:
        # Check input
        params = tf.convert_to_tensor(params, name="params")
        indices = tf.convert_to_tensor(indices, name="indices")
        try:
            params.get_shape().assert_has_rank(2)
        except ValueError:
            raise ValueError('\'params\' must be 2D.')
        try:
            indices.get_shape().assert_has_rank(1)
        except ValueError:
            raise ValueError('\'indices\' must be 1D.')

        # Define op
        p_shape = tf.shape(params)
        p_flat = tf.reshape(params, [-1])
        i_flat = tf.reshape(tf.reshape(tf.range(0, p_shape[0]) * p_shape[1],
                                       [-1, 1]) + indices, [-1])
        return tf.reshape(tf.gather(p_flat, i_flat),
                          [p_shape[0], -1])

Which for:

params = tf.constant([[1, 2, 3],
                      [4, 5, 6]])
indices = [0, 2]
op = gather_cols(params, indices)

produces the expected output:

[[1 3]
 [4 6]]
Yuval Atzmon
  • 5,645
  • 3
  • 41
  • 74
Andrzej Pronobis
  • 33,828
  • 17
  • 76
  • 92