1

I am trying to use a shared LSTM layer with state in a Keras model, but it seems that the internal state is modified by each parallel use. This raises two questions:

  1. When training a model with a shared LSTM layer and using stateful=True, are the parallel uses updating the same state also during training?
  2. If my observation is valid, is there a way to use weight-sharing LSTMs such that the state is stored independently for each of the parallel uses?

The code below exemplifies the problem with three sequences sharing the LSTM. The prediction of a full input is compared with the result from splitting the prediction input into two halves and feeding them into the network consecutively.

What can be observed, is that the a1 is the same as the first half of aFull, meaning that the the uses of the LSTM really are in parallel with independent states during the first prediction. I.e., z1 is not affected by the parallel call creating z2 and z3. But a2 is different from the second half of aFull, so there is some interaction between the states of the parallel uses.

What I was hoping is that the concatenation of the two pieces a1 and a2 would be the same as the result from calling the prediction with a longer input sequence, but this doesn't seem to be the case. A further concern is that when this kind of interaction takes place in the prediction, is it also happening during the training.

import keras
import keras.backend as K
import numpy as np

nOut = 3
xShape = (3, 50, 4)
inShape = (xShape[0], None, xShape[2])   
batchInShape = (1, ) + inShape
x = np.random.randn(*xShape)

# construct network
xIn = keras.layers.Input(shape=inShape, batch_shape=batchInShape)

# shared LSTM layer
sharedLSTM = keras.layers.LSTM(units=nOut, stateful=True, return_sequences=True, return_state=False)

# split the input on the first axis
x1 = keras.layers.Lambda(lambda x: x[:,0,:,:])(xIn)
x2 = keras.layers.Lambda(lambda x: x[:,1,:,:])(xIn)
x3 = keras.layers.Lambda(lambda x: x[:,2,:,:])(xIn)

# pass each input through the LSTM
z1 = sharedLSTM(x1)
z2 = sharedLSTM(x2)
z3 = sharedLSTM(x3)

# add a singleton dimension
y1 = keras.layers.Lambda(lambda x: K.expand_dims(x, axis=1))(z1)
y2 = keras.layers.Lambda(lambda x: K.expand_dims(x, axis=1))(z2)
y3 = keras.layers.Lambda(lambda x: K.expand_dims(x, axis=1))(z3)

# combine the outputs
y = keras.layers.Concatenate(axis=1)([y1, y2, y3])

model = keras.models.Model(inputs=xIn, outputs=y)
model.compile(loss='mse', optimizer='adam')
model.summary()

# no need to train, since we're interested only what is happening mechanically

# reset to a known state and predict for full input
model.reset_states()
aFull = model.predict(x[np.newaxis,:,:,:])

