I am trying to use a shared LSTM layer with state in a Keras model, but it seems that the internal state is modified by each parallel use. This raises two questions:
- When training a model with a shared LSTM layer and using
stateful=True
, are the parallel uses updating the same state also during training? - If my observation is valid, is there a way to use weight-sharing LSTMs such that the state is stored independently for each of the parallel uses?
The code below exemplifies the problem with three sequences sharing the LSTM. The prediction of a full input is compared with the result from splitting the prediction input into two halves and feeding them into the network consecutively.
What can be observed, is that the a1
is the same as the first half of aFull
, meaning that the the uses of the LSTM really are in parallel with independent states during the first prediction. I.e., z1
is not affected by the parallel call creating z2
and z3
. But a2
is different from the second half of aFull
, so there is some interaction between the states of the parallel uses.
What I was hoping is that the concatenation of the two pieces a1
and a2
would be the same as the result from calling the prediction with a longer input sequence, but this doesn't seem to be the case. A further concern is that when this kind of interaction takes place in the prediction, is it also happening during the training.
import keras
import keras.backend as K
import numpy as np
nOut = 3
xShape = (3, 50, 4)
inShape = (xShape[0], None, xShape[2])
batchInShape = (1, ) + inShape
x = np.random.randn(*xShape)
# construct network
xIn = keras.layers.Input(shape=inShape, batch_shape=batchInShape)
# shared LSTM layer
sharedLSTM = keras.layers.LSTM(units=nOut, stateful=True, return_sequences=True, return_state=False)
# split the input on the first axis
x1 = keras.layers.Lambda(lambda x: x[:,0,:,:])(xIn)
x2 = keras.layers.Lambda(lambda x: x[:,1,:,:])(xIn)
x3 = keras.layers.Lambda(lambda x: x[:,2,:,:])(xIn)
# pass each input through the LSTM
z1 = sharedLSTM(x1)
z2 = sharedLSTM(x2)
z3 = sharedLSTM(x3)
# add a singleton dimension
y1 = keras.layers.Lambda(lambda x: K.expand_dims(x, axis=1))(z1)
y2 = keras.layers.Lambda(lambda x: K.expand_dims(x, axis=1))(z2)
y3 = keras.layers.Lambda(lambda x: K.expand_dims(x, axis=1))(z3)
# combine the outputs
y = keras.layers.Concatenate(axis=1)([y1, y2, y3])
model = keras.models.Model(inputs=xIn, outputs=y)
model.compile(loss='mse', optimizer='adam')
model.summary()
# no need to train, since we're interested only what is happening mechanically
# reset to a known state and predict for full input
model.reset_states()
aFull = model.predict(x[np.newaxis,:,:,:])
# reset to a known state and predict for the same input, but in two pieces
model.reset_states()
a1 = model.predict(x[np.newaxis,:,:xShape[1]//2,:])
a2 = model.predict(x[np.newaxis,:,xShape[1]//2:,:])
# combine the pieces
aSplit = np.concatenate((a1, a2), axis=2)
print('full diff: {}, first half diff: {}, second half diff: {}'.format(str(np.sum(np.abs(aFull - aSplit))), str(np.sum(np.abs(aFull[:,:,:xShape[1]//2,:] - aSplit[:,:,:xShape[1]//2,:]))), str(np.sum(np.abs(aFull[:,:,xShape[1]//2:,:] - aSplit[:,:,xShape[1]//2:,:])))))
Update: The behaviour described above was observed with Keras using Tensorflow 1.14 and 1.15 as the backend. Running the same code with tf2.0 (with the adjusted imports) changes the result so that a1
is no longer the same as the first half of aFull
. This can be still accomplished by setting stateful=False
in the layer instantiation.
This would suggest to me that the way I'm trying to use the recursive layer with shared parameters, but own states for parallel uses, is not really possible like this.
Update 2: It seems that the same functionality has been missed by also other earlier: closed, unanswered question at Keras' github.
For a comparison, here is a scribbling in pytorch (the first time I've tried to use it) implementing a simple network with N parallel LSTMs sharing the weights, but having independent states. In this case the states are stored explicitly in a list and provided to the LSTM cell manually.
import torch
import numpy as np
class sharedLSTM(torch.nn.Module):
def __init__(self, batchSz, nBands, nDims, outDim):
super(sharedLSTM, self).__init__()
self.internalLSTM = torch.nn.LSTM(input_size=nDims, hidden_size=outDim, num_layers=1, bias=True, batch_first=True)
allStates = list()
for bandIdx in range(nBands):
h_0 = torch.zeros(1, batchSz, outDim)
c_0 = torch.zeros(1, batchSz, outDim)
allStates.append((h_0, c_0))
self.allStates = allStates
self.nBands = nBands
def forward(self, x):
allOut = list()
for dimIdx in range(self.nBands):
thisSlice = x[:,dimIdx,:,:] # (batchSz, nSteps, nFeats)
thisState = self.allStates[dimIdx]
thisY, thisState = self.internalLSTM(thisSlice, thisState)
self.allStates[dimIdx] = thisState
allOut.append(thisY[:,None,:,:]) # => (batchSz, 1, nSteps, nFeats)
y = torch.cat(allOut, dim=1) # => (batchSz, nDims, nSteps, nFeats)
return y
def resetStates(self):
for bandIdx in range(nBands):
self.allStates[bandIdx][0][:] = 0.0
self.allStates[bandIdx][1][:] = 0.0
batchSz = 5
nBands = 3
nFeats = 4
nOutDims = 2
net = sharedLSTM(batchSz, nBands, nFeats, nOutDims)
net = net.float()
print(net)
N = 20
x = torch.from_numpy(np.random.rand(batchSz, nBands, N, nFeats)).float()
x1 = x[:, :, :N//2, :]
x2 = x[:, :, N//2:, :]
aa = net.forward(x)
net.resetStates()
a1 = net.forward(x1)
a2 = net.forward(x2)
print('(with reset) first half abs diff: {}'.format(str(torch.sum(torch.abs(a1 - aa[:,:,:N//2,:])).detach().numpy())))
print('(with reset) second half abs diff: {}'.format(str(torch.sum(torch.abs(a2 - aa[:,:,N//2:,:])).detach().numpy())))
Result: the output is the same regardless if we do the prediction in one go or in pieces.
I've tried to replicate this in Keras using sub-classing, but without success:
import keras
import numpy as np
class sharedLSTM(keras.Model):
def __init__(self, batchSz, nBands, nDims, outDim):
super(sharedLSTM, self).__init__()
self.internalLSTM = keras.layers.LSTM(units=outDim, stateful=True, return_sequences=True, return_state=True)
self.internalLSTM.build((batchSz, None, nDims))
self.internalLSTM.reset_states()
allStates = list()
allSlicers = list()
for bandIdx in range(nBands):
allStates.append(None)
allSlicers.append(keras.layers.Lambda(lambda x, b: x[:, :, b, :], arguments = {'b' : bandIdx}))
self.allStates = allStates
self.allSlicers = allSlicers
self.Concat = keras.layers.Lambda(lambda x: keras.backend.concatenate(x, axis=2))
self.nBands = nBands
def call(self, x):
allOut = list()
for bandIdx in range(self.nBands):
thisSlice = self.allSlicers[bandIdx]( x )
thisState = self.allStates[bandIdx]
thisY, *thisState = self.internalLSTM(thisSlice, initial_state=thisState)
self.allStates[bandIdx] = thisState.copy()
allOut.append(thisY[:,:,None,:])
y = self.Concat( allOut )
return y
batchSz = 1
nBands = 3
nFeats = 4
nOutDims = 2
N = 20
model = sharedLSTM(batchSz, nBands, nFeats, nOutDims)
model.compile(optimizer='SGD', loss='mae')
x = np.random.rand(batchSz, N, nBands, nFeats)
x1 = x[:, :N//2, :, :]
x2 = x[:, N//2:, :, :]
aa = model.predict(x)
model.reset_states()
a1 = model.predict(x1)
a2 = model.predict(x2)
print('(with reset) first half abs diff: {}'.format(str(np.sum(np.abs(a1 - aa[:,:N//2,:,:])))))
print('(with reset) second half abs diff: {}'.format(str(np.sum(np.abs(a2 - aa[:,N//2:,:,:])))))
If you now ask "why don't you then use torch and shut up?", the answer is that the surrounding experimental framework has been built assuming Keras and changing it would be a non-negligible amount of work.