5

I'm trying to implement encoder-decoder type network in Keras, with Bidirectional GRUs.

The following code seems to be working

src_input = Input(shape=(5,))
ref_input = Input(shape=(5,))

src_embedding = Embedding(output_dim=300, input_dim=vocab_size)(src_input)
ref_embedding = Embedding(output_dim=300, input_dim=vocab_size)(ref_input)

encoder = Bidirectional(
                GRU(2, return_sequences=True, return_state=True)
        )(src_embedding)

decoder = GRU(2, return_sequences=True)(ref_embedding, initial_state=encoder[1])

But when I change the decode to use Bidirectional wrapper, it stops showing encoder and src_input layers in the model.summary(). The new decoder looks like:

decoder = Bidirectional(
                GRU(2, return_sequences=True)
        )(ref_embedding, initial_state=encoder[1:])

The output of model.summary() with the Bidirectional decoder.

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         (None, 5)                 0         
_________________________________________________________________
embedding_2 (Embedding)      (None, 5, 300)            6610500   
_________________________________________________________________
bidirectional_2 (Bidirection (None, 5, 4)              3636      
=================================================================
Total params: 6,614,136
Trainable params: 6,614,136
Non-trainable params: 0
_________________________________________________________________

Question: Am I missing something when I pass initial_state in Bidirectional decoder? How can I fix this? Is there any other way to make this work?

nisargjhaveri
  • 1,469
  • 11
  • 21

1 Answers1

1

It's a bug. The RNN layer implements __call__ so that tensors in initial_state can be collected into a model instance. However, the Bidirectional wrapper did not implement it. So topological information about the initial_state tensors is missing and some strange bugs happen.

I wasn't aware of it when I was implementing initial_state for Bidirectional. It should be fixed now, after this PR. You can install the latest master branch on GitHub to fix it.

Yu-Yang
  • 14,539
  • 2
  • 55
  • 62
  • Thanks, it worked! :) Btw, side question, what is expected release cycle for Keras? When is it going to be available in a release? – nisargjhaveri Jan 31 '18 at 04:52
  • Well I'm not sure. It seems that Keras doesn't have a fixed time frame to release new version (or just I don't know about it). I think it depends on the decision of the project owner. – Yu-Yang Jan 31 '18 at 05:37
  • Thanks. Also, not sure if this is the right place, but Bidirectional wrapper maybe should also support `constants` argument, which RNN supports. – nisargjhaveri Jan 31 '18 at 06:02
  • I totally agree. But I don't have a plan (and time) to implement it right now, since this feature is not so critical to my own project (as I only use unidirectional decoder with attention). If you find it urgent, maybe you can submit a feature request on Keras issue board. Implementing it and submit a PR to Keras would be even better :) – Yu-Yang Jan 31 '18 at 06:18
  • Sure! I'll try to add a PR :) – nisargjhaveri Jan 31 '18 at 06:21
  • Again, side question, I'm struggling with adding attention to encoder-decoder with Keras. Is there any good example you can point me to? I tried to search a lot, but can't get a clear idea. Thanks a lot for your time! :) – nisargjhaveri Jan 31 '18 at 06:26
  • I couldn't find one either, so I've implemented my own wrapper classes from scratch (also because I need some additional functionalities). Unfortunately I can't share my code here. I've recently tested [this PR](https://github.com/keras-team/keras/pull/8296#issuecomment-358887309) and it seems working fine after some small fixes. You can find out what I've modified in [this branch](https://github.com/myutwo150/keras/compare/76ff23e663fa7d279e0a977df07c9e97b12c671a...myutwo150:test-attention). There's also a simple dot product attention wrapper `MultiplicativeAttention` in it. – Yu-Yang Jan 31 '18 at 07:50