24

Let's suppose I have a sequence of integers:

0,1,2, ..

and want to predict the next integer given the last 3 integers, e.g.:

[0,1,2]->5, [3,4,5]->6, etc

Suppose I setup my model like so:

batch_size=1
time_steps=3
model = Sequential()
model.add(LSTM(4, batch_input_shape=(batch_size, time_steps, 1), stateful=True))
model.add(Dense(1))

It is my understanding that model has the following structure (please excuse the crude drawing):

enter image description here

First Question: is my understanding correct?

Note I have drawn the previous states C_{t-1}, h_{t-1} entering the picture as this is exposed when specifying stateful=True. In this simple "next integer prediction" problem, the performance should improve by providing this extra information (as long as the previous state results from the previous 3 integers).

This brings me to my main question: It seems the standard practice (for example see this blog post and the TimeseriesGenerator keras preprocessing utility), is to feed a staggered set of inputs to the model during training.

For example:

batch0: [[0, 1, 2]]
batch1: [[1, 2, 3]]
batch2: [[2, 3, 4]]
etc

This has me confused because it seems this is requires the output of the 1st Lstm Cell (corresponding to the 1st time step). See this figure:

From the tensorflow docs:

stateful: Boolean (default False). If True, the last state for each sample at index i in a batch will be used as initial state for the sample of index i in the following batch.

it seems this "internal" state isn't available and all that is available is the final state. See this figure:

