0

I want to exercise this tensorflow tutorial about image classification https://www.tensorflow.org/tutorials/keras/basic_classification

and have loaded my data into a dataset with this tutorial https://www.tensorflow.org/tutorials/load_data/images

If i am calling the model fit method i get following exception: Error when checking input: expected flatten_6_input to have 4 dimensions, but got array with shape (28, 28, 3)

The images are in the shape of (28, 28, 3)

I have already tried to pass several dimension as input parameter and also read out the data from the dataset to verify the data.

image_ds = path_ds.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)
label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(all_image_labels, tf.int64))
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))

model = keras.Sequential([
    keras.layers.Flatten(input_shape=(28, 28, 3)),
    keras.layers.Dense(256, activation=tf.nn.relu),
    keras.layers.Dense(2, activation=tf.nn.softmax)
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(image_label_ds.shuffle(5), epochs=5)
Epoch 1/5
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-114-a638eaff329f> in <module>()
----> 1 model.fit(train_dataset.shuffle(5), epochs=5)

5 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
    374                            ': expected ' + names[i] + ' to have ' +
    375                            str(len(shape)) + ' dimensions, but got array '
--> 376                            'with shape ' + str(data_shape))
    377         if not check_batch_axis:
    378           data_shape = data_shape[1:]

ValueError: Error when checking input: expected flatten_6_input to have 4 dimensions, but got array with shape (28, 28, 3)

1 Answers1

0

I just included the command

''' image_label_ds= image_label_ds.batch(2) '''

and it worked