I've read the tf.scatter_nd documentation and run the example code for 1D and 3D tensors... and now I'm trying to do it for a 2D tensor. I want to 'interleave' the columns of two tensors. For 1D tensors, one can do this via
'''
We want to interleave elements of 1D tensors arr1 and arr2, where
arr1 = [10, 11, 12]
arr2 = [1, 2, 3, 4, 5, 6]
such that
desired result = [1, 2, 10, 3, 4, 11, 5, 6, 12]
'''
import tensorflow as tf
with tf.Session() as sess:
updates1 = tf.constant([1,2,3,4,5,6])
indices1 = tf.constant([[0], [1], [3], [4], [6], [7]])
shape = tf.constant([9])
scatter1 = tf.scatter_nd(indices1, updates1, shape)
updates2 = tf.constant([10,11,12])
indices2 = tf.constant([[2], [5], [8]])
scatter2 = tf.scatter_nd(indices2, updates2, shape)
result = scatter1 + scatter2
print(sess.run(result))
(aside: is there a better way to do this? I'm all ears.)
This gives the output
[ 1 2 10 3 4 11 5 6 12]
Yay! that worked!
Now lets' try to extend this to 2D.
'''
We want to interleave the *columns* (not rows; rows would be easy!) of
arr1 = [[1,2,3,4,5,6],[1,2,3,4,5,6],[1,2,3,4,5,6]]
arr2 = [[10 11 12], [10 11 12], [10 11 12]]
such that
desired result = [[1,2,10,3,4,11,5,6,12],[1,2,10,3,4,11,5,6,12],[1,2,10,3,4,11,5,6,12]]
'''
updates1 = tf.constant([[1,2,3,4,5,6],[1,2,3,4,5,6],[1,2,3,4,5,6]])
indices1 = tf.constant([[0], [1], [3], [4], [6], [7]])
shape = tf.constant([3, 9])
scatter1 = tf.scatter_nd(indices1, updates1, shape)
This gives the error
ValueError: The outer 1 dimensions of indices.shape=[6,1] must match the outer 1
dimensions of updates.shape=[3,6]: Dimension 0 in both shapes must be equal, but
are 6 and 3. Shapes are [6] and [3]. for 'ScatterNd_2' (op: 'ScatterNd') with
input shapes: [6,1], [3,6], [2].
Seems like my indices
is specifying row indices instead of column indices, and given the way that arrays are "connected" in numpy and tensorflow (i.e. row-major order), does that mean
I need to explicitly specify every single pair of indices for every element in updates1
?
Or is there some kind of 'wildcard' specification I can use for the rows? (Note indices1 = tf.constant([[:,0], [:,1], [:,3], [:,4], [:,6], [:,7]])
gives syntax errors, as it probably should.)
Would it be easier to just do a transpose, interleave the rows, then transpose back? Because I tried that...
scatter1 = tf.scatter_nd(indices1, tf.transpose(updates1), tf.transpose(shape))
print(sess.run(tf.transpose(scatter1)))
...and got a much longer error message, that I don't feel like posting unless someone requests it.
PS- I searched to make sure this isn't a duplicate -- I find it hard to imagine that someone else hasn't asked this before -- but turned up nothing.