19

I am trying to learn TensorFlow and studying the example at: https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/3_NeuralNetworks/autoencoder.ipynb

I then have some questions in the code below:

for epoch in range(training_epochs):
    # Loop over all batches
    for i in range(total_batch):
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)
        # Run optimization op (backprop) and cost op (to get loss value)
        _, c = sess.run([optimizer, cost], feed_dict={X: batch_xs})
    # Display logs per epoch step
    if epoch % display_step == 0:
        print("Epoch:", '%04d' % (epoch+1),
              "cost=", "{:.9f}".format(c))

Since mnist is just a dataset, what exactly does mnist.train.next_batch mean? How was the dataset.train.next_batch defined?

Thanks!

Edamame
  • 23,718
  • 73
  • 186
  • 320

1 Answers1

29

The mnist object is returned from the read_data_sets() function defined in the tf.contrib.learn module. The mnist.train.next_batch(batch_size) method is implemented here, and it returns a tuple of two arrays, where the first represents a batch of batch_size MNIST images, and the second represents a batch of batch-size labels corresponding to those images.

The images are returned as a 2-D NumPy array of size [batch_size, 784] (since there are 784 pixels in an MNIST image), and the labels are returned as either a 1-D NumPy array of size [batch_size] (if read_data_sets() was called with one_hot=False) or a 2-D NumPy array of size [batch_size, 10] (if read_data_sets() was called with one_hot=True).

mrry
  • 125,488
  • 26
  • 399
  • 400
  • 10
    It's worth mentioning that [next_batch](https://github.com/tensorflow/tensorflow/blob/7c36309c37b04843030664cdc64aca2bb7d6ecaa/tensorflow/contrib/learn/python/learn/datasets/mnist.py#L160) reshuffles the examples after going through all of them every epoch. You can track where you're in the epoch by `DataSet._index_in_epoch`, like `mnist.train._index_in_epoch` – Yibo Yang Aug 02 '17 at 22:13
  • @YiboYang So does that mean with next_batch() not all the training data may be fed in while training? I am a newbie in Tensorflow, so please excuse me if this seems silly. Thanks – Loochie Jan 24 '22 at 11:42