2

I want to update a slice of a tensor with 3 dimensions. Following How to do slice assignment in Tensorflow I would do something like

import tensorflow as tf

with tf.Session() as sess:
    init_val = tf.Variable(tf.zeros((2, 3, 3)))
    indices = tf.constant([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1]])
    update = tf.scatter_nd_add(init_val, indices, tf.ones(4))

    init = tf.global_variables_initializer()
    sess.run(init)
    print(sess.run(update))

This works, but since my actual problem is more complex I would like to generate the set of indices somehow automatically by defining the beginning and the size of the slice, such as if you would use tf.slice(...). Do you have any ideas? Thanks in advance!

I am using TensorFlow 1.12, which is currently the most recent release.

1 Answers1

3

tf.strided_slice supports passing a var parameter to indicate the variable the slice refers to, so when you pass it it will return an assignable object (I'm not sure why they didn't just do that depending on the type of the input, but whatever). You can do something like this:

import tensorflow as tf
import numpy as np

var = tf.Variable(np.ones((3, 4), dtype=np.float32))
s = tf.strided_slice(var, [0, 2], [2, 3], var=var, name='var_slice')
s2 = s.assign([[2], [3]])
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init_op)
    print(sess.run(s2))

Output:

[[1. 1. 2. 1.]
 [1. 1. 3. 1.]
 [1. 1. 1. 1.]]

Note that in tf.strided_slice you give the begin and end indices (end not included), unlike intf.slice, where you give begin and size. Also, as the code currently stands, you have to provide a name value either in for the slice or the assign operation (I feel this should be a bug and happens because that part of the API is used almost exclusively internally).

jdehesa
  • 58,456
  • 7
  • 77
  • 121