1

I got the following error message while using tf.data.dataset (tf version 2.6.2) with train_on_batch:

/lib/python3.9/site-packages/tensorflow/python/keras/engine/data_adapter.py", line 1522, in <genexpr>
    num_samples = set(int(i.shape[0]) for i in nest.flatten(data))
AttributeError: 'PrefetchDataset' object has no attribute 'shape'

when using this code:

data_x, data_y = load_tiles_and_prepare(perm[j*batch_size:(j+1)*batch_size])
#this returns {ndarray: (8,512,512)} for each with dtype float64 ranging 0-1
train_data = tf.data.Dataset.from_tensor_slices((data_x, data_Y)).repeat()
train_data = train_data.batch(32)
train_data = train_data.prefetch(1)
loss = model.train_on_batch(train_data)

Avoiding the conversion to a data.Dataset is (sadly) not an option as I want to use mirroredStrategy I need a tf.data.Dataset to set the options:

train_data = train_data.with_options(options)

But it might be worth noting that train_on_batch is working fine like this:

data_x, data_y = load_tiles_and_prepare(perm[j*batch_size:(j+1)*batch_size])
loss = model.train_on_batch(tf.convert_to_tensor(data_x), tf.convert_to_tensor(data_y))

So, how can I use the tf.data.dataset in train_on_batch without missing 'shape' attribute error?

Christian Gold
  • 341
  • 3
  • 14
  • The load_tiles_and_prepare is unnecessary to share as I already explained what the data afterwards looks like (8 images with 512,512 pixels each). The model is an autoencoder and probably too much code to share. I will try to get a reduced model if this is helps. – Christian Gold Jan 19 '22 at 21:00
  • It is just hard to help without further details. – AloneTogether Jan 25 '22 at 10:32
  • You can try https://stackoverflow.com/questions/62436302/extract-target-from-tensorflow-prefetchdataset – Kliment Feb 20 '23 at 16:54

0 Answers0