I am trying to implement a LSTM VAE (following this example I found), but also have it accept variable length sequences using Masking Layers. I tried to combine the above code with the ideas from this SO question that seems to deal with it the "best way" by cropping the gradients to get the most accurate loss as possible, however my implementation does not seem to be able to reproduce sequences on a small set of data. I am thus relatively confident that there is something amiss with my implementation, but I cannot seem to pinpoint what exactly is wrong. The relevant part is here:
x = Input(shape=(None, input_dim))(x)
x_masked = Masking(mask_value=0.0, input_shape=(None, input_dim))(x)
h = LSTM(intermediate_dim)(x_masked)
z_mean = Dense(latent_dim)(h)
z_log_sigma = Dense(latent_dim)(h)
def sampling(args):
z_mean, z_log_sigma = args
epsilon = K.random_normal(shape=(batch_size, latent_dim), mean=0., stddev=epsilon_std)
return z_mean + z_log_sigma * epsilon
z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_sigma])
decoded_h = LSTM(intermediate_dim, return_sequences=True)
decoded_mean = LSTM(latent_dim, return_sequences=True)
h_decoded = RepeatVector(max_timesteps)(z)
h_decoded = decoder_h(h_decoded)
x_decoded_mean = decoder_mean(h_decoded)
def crop_outputs(x):
padding = K.cast(K.not_equal(x[1], 0), dtype=K.floatx())
return x[0] * padding
x_decoded_mean = Lambda(crop_outputs, output_shape=(max_timesteps, input_dim))([x_decoded_mean, x])
vae = Model(x, x_decoded_mean)
def vae_loss(x, x_decoded_mean):
xent_loss = objectives.mse(x, x_decoded_mean)
kl_loss = -0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma))
loss = xent_loss + kl_loss
return loss
vae.compile(optimizer='adam', loss=vae_loss)
# Here, X is variable length time series data of shape
# (num_examples, max_timesteps, input_dim) and is zero padded
# on the right for all the examples of length less than max_timesteps
# X has been appropriately scaled using the StandardScaler.
vae.fit(X, X, epochs = num_epochs, batch_size=batch_size)
As always, any help is much appreciated. Thank you!