0

My question is related to this question*.

Is it possible to transform standard tensorflow layers into 'cells', to be used together with RNN cells to compose recurrent neural networks?

So, the new 'cell' should store the parameters (weights, ...), and be able to be called on varying inputs. Something like this:

from tf.nn import batch_normalization, conv2d
from tf.contrib.rnn import MultiRNNCell, LSTMCell

bn_cell = cell_creation_fun(batch_normalization, otherparams) # batch norm cell
conv_cell = cell_creation_fun(conv2d, otherparams )           # non-rnn conv cell
# or `conv_cell = cell_creation_fun(tf.layers.Conv2D, otherparams )` # using tf.layers  

So that they can be used like this:

multi_cell = MultiRNNCell([LSTMCell(...), conv_cell, bn_cell])

Or like this:

h = ...
conv_h, _ = conv_cell(h, state=None)
normed_h, _ = bn_cell(h, state=None)

The only thing I could think of is manually writing such a 'cell' for every layer I want to use, subclassing RNNCell. But it doesn't seem straightforward to use existing functions like Conv2D without being able to pass an ´input´ parameter during creation. (Will post code when I manage.)


* Maybe asking in a more targeted way has a chance of an answer.

dasWesen
  • 579
  • 2
  • 11
  • 28

1 Answers1

0

Ok, here's what I have so far:

class LayerCell(rnn_cell_impl.RNNCell):

    def __init__(self, tf_layer, **kwargs):
        ''' :param tf_layer: a tensorflow layer, e.g. tf.layers.Conv2D or 
            tf.keras.layers.Conv2D. NOT tf.layers.conv2d !'''
        self.layer_fn = tf_layer(**kwargs)

    def __call__(self, inputs, state, scope=None):
        ''' Every `RNNCell` must implement `call` with
          the signature `(output, next_state) = call(input, state)`.  The optional
          third input argument, `scope`, is allowed for backwards compatibility
          purposes; but should be left off for new subclasses.'''
        return (self.layer_fn(inputs), state)

    def __str__(self):
            return "Cell wrapper of " + str(self.layer_fn)

    def __getattr__(self, attr):
        '''credits to https://stackoverflow.com/questions/1382871/dynamically-attaching-a-method-to-an-existing-python-object-generated-with-swig/1383646#1383646'''
        return getattr(self.layer_fn, attr)

    @property
    def state_size(self):
        """size(s) of state(s) used by this cell.

        It can be represented by an Integer, a TensorShape or a tuple of Integers
        or TensorShapes.
        """
        return  (0,) 

    @property
    def output_size(self):
        """Integer or TensorShape: size of outputs produced by this cell."""
        # use with caution; could be uninitialized
        return self.layer_fn.output_shape

(Naturally, don't use with recurrent layers because state-keeping will be destroyed.)

Seems to work with: tf.layers.Conv2D, tf.keras.layers.Conv2D, tf.keras.layers.Activation, tf.layers.BatchNormalization

Does NOT work with: tf.keras.layers.BatchNormalization. At least it failed for me when using it in a tf.while loop; complaining about combining variables from different frames, similar to here. Maybe keras uses tf.Variable() instead of tf.get_variable() ...?


Usuage:

cell0 = tf.contrib.rnn.ConvLSTMCell(conv_ndims=2, input_shape=[40, 40, 3], output_channels=16, kernel_shape=[5, 5])
cell1 = LayerCell(tf.keras.layers.Conv2D, filters=8, kernel_size=[5, 5], strides=(1, 1), padding='same')
cell2 = LayerCell(tf.layers.BatchNormalization, axis=-1)

inputs =  np.random.rand(10, 40, 40, 3).astype(np.float32)
multicell = tf.contrib.rnn.MultiRNNCell([cell0, cell1, cell2])
state = multicell.zero_state(batch_size=10, dtype=tf.float32)

output = multicell(inputs, state)
print("Yippee!")
dasWesen
  • 579
  • 2
  • 11
  • 28