I have a tf.data.Dataset
that looks like this:
<BatchDataset shapes: ((None, 256, 256, 3), (None,)), types: (tf.float32, tf.int32)>
The 2nd element (1st if zero indexing) corresponds with a label. I want to cast the 2nd term (labels) to tf.uint8
.
How can one use tf.cast
when dealing with td.data.Dataset
?
Similar Questions
How to convert tf.int64 to tf.float32? is very similar, but is not for a tf.data.Dataset
.
Repro
From Image classification from scratch:
curl -O https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_5340.zip
unzip kagglecatsanddogs_5340.zip
Then in Python with tensorflow~=2.4
:
import tensorflow as tf
ds = tf.keras.preprocessing.image_dataset_from_directory(
"PetImages", batch_size=32
)
print(ds)