# reset to a known state and predict for the same input, but in two pieces
model.reset_states()
a1 = model.predict(x[np.newaxis,:,:xShape[1]//2,:])
a2 = model.predict(x[np.newaxis,:,xShape[1]//2:,:])
# combine the pieces
aSplit = np.concatenate((a1, a2), axis=2)

print('full diff: {}, first half diff: {}, second half diff: {}'.format(str(np.sum(np.abs(aFull - aSplit))), str(np.sum(np.abs(aFull[:,:,:xShape[1]//2,:] - aSplit[:,:,:xShape[1]//2,:]))), str(np.sum(np.abs(aFull[:,:,xShape[1]//2:,:] - aSplit[:,:,xShape[1]//2:,:])))))

Update: The behaviour described above was observed with Keras using Tensorflow 1.14 and 1.15 as the backend. Running the same code with tf2.0 (with the adjusted imports) changes the result so that a1 is no longer the same as the first half of aFull. This can be still accomplished by setting stateful=False in the layer instantiation.

This would suggest to me that the way I'm trying to use the recursive layer with shared parameters, but own states for parallel uses, is not really possible like this.

Update 2: It seems that the same functionality has been missed by also other earlier: closed, unanswered question at Keras' github.

For a comparison, here is a scribbling in pytorch (the first time I've tried to use it) implementing a simple network with N parallel LSTMs sharing the weights, but having independent states. In this case the states are stored explicitly in a list and provided to the LSTM cell manually.

import torch
import numpy as np

class sharedLSTM(torch.nn.Module):

    def __init__(self, batchSz, nBands, nDims, outDim):
        super(sharedLSTM, self).__init__()
        self.internalLSTM = torch.nn.LSTM(input_size=nDims, hidden_size=outDim, num_layers=1, bias=True, batch_first=True)
        allStates = list()
        for bandIdx in range(nBands):
            h_0 = torch.zeros(1, batchSz, outDim)
            c_0 = torch.zeros(1, batchSz, outDim)
            allStates.append((h_0, c_0))

        self.allStates = allStates            
        self.nBands = nBands

    def forward(self, x):
        allOut = list()
        for dimIdx in range(self.nBands):
            thisSlice = x[:,dimIdx,:,:] # (batchSz, nSteps, nFeats)
            thisState = self.allStates[dimIdx]

            thisY, thisState = self.internalLSTM(thisSlice, thisState) 
            self.allStates[dimIdx] = thisState
            allOut.append(thisY[:,None,:,:]) # => (batchSz, 1, nSteps, nFeats)

        y = torch.cat(allOut, dim=1) # => (batchSz, nDims, nSteps, nFeats)

        return y

    def resetStates(self):
        for bandIdx in range(nBands):
            self.allStates[bandIdx][0][:] = 0.0
            self.allStates[bandIdx][1][:] = 0.0


batchSz = 5
nBands = 3
nFeats = 4
nOutDims = 2
net = sharedLSTM(batchSz, nBands, nFeats, nOutDims)
net = net.float()
print(net)

N = 20
x = torch.from_numpy(np.random.rand(batchSz, nBands, N, nFeats)).float()
x1 = x[:, :, :N//2, :]
x2 = x[:, :, N//2:, :]

aa = net.forward(x)
net.resetStates()
a1 = net.forward(x1)
a2 = net.forward(x2)

print('(with reset) first half abs diff: {}'.format(str(torch.sum(torch.abs(a1 - aa[:,:,:N//2,:])).detach().numpy())))
print('(with reset) second half abs diff: {}'.format(str(torch.sum(torch.abs(a2 - aa[:,:,N//2:,:])).detach().numpy())))

Result: the output is the same regardless if we do the prediction in one go or in pieces.

I've tried to replicate this in Keras using sub-classing, but without success:

import keras
import numpy as np

class sharedLSTM(keras.Model):
    def __init__(self, batchSz, nBands, nDims, outDim):
        super(sharedLSTM, self).__init__()
        self.internalLSTM = keras.layers.LSTM(units=outDim, stateful=True, return_sequences=True, return_state=True)
        self.internalLSTM.build((batchSz, None, nDims))
        self.internalLSTM.reset_states()
        allStates = list()
        allSlicers = list()
        for bandIdx in range(nBands):
            allStates.append(None)
            allSlicers.append(keras.layers.Lambda(lambda x, b: x[:, :, b, :], arguments = {'b' : bandIdx}))

        self.allStates = allStates            
        self.allSlicers = allSlicers
        self.Concat = keras.layers.Lambda(lambda x: keras.backend.concatenate(x, axis=2))

        self.nBands = nBands

    def call(self, x):
        allOut = list()
        for bandIdx in range(self.nBands):
            thisSlice = self.allSlicers[bandIdx]( x )
            thisState = self.allStates[bandIdx]

            thisY, *thisState = self.internalLSTM(thisSlice, initial_state=thisState) 
            self.allStates[bandIdx] = thisState.copy()
            allOut.append(thisY[:,:,None,:]) 

        y = self.Concat( allOut )
        return y

batchSz = 1
nBands = 3
nFeats = 4
nOutDims = 2
N = 20

model = sharedLSTM(batchSz, nBands, nFeats, nOutDims)
model.compile(optimizer='SGD', loss='mae')

x = np.random.rand(batchSz, N, nBands, nFeats)
x1 = x[:, :N//2, :, :]
x2 = x[:, N//2:, :, :]

aa = model.predict(x)

model.reset_states()
a1 = model.predict(x1)
a2 = model.predict(x2)

print('(with reset) first half abs diff: {}'.format(str(np.sum(np.abs(a1 - aa[:,:N//2,:,:])))))
print('(with reset) second half abs diff: {}'.format(str(np.sum(np.abs(a2 - aa[:,N//2:,:,:])))))

If you now ask "why don't you then use torch and shut up?", the answer is that the surrounding experimental framework has been built assuming Keras and changing it would be a non-negligible amount of work.

paulus
  • 51
  • 4
  • [This answer](https://stackoverflow.com/questions/58276337/proper-way-to-feed-time-series-data-to-stateful-lstm/58277760#58277760) may help – OverLordGoldDragon Dec 12 '19 at 14:50
  • @OverLordGoldDragon That answer is clearly related, but missing the addition of sharing the LSTM layer. If we simplify the code above to omit the x2, x3, z2, z3, y2, and y3, i.e., run the LSTM without sharing, everything works as I would be expecting: It doesn't matter if the input to prediction is split into multiple shorter pieces. Concatenating the outputs gives the same result as feeding the full long sequence at once. The problem is that I would like to have N parallel sequences sharing the LSTM weights, but to have independent states. In the example code the states aren't independent. – paulus Dec 13 '19 at 08:46

1 Answers1

2

Based on my current understanding of the behaviour of LSTMs (and other RNNs) in Keras is that using a shared LSTM layer in a stateful=True mode does not work as one would expect and there is only one state variable that gets updated through all the parallel uses. So the answers to the questions appear to be:

  1. Yes, they are. The processing runs over one of the many parallel sequences, stores the state at the end, and uses this as the initial state for the second parallel sequence, and so forth.
  2. Yes, but it requires some work. See below for the details.

I've managed to accomplish handling the states in two ways. First is deriving sub-classes from Keras' LSTM and LSTMCell, and overloading LSTMCell.call() to handle the parallel data streams by splitting the input, and storing and recovering the state of each parallel stream. A drawback here is that the input shape to an RNN is fixed to be 3D, which means that the parallel inputs need to be reshaped into the feature dimension along with the real features.

The second approach is to create a wrapper Layer not completely dissimilar to the sharedLSTM-Model in the question, containing slicing of the input to parallel streams, calling the internal LSTM with the correct state for each stream, and storing the returned states. The state storage update in the list works through add_update() call inserted into the end of call(). This add_update() does not (seem to) work with Model, hence Layer. However, when run with Keras <2.3, the weights of the nested layers are not tracked or updated, so Keras 2.3+ or TF2 is needed.

paulus
  • 51
  • 4