2

I'm trying to understand the TensorFlow Wide & Deep Learning Tutorial. The census income dataset has two files for validation: adult.data and adult.test. After a certain number of epochs, it prints an evaluation (you can see the complete code here: https://github.com/tensorflow/models/blob/master/official/wide_deep/wide_deep.py). It uses "input_fn" to read input information from a csv file. It's used to read both files, adult.data and adult.test.

def input_fn(data_file, num_epochs, shuffle, batch_size):
  """Generate an input function for the Estimator."""
  assert tf.gfile.Exists(data_file), (
      '%s not found. Please make sure you have either run data_download.py or '
      'set both arguments --train_data and --test_data.' % data_file)

  def parse_csv(value):
    print('Parsing', data_file)
    columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS)
    features = dict(zip(_CSV_COLUMNS, columns))
    labels = features.pop('income_bracket')
    return features, tf.equal(labels, '>50K')

  # Extract lines from input files using the Dataset API.
  dataset = tf.data.TextLineDataset(data_file)

  if shuffle:
    dataset = dataset.shuffle(buffer_size=_NUM_EXAMPLES['train'])

  dataset = dataset.map(parse_csv, num_parallel_calls=5)

  # We call repeat after shuffling, rather than before, to prevent separate
  # epochs from blending together.
  dataset = dataset.repeat(num_epochs)
  dataset = dataset.batch(batch_size)
  return dataset

It builds an estimator, DNNLinearCombinedClassifier, and evaluates and prints the accuracy like this:

...
results = model.evaluate(input_fn=lambda: input_fn(
    FLAGS.test_data, 1, False, FLAGS.batch_size))

# Display evaluation metrics
print('Results at epoch', (n + 1) * FLAGS.epochs_per_eval)
print('-' * 60)

for key in sorted(results):
  print('%s: %s' % (key, results[key]))

I understand you should train your network with batches. My question is, why would they evaluate the model with a batch? Shouldn't they use the entire evaluation dataset? The dataset has 16281 validation values, shouldn't it call model.evaluate like this?:

_NUM_EXAMPLES = {
  'train': 32561,
  'validation': 16281,
}
...
results = model.evaluate(input_fn=lambda: input_fn(
    FLAGS.test_data, 1, False, _NUM_EXAMPLES['validation']))

Is it wrong to use the entire validation dataset?

Maxim
  • 52,561
  • 27
  • 155
  • 209
Kuranes
  • 197
  • 1
  • 12

1 Answers1

1

Both training and testing require mini-batches of data, because both may lead to out-of-memory error (OOM) otherwise. You are right that the problem is more critical in training because backward pass effectively doubles memory consumption. But it doesn't mean OOM is impossible in inference.

Examples from my experience:

... and I'm sure there are many more examples that I haven't seen. Depending on your resources, 16281 might be small enough to fit into one batch, but in general it makes perfect sense to iterate in batches in inference and have a separate setting for this batch size, for instance if the model would ever run on another machine with fewer resources.

Maxim
  • 52,561
  • 27
  • 155
  • 209