Probably tf.dynamic_partition
may help, but it requires static number of output tensors. If you can establish a maximum number of tensors then you can use it.
import tensorflow as tf
import numpy as np
x = tf.placeholder(tf.int32, shape=[None, 2])
data = np.random.randint(10, size=(10,2))
parts = range(len(data))
out = tf.dynamic_partition(x, parts, 20)
sess = tf.Session()
print 'out tensors:\n', out
print
print 'input data:\n', data
print
print 'sess.run result:\n', sess.run(out, {x: data})
This outputs the following:
out tensors:
[<tf.Tensor 'DynamicPartition:0' shape=(?, 2) dtype=int32>,
<tf.Tensor 'DynamicPartition:1' shape=(?, 2) dtype=int32>,
<tf.Tensor 'DynamicPartition:2' shape=(?, 2) dtype=int32>,
<tf.Tensor 'DynamicPartition:3' shape=(?, 2) dtype=int32>,
<tf.Tensor 'DynamicPartition:4' shape=(?, 2) dtype=int32>,
<tf.Tensor 'DynamicPartition:5' shape=(?, 2) dtype=int32>,
<tf.Tensor 'DynamicPartition:6' shape=(?, 2) dtype=int32>,
<tf.Tensor 'DynamicPartition:7' shape=(?, 2) dtype=int32>,
<tf.Tensor 'DynamicPartition:8' shape=(?, 2) dtype=int32>,
<tf.Tensor 'DynamicPartition:9' shape=(?, 2) dtype=int32>,
<tf.Tensor 'DynamicPartition:10' shape=(?, 2) dtype=int32>,
<tf.Tensor 'DynamicPartition:11' shape=(?, 2) dtype=int32>,
<tf.Tensor 'DynamicPartition:12' shape=(?, 2) dtype=int32>,
<tf.Tensor 'DynamicPartition:13' shape=(?, 2) dtype=int32>,
<tf.Tensor 'DynamicPartition:14' shape=(?, 2) dtype=int32>,
<tf.Tensor 'DynamicPartition:15' shape=(?, 2) dtype=int32>,
<tf.Tensor 'DynamicPartition:16' shape=(?, 2) dtype=int32>,
<tf.Tensor 'DynamicPartition:17' shape=(?, 2) dtype=int32>,
<tf.Tensor 'DynamicPartition:18' shape=(?, 2) dtype=int32>,
<tf.Tensor 'DynamicPartition:19' shape=(?, 2) dtype=int32>]
input data:
[[7 6]
[5 1]
[4 6]
[4 8]
[4 9]
[0 9]
[9 6]
[7 6]
[0 5]
[9 7]]
sess.run result:
[array([[7, 3]], dtype=int32),
array([[0, 5]], dtype=int32),
array([[2, 3]], dtype=int32),
array([[2, 6]], dtype=int32),
array([[7, 9]], dtype=int32),
array([[8, 2]], dtype=int32),
array([[1, 5]], dtype=int32),
array([[3, 7]], dtype=int32),
array([[6, 7]], dtype=int32),
array([[8, 1]], dtype=int32),
array([], shape=(0, 2), dtype=int32),
array([], shape=(0, 2), dtype=int32),
array([], shape=(0, 2), dtype=int32),
array([], shape=(0, 2), dtype=int32),
array([], shape=(0, 2), dtype=int32),
array([], shape=(0, 2), dtype=int32),
array([], shape=(0, 2), dtype=int32),
array([], shape=(0, 2), dtype=int32),
array([], shape=(0, 2), dtype=int32),
array([], shape=(0, 2), dtype=int32)]