I've consulted the answer here: Why does my training loss have regular spikes? about spikes in the training loss. However, I am not using batches (i.e. batch_size=LARGE_NUMBER).
My model is as follows and I'm using a batchsize of 1_000_000 and a data size of 100_000:
norm = preprocessing.Normalization()
norm.adapt(data)
model = keras.Sequential([
norm,
layers.Dense(100, activation='tanh', kernel_regularizer=regularizers.l2(1e-5)),
layers.Dense(100, activation='tanh', kernel_regularizer=regularizers.l2(1e-5)),
layers.Dense(100, activation='tanh', kernel_regularizer=regularizers.l2(1e-5)),
layers.Dense(1)
])
lr_schedule = tf.keras.optimizers.schedules.InverseTimeDecay(
0.01,
decay_steps=2000,
decay_rate=1,
staircase=False)
optimizer=tf.keras.optimizers.Adam(lr_schedule)
model.compile(loss='huber',
optimizer=optimizer, metrics=['mean_absolute_error'])
history = model.fit(
train[0], train[1], validation_data=test, batch_size=1_000_000,
verbose=2, epochs=epochs)
Looking at the regression fits, I can see that the spikes are physical in that they correspond to obviously bad fits.