0

I have a RNN like structure that has some building blocks (component neural networks) that are passed in by the user. Here is a minimal example:

import tensorflow as tf
tf.reset_default_graph()

def initialize(shape):
    init = tf.random_normal(shape, mean=0, stddev=0.1, dtype=tf.float32)
    return init

def test_rnn_with_external(input, hiddens, external_fct):
    """
    A simple rnn that makes the standard update, then
    feeds the new hidden state through some external
    function.
    """
    dim_in = input.get_shape().as_list()[-1]
    btsz = input.get_shape().as_list()[1]
    shape = (dim_in + hiddens, hiddens)
    _init = initialize(shape)
    W = tf.get_variable("rnn_w", initializer=_init)
    _init = tf.zeros([hiddens])
    b = tf.get_variable("rnn_b", initializer=_init)

    def _step(previous, input):
        concat = tf.concat(1, [input, previous])     
        h_t = tf.tanh(tf.add(tf.matmul(concat, W), b))

        h_t = external_fct(h_t)

        return h_t

    h_0 = tf.zeros([btsz, hiddens])
    states = tf.scan(_step,
                     input,
                     initializer=h_0,
                     name="states")
    return states

# the external function, relying on the templating mechanism.
def ext_fct(hiddens):
    """
    """
    def tmp(input):
        shape = (hiddens, hiddens)
        _init = initialize(shape)
        W = tf.get_variable("ext_w", initializer=_init)
        b = 0
        return tf.add(tf.matmul(input, W), b, name="external")
    return tf.make_template(name_="external_fct", func_=tmp)

# run from here on
t = 5
btsz = 4
dim = 2
hiddens = 3

x = tf.placeholder(tf.float32, shape=(t, btsz, dim))
ext = ext_fct(hiddens)

states = test_rnn_with_external(x, hiddens, external_fct=ext)

sess = tf.InteractiveSession()
sess.run(tf.initialize_all_variables())

with the error ending in:

InvalidArgumentError: All inputs to node external_fct/ext_w/Assign must be from the same frame.

With Frame, I would associate an area on the stack. So I thought that maybe tf.make_template does something very wired, and thus it is not useable here. The external function can be rewritten a bit and then called more directly, like so:

import tensorflow as tf
tf.reset_default_graph()

def initialize(shape):
    init = tf.random_normal(shape, mean=0, stddev=0.1, dtype=tf.float32)
    return init

def test_rnn_with_external(input, hiddens, external_fct):
    dim_in = input.get_shape().as_list()[-1]
    btsz = input.get_shape().as_list()[1]
    shape = (dim_in + hiddens, hiddens)
    _init = initialize(shape)
    W = tf.get_variable("rnn_w", initializer=_init)
    _init = tf.zeros([hiddens])
    b = tf.get_variable("rnn_b", initializer=_init)

    def _step(previous, input):
        """
        """
        concat = tf.concat(1, [input, previous])     
        h_t = tf.tanh(tf.add(tf.matmul(concat, W), b))

        h_t = external_fct(h_t, hiddens)

        return h_t

    h_0 = tf.zeros([btsz, hiddens])
    states = tf.scan(_step,
                     input,
                     initializer=h_0,
                     name="states")
    return states

def ext_fct_new(input, hiddens):
    """
    """
    shape = (hiddens, hiddens)
    _init = initialize(shape)
    W = tf.get_variable("ext_w_new", initializer=_init)
    b = 0
    return tf.add(tf.matmul(input, W), b, name="external_new")

t = 5
btsz = 4
dim = 2
hiddens = 3
x = tf.placeholder(tf.float32, shape=(t, btsz, dim))

states = test_rnn_with_external(x, hiddens, external_fct=ext_fct_new)

sess = tf.InteractiveSession()
sess.run(tf.initialize_all_variables())

However, still the same error InvalidArgumentError: All inputs to node ext_w_new/Assign must be from the same frame.

Of course, moving contents of the external function into the _step part (and tf.get_variableing before) works. But then the flexibility (necessary in the original code) is gone.

What am I doing wrong? Any help/tips/pointers is greatly appreciated.

(Note: Asked this on github, too: https://github.com/tensorflow/tensorflow/issues/4478)

osdf
  • 818
  • 10
  • 20

1 Answers1

0

Using a tf.constant_initializer solves the problem. This is described here.

Community
  • 1
  • 1
osdf
  • 818
  • 10
  • 20