There isn't a straightforward Keras implementation, as Keras enforces the batch axis (sampels dimension, dimension 0) as fixed for the input & output layers (but not all layers in-between) - whereas you seek to collapse it by averaging. There is, however, a workaround - see below:
import tensorflow.keras.backend as K
from tensorflow.keras.layers import Input, Dense, GRU, Lambda
from tensorflow.keras.layers import Reshape, GlobalAveragePooling1D
from tensorflow.keras.models import Model
from tensorflow.keras.utils import plot_model
import numpy as np
def make_model(batch_shape):
ipt = Input(batch_shape=batch_shape)
x = Lambda(lambda x: K.squeeze(x, 0))(ipt)
x, s = GRU(4, return_state=True)(x) # s == last returned state
x = Lambda(lambda x: K.expand_dims(x, 0))(s)
x = GlobalAveragePooling1D()(x) # averages along axis1 (original axis2)
x = Dense(32, activation='relu')(x)
out = Dense(1, activation='sigmoid')(x)
model = Model(ipt, out)
model.compile('adam', 'binary_crossentropy')
return model
def make_data(batch_shape):
return (np.random.randn(*batch_shape),
np.random.randint(0, 2, (batch_shape[0], 1)))
m, timesteps = 16, 100
batch_shape = (1, m, timesteps, 1)
model = make_model(batch_shape)
model.summary() # see model structure
plot_model(model, show_shapes=True)
x, y = make_data(batch_shape)
model.train_on_batch(x, y)
Above assumes the task is binary classification, but you can easily adapt it to anything else - the main task's tricking Keras by feeding m
samples as 1
, and the rest of layers can freely take m
instead as Keras doesn't enforce the 1
there.
Note, however, that I cannot guarantee this'll work as intended per the following:
- Keras treats all entries along the batch axis as independent, whereas your samples are claimed as dependent
- Per (1), the main concern is backpropagation: I'm not really sure how gradient will flow with all the dimensionality shuffling going on.
- (1) is also consequential for stateful RNNs, as Keras constructs
batch_size
number of independent states, which'll still likely behave as intended as all they do is keep memory, but still worth understanding fully - see here
(2) is the "elephant in the room", but aside that, the model fits your exact description. Chances are, if you've planned out forward-prop and all dims agree w/ code's, it'll work as intended - else, and also for sanity-check, I'd suggest opening another question to verify gradients flow as you intend them to per above code.
model.summary()
:
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(1, 32, 100, 1)] 0
_________________________________________________________________
lambda (Lambda) (32, 100, 1) 0
_________________________________________________________________
gru (GRU) [(32, 16), (32, 16)] 864
_________________________________________________________________
lambda_1 (Lambda) (1, 32, 16) 0
_________________________________________________________________
global_average_pooling1d (Gl (1, 16) 0
_________________________________________________________________
dense (Dense) (1, 8) 136
_________________________________________________________________
dense_1 (Dense) (1, 1) 9
On LSTMs: will return two last states, one for cell state, one for hidden state - see source code; you should understand what this exactly means if you are to use it. If you do, you'll need concatenate
:
from tensorflow.keras.layers import concatenate
# ...
x, s1, s2 = LSTM(return_state=True)(x)
x = concatenate([s1, s2], axis=-1)
# ...