EDIT3: You cannot do this natively and I marked the answer that said so. However, I posted an example solution in another answer below for those curious.
EDIT2: Simple code with issue replication below.
EDIT: This is not a question about how to queue/batch over multiple epochs in general which is what the duplicate/suggested post explains, I'm asking specifically how to get non-perfect batch sizes working properly. That post simply mentions that the "allow_smaller_final_batch=True
" argument should account for this scenario, but does not seem to (as proven in my code below).
In my TF neural network, I am using tf.train.slice_input_producer
and tf.train.batch
to batch my data over epochs, which works flawlessly when my batch size is a perfect multiple of my number of samples.
Unfortunately if it's not, the last batch of an epoch trails over into the next epoch's (i.e. there is no true "epoch" division), which eventually means that every epoch is different. EXAMPLE:
2 Epochs * 12 samples = 24 total values, Batch_size = 5,
WHAT IS CORRECT:
Epoch 1: [5 items], [5 items], [2 items]
Epoch 2: [5 items], [5 items], [2 items]
WHAT IT'S ACTUALLY DOING:
Epoch 1: [5 items], [5 items], [5 items]
Epoch 2: [5 items], [4 items], [0 items: out of bounds]
Code that produces the above example (very similar to my NN implementation):
import tensorflow as tf
import numpy as np
batch_size = 5
epochs = 2
Data = list(range(12))
iterations = int(np.ceil(len(Data)/batch_size)*epochs)
sess = tf.InteractiveSession()
x1 = tf.train.slice_input_producer([Data], num_epochs=epochs)
x2 = tf.train.batch(x1, batch_size=batch_size, allow_smaller_final_batch=True)
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord)
for i in range(iterations):
temp_batch = sess.run(x2)
print('\n' + str(temp_batch))
sess.close()
I know this is likely just a bi-product of how tf.train.slice_input_producer
works and I can probably manually achieve/avoid this in various ways, but is there no way to natively distinguish the "end" of an epoch with slicing?