So, if my understanding is correct (which it's clearly not), shouldn't we be feeding non-overlapped windows of samples to the model when using stateful=True? E.g.:

batch0: [[0, 1, 2]]
batch1: [[3, 4, 5]]
batch2: [[6, 7, 8]]
etc
OverLordGoldDragon
  • 1
  • 9
  • 53
  • 101
rmccabe3701
  • 1,418
  • 13
  • 31
  • As I predicted, that's two questions in one. To briefly answer your first question: probably yes. What matters more is what you _think_ that image depicts - but the gist is accurate: LSTM's pass information across hidden states, and pass only one feature tensor to Dense for prediction. (Many-to-one). – OverLordGoldDragon Oct 07 '19 at 22:11
  • What do u mean by "what you *think* that image depicts"? Are u saying its conceptually accurate, but there is a lot more going on than what I have drawn? – rmccabe3701 Oct 07 '19 at 22:14
  • 1
    I'm saying it's a very high-level representation, and that there's plenty to what goes on _inside_ the LSTM - such as `kernel` vs `recurrent` weights, each gate's role, and how information flows between the timesteps. – OverLordGoldDragon Oct 07 '19 at 22:16

1 Answers1

41

The answer is: depends on problem at hand. For your case of one-step prediction - yes, you can, but you don't have to. But whether you do or not will significantly impact learning.


Batch vs. sample mechanism ("see AI" = see "additional info" section)

All models treat samples as independent examples; a batch of 32 samples is like feeding 1 sample at a time, 32 times (with differences - see AI). From model's perspective, data is split into the batch dimension, batch_shape[0], and the features dimensions, batch_shape[1:] - the two "don't talk." The only relation between the two is via the gradient (see AI).


Overlap vs no-overlap batch

Perhaps the best approach to understand it is information-based. I'll begin with timeseries binary classification, then tie it to prediction: suppose you have 10-minute EEG recordings, 240000 timesteps each. Task: seizure or non-seizure?

  • As 240k is too much for an RNN to handle, we use CNN for dimensionality reduction
  • We have the option to use "sliding windows" - i.e. feed a subsegment at a time; let's use 54k

Take 10 samples, shape (240000, 1). How to feed?

  1. (10, 54000, 1), all samples included, slicing as sample[0:54000]; sample[54000:108000] ...
  2. (10, 54000, 1), all samples included, slicing as sample[0:54000]; sample[1:54001] ...

Which of the two above do you take? If (2), your neural net will never confuse a seizure for a non-seizure for those 10 samples. But it'll also be clueless about any other sample. I.e., it will massively overfit, because the information it sees per iteration barely differs (1/54000 = 0.0019%) - so you're basically feeding it the same batch several times in a row. Now suppose (3):

  1. (10, 54000, 1), all samples included, slicing as sample[0:54000]; sample[24000:81000] ...

A lot more reasonable; now our windows have a 50% overlap, rather than 99.998%.


Prediction: overlap bad?

If you are doing a one-step prediction, the information landscape is now changed:

  • Chances are, your sequence length is faaar from 240000, so overlaps of any kind don't suffer the "same batch several times" effect
  • Prediction fundamentally differs from classification in that, the labels (next timestep) differ for every subsample you feed; classification uses one for the entire sequence

This dramatically changes your loss function, and what is 'good practice' for minimizing it:

  • A predictor must be robust to its initial sample, especially for LSTM - so we train for every such "start" by sliding the sequence as you have shown
  • Since labels differ timestep-to-timestep, the loss function changes substantially timestep-to-timestep, so risks of overfitting are far less

What should I do?

First, make sure you understand this entire post, as nothing here's really "optional." Then, here's the key about overlap vs no-overlap, per batch:

  1. One sample shifted: model learns to better predict one step ahead for each starting step - meaning: (1) LSTM's robust against initial cell state; (2) LSTM predicts well for any step ahead given X steps behind
  2. Many samples, shifted in later batch: model less likely to 'memorize' train set and overfit

Your goal: balance the two; 1's main edge over 2 is:

  • 2 can handicap the model by making it forget seen samples
  • 1 allows model to extract better quality features by examining the sample over several starts and ends (labels), and averaging the gradient accordingly

Should I ever use (2) in prediction?

  • If your sequence lengths are very long and you can afford to "slide window" w/ ~50% its length, maybe, but depends on the nature of data: signals (EEG)? Yes. Stocks, weather? Doubt it.
  • Many-to-many prediction; more common to see (2), in large per longer sequences.

LSTM stateful: may actually be entirely useless for your problem.

Stateful is used when LSTM can't process the entire sequence at once, so it's "split up" - or when different gradients are desired from backpropagation. With former, the idea is - LSTM considers former sequence in its assessment of latter:

  • t0=seq[0:50]; t1=seq[50:100] makes sense; t0 logically leads to t1
  • seq[0:50] --> seq[1:51] makes no sense; t1 doesn't causally derive from t0

In other words: do not overlap in stateful in separate batches. Same batch is OK, as again, independence - no "state" between the samples.

When to use stateful: when LSTM benefits from considering previous batch in its assessment of the next. This can include one-step predictions, but only if you can't feed the entire seq at once:

  • Desired: 100 timesteps. Can do: 50. So we set up t0, t1 as in above's first bullet.
  • Problem: not straightforward to implement programmatically. You'll need to find a way to feed to LSTM while not applying gradients - e.g. freezing weights or setting lr = 0.

When and how does LSTM "pass states" in stateful?

  • When: only batch-to-batch; samples are entirely independent
  • How: in Keras, only batch-sample to batch-sample: stateful=True requires you to specify batch_shape instead of input_shape - because, Keras builds batch_size separate states of the LSTM at compiling

Per above, you cannot do this:

# sampleNM = sample N at timestep(s) M
batch1 = [sample10, sample20, sample30, sample40]
batch2 = [sample21, sample41, sample11, sample31]

This implies 21 causally follows 10 - and will wreck training. Instead do:

batch1 = [sample10, sample20, sample30, sample40]
batch2 = [sample11, sample21, sample31, sample41]

Batch vs. sample: additional info

A "batch" is a set of samples - 1 or greater (assume always latter for this answer) . Three approaches to iterate over data: Batch Gradient Descent (entire dataset at once), Stochastic GD (one sample at a time), and Minibatch GD (in-between). (In practice, however, we call the last SGD also and only distinguish vs BGD - assume it so for this answer.) Differences:

  • SGD never actually optimizes the train set's loss function - only its 'approximations'; every batch is a subset of the entire dataset, and the gradients computed only pertain to minimizing loss of that batch. The greater the batch size, the better its loss function resembles that of the train set.
  • Above can extend to fitting batch vs. sample: a sample is an approximation of the batch - or, a poorer approximation of the dataset
  • First fitting 16 samples and then 16 more is not the same as fitting 32 at once - since weights are updated in-between, so model outputs for the latter half will change
  • The main reason for picking SGD over BGD is not, in fact, computational limitations - but that it's superior, most of the time. Explained simply: a lot easier to overfit with BGD, and SGD converges to better solutions on test data by exploring a more diverse loss space.

BONUS DIAGRAMS:


OverLordGoldDragon
  • 1
  • 9
  • 53
  • 101
  • 2
    I'm still digesting this wonderful answer (thanks for that), but I'm still unclear wrt which one of my later two figures is "correct": when ``stateful=True`` does the "final" LSTM state (corresponding to the output given the entire sample's time sequence up to that point) get passed along to the next batch? Or is it some intermediate state? – rmccabe3701 Oct 07 '19 at 22:31
  • @rmccabe3701 Yeah I just realized my answer is incomplete - this is one hell of a question. Working on it – OverLordGoldDragon Oct 07 '19 at 22:32
  • @rmccabe3701 Updated -- I'm not actually entirely sure as to what your diagrams are showing, but they do seem off; let me know if anything remains unclear. (Edit: looking a bit closer, you're probably right deeming the second diagram 'correct' - but I'll take a bit of a break for now) – OverLordGoldDragon Oct 07 '19 at 22:43
  • 1
    Woah, your first “bonus diagram” was exactly what I was asking. So it looks like my second diagram (feeding the ‘final” state into the next batch) is the most accurate (your figure it much clearer). I’m almost ready to mark this question as resolved. But before I do: I’m still not clear of the validity of passing along the state in this way if the input is staggered. I totally understand your motivating examples of why staggering the input is useful, but my confusion was on the apparent inconsistency in the algorithm that carries over the state in this case. – rmccabe3701 Oct 08 '19 at 02:23
  • @rmccabe3701 What do you mean by "stagger"? Splitting the sequence, or slicing by one timestep? – OverLordGoldDragon Oct 08 '19 at 02:29
  • Feeding overlapping windows for the same sample for subsequent batches. A single time step being the most extreme example. It seems to me that the math for ``stateful = True`` only makes sense if you are feeding something like ``[0,1,2]`` *then* ``[3,4,5]`` for batches 0 and 1 respectively (for simplicity numSamples/batch = 1). – rmccabe3701 Oct 08 '19 at 02:41
  • @rmccabe3701 That's exactly right; to your knowledge, I've moved sections around, and encourage you to double-check whether you've read them all - as your question is answered in bold under "LSTM stateful." – OverLordGoldDragon Oct 08 '19 at 02:43
  • @rmccabe3701 You're welcome - glad you found it useful. If you can afford it, I'd appreciate the minimal bounty on this answer - as I did put quite some time & effort into it. Good luck with your DL – OverLordGoldDragon Oct 08 '19 at 02:47
  • Let us [continue this discussion in chat](https://chat.stackoverflow.com/rooms/200563/discussion-between-rmccabe3701-and-overlordgolddragon). – rmccabe3701 Oct 08 '19 at 15:10
  • @OverLoardGoldDragon This keras-preprocessing [merge request](https://github.com/keras-team/keras-preprocessing/pull/251) addresses my concerns. – rmccabe3701 Oct 11 '19 at 00:20
  • @rmccabe3701 Unsure what you're referring to, or how it's related to this question, but I did respond in chat. If it's regarding a data input pipeline, that's again a whole question of its own -- fortunately, I wrote almost an entire API for various ways of feeding/preprocessing timeseries data, though it isn't yet published, I'd probably have an answer – OverLordGoldDragon Oct 11 '19 at 00:47