0

I'm following along the keras tutorial on image classification. I have created a tf.data.Dataset and specified a single batch using the .take() method:

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "data",
    validation_split=0.2,
    subset="training",
    image_size=(224,224),
    batch_size=32)

train_batch = train_ds.take(1)

Inspecting the train_batch object, as expected, I see it is made up of two objects: images and labels:

<TakeDataset shapes: ((None, 224, 224, 3), (None,)), types: (tf.float32, tf.int32)>

The tutorial states uses the following code to plot the images in this batch:

for images, labels in train_batch:
    for i in range(32):
        ax = plt.subplot(4, 8, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))

My question is how does for images, labels in train_batch: manage to specify the images and labels separately. Apart from enumerate I have not come across specifying two variables in a for loop. Is this the only way to access the images and labels in a batch?

  • I think this is a duplicate of https://stackoverflow.com/questions/56226621/how-to-extract-data-labels-back-from-tensorflow-dataset – krenerd Feb 01 '21 at 03:57

1 Answers1

0

train_batch returns a tuple (image,label). take for example the code below

x=(1,2,3)
a,b,c=x
print ('a= ', a,' b= ',b,' c= ', c)
# the result will be a=  1  b=  2  c=  3

same process happens in the for loop images receives the image part of the tuple and labels receives the label part of the tuple.

Gerry P
  • 7,662
  • 3
  • 10
  • 20
  • Thanks for the comment, however if I defined x=(1,2,3) and then tried the iterator as was described in the keras example above, I would get: for a,b,c in x: which raises an exception. Is it because a tf.data.dataset is similar to a zip object? – Matt_Haythornthwaite Feb 01 '21 at 12:16