I'm trying to understand the TensorFlow Wide & Deep Learning Tutorial. The census income dataset has two files for validation: adult.data and adult.test. After a certain number of epochs, it prints an evaluation (you can see the complete code here: https://github.com/tensorflow/models/blob/master/official/wide_deep/wide_deep.py). It uses "input_fn" to read input information from a csv file. It's used to read both files, adult.data and adult.test.
def input_fn(data_file, num_epochs, shuffle, batch_size):
"""Generate an input function for the Estimator."""
assert tf.gfile.Exists(data_file), (
'%s not found. Please make sure you have either run data_download.py or '
'set both arguments --train_data and --test_data.' % data_file)
def parse_csv(value):
print('Parsing', data_file)
columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS)
features = dict(zip(_CSV_COLUMNS, columns))
labels = features.pop('income_bracket')
return features, tf.equal(labels, '>50K')
# Extract lines from input files using the Dataset API.
dataset = tf.data.TextLineDataset(data_file)
if shuffle:
dataset = dataset.shuffle(buffer_size=_NUM_EXAMPLES['train'])
dataset = dataset.map(parse_csv, num_parallel_calls=5)
# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)
return dataset
It builds an estimator, DNNLinearCombinedClassifier, and evaluates and prints the accuracy like this:
...
results = model.evaluate(input_fn=lambda: input_fn(
FLAGS.test_data, 1, False, FLAGS.batch_size))
# Display evaluation metrics
print('Results at epoch', (n + 1) * FLAGS.epochs_per_eval)
print('-' * 60)
for key in sorted(results):
print('%s: %s' % (key, results[key]))
I understand you should train your network with batches. My question is, why would they evaluate the model with a batch? Shouldn't they use the entire evaluation dataset? The dataset has 16281 validation values, shouldn't it call model.evaluate like this?:
_NUM_EXAMPLES = {
'train': 32561,
'validation': 16281,
}
...
results = model.evaluate(input_fn=lambda: input_fn(
FLAGS.test_data, 1, False, _NUM_EXAMPLES['validation']))
Is it wrong to use the entire validation dataset?