3

I have a dataset which I create with the following code working with tf.data.Dataset:

dataset = Dataset.from_tensor_slices(corona_new)
dataset = dataset.window(WINDOW_SIZE, 1, drop_remainder=True)
dataset = dataset.flat_map(lambda x: x.batch(WINDOW_SIZE))
dataset = dataset.map(lambda x: tf.transpose(x))

for i in dataset:
    print(i.numpy())
    break

which when I run it I get the following output (this is an example of one batch):

[[  0. 125. 111. 232. 164. 134. 235. 190.] 
 [  0.  14.  16.   7.   9.   7.   6.   8.]
 [  0. 132. 199. 158. 148. 141. 179. 174.]
 [  0.   0.   0.   2.   0.   2.   1.   2.]
 [  0.   0.   0.   0.   3.   5.   0.   0.]]

How can I unbatch them?

Tomergt45
  • 579
  • 1
  • 7
  • 19

1 Answers1

6

Found my solution.

In TensorFlow 2.0 you can unbatch a tf.data.Dataset by calling the .unbatch() function.

example: dataset.unbatch()

Tomergt45
  • 579
  • 1
  • 7
  • 19