0

In TensorFlow, I miss a straightforward way to assign something to a slice of the flattened view of a variable.

Here's an example to achieve the same result in a roundabout way:

var = tf.Variable(tf.reshape(tf.range(12), [4,3]))

# <tf.Variable 'Variable:0' shape=(4, 3) dtype=int32, numpy=
# array([[ 0,  1,  2],
#        [ 3,  4,  5],
#        [ 6,  7,  8],
#        [ 9, 10, 11]], dtype=int32)>

flat_indices = tf.range(4, 8)
multi_dim_indices = tf.transpose(tf.unravel_index(flat_indices, dims=[4,3]))

# <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
# array([[1, 1],
#        [1, 2],
#        [2, 0],
#        [2, 1]], dtype=int32)>

update = [40, 50, 60, 70]

var.scatter_nd_update(multi_dim_indices, update)

# <tf.Variable 'UnreadVariable' shape=(4, 3) dtype=int32, numpy=
# array([[ 0,  1,  2],
#        [ 3, 40, 50],
#        [60, 70,  8],
#        [ 9, 10, 11]], dtype=int32)>

But that's not an efficient solution for large tensors. Building multi_dim_indices should be unnecessary. scatter_nd_update is a sparse operation, but what I am looking for is really a dense assignment to a consecutive stretch of memory.

With a numpy-like API I could write:

var.flat[4:8] = update

Is there an efficient way to achieve the same result in TensorFlow, maybe with an uglier API?

rerx
  • 1,133
  • 8
  • 19

1 Answers1

0

You can do a similar operation in as follows:

var = tf.Variable(tf.reshape(tf.range(12), [4,3]))
var = tf.Variable(tf.reshape(var, [-1]))  # flatten the vector 
var[4:8].assign([40, 50, 60, 70]) # update / assign item in particular indices 
var = tf.Variable(tf.reshape(var, [4,3])) # reshape back to original shape
var

<tf.Variable 'Variable:0' shape=(4, 3) dtype=int32, numpy=
array([[ 0,  1,  2],
       [ 3, 40, 50],
       [60, 70,  8],
       [ 9, 10, 11]], dtype=int32)>
Innat
  • 16,113
  • 6
  • 53
  • 101
  • Thanks. Are these guaranteed to not do any extra memory allocations and copies? – rerx May 31 '21 at 08:04
  • It (`tf.Variable.assign`) assigns the new values to the particular indices in place. It doesn't create new tensor. – Innat May 31 '21 at 08:12
  • And `var = tf.Variable(tf.reshape(var, [-1]))`? – rerx May 31 '21 at 08:18
  • The difference between `tf. Variable` and `tf. tensor` is mutability, [details](https://stackoverflow.com/questions/44167134/whats-the-difference-between-tensor-and-variable-in-tensorflow/44167844#44167844). The `tf.reshape` returns a **tensor** and that's why you can't use it to update the old tensor but have to create a new tensor. – Innat May 31 '21 at 08:40