I am trying reimplement some parts of Nvidia's noise2noise repo to learn tensorflow and the tf.data
pipeline, and I am having a lot of trouble understanding what is happening. So far I am able to create a TFRecord
consisting of tf.train.Example
types as described in https://github.com/NVlabs/noise2noise/blob/master/dataset_tool_tf.py
image = load_image(imgname)
feature = {
'shape': shape_feature(image.shape),
'data': bytes_feature(tf.compat.as_bytes(image.tostring()))
}
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString())
That part makes sense. What's driving me nuts is the noise augmentation piece in https://github.com/NVlabs/noise2noise/blob/master/dataset.py Specifically the function:
def create_dataset(train_tfrecords, minibatch_size, add_noise):
print ('Setting up dataset source from', train_tfrecords)
buffer_mb = 256
num_threads = 2
dset = tf.data.TFRecordDataset(train_tfrecords, compression_type='', buffer_size=buffer_mb<<20)
dset = dset.repeat()
buf_size = 1000
dset = dset.prefetch(buf_size)
dset = dset.map(parse_tfrecord_tf, num_parallel_calls=num_threads)
dset = dset.shuffle(buffer_size=buf_size)
dset = dset.map(lambda x: random_crop_noised_clean(x, add_noise))
dset = dset.batch(minibatch_size)
it = dset.make_one_shot_iterator()
return it
returns an iterator. This iterator is used in train.py
and has three elements that are returned at every iteration:
noisy_input, noisy_target, clean_target = dataset_iter.get_next()
I've tried reimplementing this in a local tensorflow jupyter notebook, and I can't figure out where those three items are coming from. The way I understood it, the create_dataset(...)
function just takes every input image in the Example
record, and augments it with gaussian/poisson noise. But then why is the returned iterator pointing to three different images? What's the connection between the augmentation in create_dataset(...)
and the three different images in the iterator?
I found this, which was really helpful in understanding map
, batch
, and shuffle
: What does batch, repeat, and shuffle do with TensorFlow Dataset?