In TensorFlow, I miss a straightforward way to assign something to a slice of the flattened view of a variable.
Here's an example to achieve the same result in a roundabout way:
var = tf.Variable(tf.reshape(tf.range(12), [4,3]))
# <tf.Variable 'Variable:0' shape=(4, 3) dtype=int32, numpy=
# array([[ 0, 1, 2],
# [ 3, 4, 5],
# [ 6, 7, 8],
# [ 9, 10, 11]], dtype=int32)>
flat_indices = tf.range(4, 8)
multi_dim_indices = tf.transpose(tf.unravel_index(flat_indices, dims=[4,3]))
# <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
# array([[1, 1],
# [1, 2],
# [2, 0],
# [2, 1]], dtype=int32)>
update = [40, 50, 60, 70]
var.scatter_nd_update(multi_dim_indices, update)
# <tf.Variable 'UnreadVariable' shape=(4, 3) dtype=int32, numpy=
# array([[ 0, 1, 2],
# [ 3, 40, 50],
# [60, 70, 8],
# [ 9, 10, 11]], dtype=int32)>
But that's not an efficient solution for large tensors. Building multi_dim_indices
should be unnecessary. scatter_nd_update
is a sparse operation, but what I am looking for is really a dense assignment to a consecutive stretch of memory.
With a numpy-like API I could write:
var.flat[4:8] = update
Is there an efficient way to achieve the same result in TensorFlow, maybe with an uglier API?