4

I'd like to speed up my training routine that uses the Estimator API with input_fn wrote using tf.data.Dataset.

My implementation takes 2 second to prepare a batch of data and then runs training on GPU for 1 sec, and then start over preparing a batch. Which is really inefficient.

I'm looking for a way to prepare the batches asynchronously and upload them to GPU to speed up the training. Or alternatively for a way to cache datasets between invocations of input_fn (the dataset.cache() doesn't seems to be a good choice as the dataset has to be recreated on each input_fn invocation).

Here is a simplified version of my code:

def input_fn(filenames, labels, epochs):
  dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
  dataset = dataset.map(_read_wav, num_parallel_calls=num_map_threads)
  if shuffle:
     dataset = dataset.shuffle(buffer_size=len(labels))
  dataset = dataset.map(_post_process,  num_parallel_calls=num_map_threads)
  dataset = dataset.map(lambda wav, label: ({'wav': wav}, label))
  dataset = dataset.batch(128)
  dataset = dataset.repeat(epochs) # to iterate over the training set forever
  iterator = dataset.dataset.make_one_shot_iterator()
  features, labels = iterator.get_next()
  return features, labels

train_input_fn = lambda : input_fn(train_files, train_labels, None)
eval_input_fn = lambda : input_fn(eval_files, eval_labels, 1)

train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=45000)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn) 
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

I've noticed that the Estimator API is under active development and in the master branch of tensorflow the input_fn can return datasets already, so maybe I'm asking too early and this feature isn't ready yet. But if so, please provide a ticket where this implementation can be tracked.

Olivier Moindrot
  • 27,908
  • 11
  • 92
  • 91
Piotr Czapla
  • 25,734
  • 24
  • 99
  • 122

2 Answers2

11

Using tf.data.Dataset.cache() is indeed not a good choice since it will cache the whole dataset into memory, which takes time and might overflow your memory.

The way to go is to use tf.data.Dataset.prefetch() at the end of your pipeline, which will always make sure that the data pipeline holds buffer_size elements. It is usually enough to have buffer_size = 1 at the end:

dataset = ...
dataset = dataset.batch(128)
dataset = dataset.prefetch(1)  # prefetch one batch

As explained by @mrry in this answer, you can also try to increase the number of prefetched batches a bit.

Typically it is most useful to add a small prefetch buffer (with perhaps just a single element) at the very end of the pipeline, but more complex pipelines can benefit from additional prefetching, especially when the time to produce a single element can vary.


If you still have a slow input pipeline compared to your GPU computations, you need to increase the number of threads working in parallel using the num_parallel_calls argument of tf.data.Dataset.map().

Olivier Moindrot
  • 27,908
  • 11
  • 92
  • 91
  • @Olivier: according to documentation, prefetch(1) means prefetching 1 example, rather than one batch, right? – John Jiang Jul 24 '18 at 23:09
  • This will prefetch one example of the current dataset. Since we batched the dataset just above, the current dataset contains batches of `128` images so it will prefetch one batch. – Olivier Moindrot Jul 25 '18 at 09:18
  • I think important note is that you can cache even if data don't fit into the memory. Tensorflow allows caching into the filestorage if you specify the filename as argument: https://www.tensorflow.org/api_docs/python/tf/data/Dataset#cache – Matěj Račinský Sep 25 '18 at 12:40
1

A few points to add to Olivier's answer, mostly from this post:

  • repeat before shuffle is slightly faster, at the downside of blurred epoch boundaries. This may be significant in rare cases, but I doubt it.
  • shuffle before mapping - this reduces the memory foot print of your shuffle buffer size, since it only needs to buffer the filenames rather than the file contents.
  • it makes more sense to me to apply the third map transform to the output of get_next() rather than the dataset - not sure if that affects speed much. You could also consider putting both other map calls in the same one to reduce scheduling issues.
  • experiment with repeat before batching. Probably won't make a difference, but might be minor. If you repeat before shuffle as mentioned above you'll have to.
  • as mentioned by Olivier, use prefetch.

Code with modifications:

def input_fn(filenames, labels, epochs):
  dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
  dataset = dataset.repeat(epochs)
  if shuffle:
    dataset = dataset.shuffle(buffer_size=len(labels))

  def combined_map_fn(*args):
    return _post_process(_read_wav(*args))

  dataset = dataset.map(combined_map_fn, num_parallel_calls=num_map_threads)
  dataset = dataset.batch(128)
  dataset = dataset.prefetch(1)

  iterator = dataset.dataset.make_one_shot_iterator()
  wavs, labels = iterator.get_next()
  features = {'wav': wavs}
  return features, labels
DomJack
  • 4,098
  • 1
  • 17
  • 32