I got the following error message while using tf.data.dataset (tf version 2.6.2) with train_on_batch:
/lib/python3.9/site-packages/tensorflow/python/keras/engine/data_adapter.py", line 1522, in <genexpr>
num_samples = set(int(i.shape[0]) for i in nest.flatten(data))
AttributeError: 'PrefetchDataset' object has no attribute 'shape'
when using this code:
data_x, data_y = load_tiles_and_prepare(perm[j*batch_size:(j+1)*batch_size])
#this returns {ndarray: (8,512,512)} for each with dtype float64 ranging 0-1
train_data = tf.data.Dataset.from_tensor_slices((data_x, data_Y)).repeat()
train_data = train_data.batch(32)
train_data = train_data.prefetch(1)
loss = model.train_on_batch(train_data)
Avoiding the conversion to a data.Dataset is (sadly) not an option as I want to use mirroredStrategy I need a tf.data.Dataset to set the options:
train_data = train_data.with_options(options)
But it might be worth noting that train_on_batch is working fine like this:
data_x, data_y = load_tiles_and_prepare(perm[j*batch_size:(j+1)*batch_size])
loss = model.train_on_batch(tf.convert_to_tensor(data_x), tf.convert_to_tensor(data_y))
So, how can I use the tf.data.dataset in train_on_batch without missing 'shape' attribute error?