3

How do I change a single value of a Tensor inside of a while loop? I know that I can manipulate a single value of a tf.Variable using tf.scatter_update(variable, index, value), but inside of a loop I cannot access variables. Is there a way/workaround to manipulate a given value of a Tensor inside of a while loop.

For reference, here is my current code:

my_variable = tf.Variable()

def body(i, my_variable):
    [...]
    return tf.add(i, 1), tf.scatter_update(my_variable, [index], value)


loop = tf.while_loop(lambda i, _: tf.less(i, 5), body, [0, my_variable])
Ian Rehwinkel
  • 2,486
  • 5
  • 22
  • 56
  • Check this https://stackoverflow.com/questions/51419333/modify-a-tensorflow-variable-inside-a-loop – TassosK Feb 10 '19 at 11:45
  • @TassosK while_loop converts my `tf.Variable` into a `tf.Tensor`, so I get `TypeError: 'Tensor' object does not support item assignment` when I do what is said in that post. – Ian Rehwinkel Feb 10 '19 at 11:50

1 Answers1

3

Inspired by this post you could use a sparse tensor to store the delta to the value you want to assign and then use addition to "set" that value. E.g. like this (I'm assuming some shapes/values here, but it should be straight-forward to generalize it to tensors of higher rank):

import tensorflow as tf

my_variable = tf.Variable(tf.ones([5]))

def body(i, v):
    index = i
    new_value = 3.0
    delta_value = new_value - v[index:index+1]
    delta = tf.SparseTensor([[index]], delta_value, (5,))
    v_updated = v + tf.sparse_tensor_to_dense(delta)
    return tf.add(i, 1), v_updated


_, updated = tf.while_loop(lambda i, _: tf.less(i, 5), body, [0, my_variable])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(my_variable))
    print(sess.run(updated))

This prints

[1. 1. 1. 1. 1.]
[3. 3. 3. 3. 3.]
kafman
  • 2,862
  • 1
  • 29
  • 51