0

I am trying to do the following

state[0,:] = state[0,:].assign( 0.9*prev_state + 0.1*( tf.matmul(inputs, weights) + biases ) )
for i in xrange(1,BATCH_SIZE):
    state[i,:] = state[i,:].assign( 0.9*state[i-1,:] + 0.1*( tf.matmul(inputs, weights) + biases ) )
prev_state = prev_state.assign( state[BATCH_SIZE-1,:] )

with

state = tf.Variable(tf.zeros([BATCH_SIZE, HIDDEN_1]), name='inner_state')
prev_state = tf.Variable(tf.zeros([HIDDEN_1]), name='previous_inner_state')

As a follow-up for this question. I get an error that Tensor does not have an assign method.

What is the correct way to call the assign method on a slice of a Variable tensor?


Full current code:

import tensorflow as tf
import math
import numpy as np

INPUTS = 10
HIDDEN_1 = 20
BATCH_SIZE = 3


def create_graph(inputs, state, prev_state):
    with tf.name_scope('h1'):
        weights = tf.Variable(
        tf.truncated_normal([INPUTS, HIDDEN_1],
                            stddev=1.0 / math.sqrt(float(INPUTS))),
        name='weights')
        biases = tf.Variable(tf.zeros([HIDDEN_1]), name='biases')

        updated_state = tf.scatter_update(state, [0], 0.9 * prev_state + 0.1 * (tf.matmul(inputs[0,:], weights) + biases))
        for i in xrange(1, BATCH_SIZE):
          updated_state = tf.scatter_update(
              updated_state, [i], 0.9 * updated_state[i-1, :] + 0.1 * (tf.matmul(inputs[i,:], weights) + biases))

        prev_state = prev_state.assign(updated_state[BATCH_SIZE-1, :])
        output = tf.nn.relu(updated_state)
    return output

def data_iter():
    while True:
        idxs = np.random.rand(BATCH_SIZE, INPUTS)
        yield idxs

with tf.Graph().as_default():
    inputs = tf.placeholder(tf.float32, shape=(BATCH_SIZE, INPUTS))
    state = tf.Variable(tf.zeros([BATCH_SIZE, HIDDEN_1]), name='inner_state')
    prev_state = tf.Variable(tf.zeros([HIDDEN_1]), name='previous_inner_state')

    output = create_graph(inputs, state, prev_state)

    sess = tf.Session()
    # Run the Op to initialize the variables.
    init = tf.initialize_all_variables()
    sess.run(init)
    iter_ = data_iter()
    for i in xrange(0, 2):
        print ("iteration: ",i)
        input_data = iter_.next()
        out = sess.run(output, feed_dict={ inputs: input_data})
Community
  • 1
  • 1
diffeomorphism
  • 991
  • 2
  • 10
  • 27

1 Answers1

3

Tensorflow Variable objects have limited support for updating slices, using the tf.scatter_update(), tf.scatter_add(), and tf.scatter_sub() ops. Each of these ops allows you to specify a variable, a vector of slice indices (representing indices in the 0th dimension of the variable, which indicate the contiguous slices to be mutated) and a tensor of values (representing the new values to be applied to the variable, at the corresponding slice indices).

To update a single row of the variable, you can use tf.scatter_update(). For example, to update the 0th row of state, you would do:

updated_state = tf.scatter_update(
    state, [0], 0.9 * prev_state + 0.1 * (tf.matmul(inputs, weights) + biases))

To chain multiple updates, you can use the mutable updated_state tensor that is returned from tf.scatter_update():

for i in xrange(1, BATCH_SIZE):
  updated_state = tf.scatter_update(
      updated_state, [i], 0.9 * updated_state[i-1, :] + ...)

prev_state = prev_state.assign(updated_state[BATCH_SIZE-1, :])

Finally, you can evaluate the resulting updated_state.op to apply all of the updates to state:

sess.run(updated_state.op)  # or `sess.run(updated_state)` to fetch the result

PS. You might find it more efficient to use tf.scan() to compute the intermediate states, and just materialize prev_state in a variable.

mrry
  • 125,488
  • 26
  • 399
  • 400
  • thanks, I am a bit confused how would I use `scan` on this case, shouldn't I have first to built a graph with ops and then use scan on the output nodes? – diffeomorphism Jun 02 '16 at 07:54
  • the other issue I'm having is that since I'm breaking down the graph on the batch element dimension, `tf.matmul(inputs, weights)` doesn't just work, but the line `updated_state = tf.scatter_update(state, [0], 0.9 * prev_state + 0.1 * (tf.matmul(inputs[0,:], weights) + biases))` gives: `ValueError: Shape (10,) must have rank 2`. I will update the OP with full code – diffeomorphism Jun 02 '16 at 09:09