I have written a more general unzip function for tf.data.Dataset
pipelines, which also handles the recursive case where a pipeline has multiple levels of zipping.
import tensorflow as tf
def tfdata_unzip(
tfdata: tf.data.Dataset,
*,
recursive: bool=False,
eager_numpy: bool=False,
num_parallel_calls: int=tf.data.AUTOTUNE,
):
"""
Unzip a zipped tf.data pipeline.
Args:
tfdata: the :py:class:`tf.data.Dataset`
to unzip.
recursive: Set to ``True`` to recursively unzip
multiple layers of zipped pipelines.
Defaults to ``False``.
eager_numpy: Set this to ``True`` to return
Python lists of primitive types or
:py:class:`numpy.array` objects. Defaults
to ``False``.
num_parallel_calls: The level of parallelism to
each time we ``map()`` over a
:py:class:`tf.data.Dataset`.
Returns:
Returns a Python list of either
:py:class:`tf.data.Dataset` or NumPy
arrays.
"""
if isinstance(tfdata.element_spec, tf.TensorSpec):
if eager_numpy:
return list(tfdata.as_numpy_iterator())
return tfdata
def tfdata_map(i: int) -> list:
return tfdata.map(
lambda *cols: cols[i],
deterministic=True,
num_parallel_calls=num_parallel_calls,
)
if isinstance(tfdata.element_spec, tuple):
num_columns = len(tfdata.element_spec)
if recursive:
return [
tfdata_unzip(
tfdata_map(i),
recursive=recursive,
eager_numpy=eager_numpy,
num_parallel_calls=num_parallel_calls,
)
for i in range(num_columns)
]
else:
return [
tfdata_map(i)
for i in range(num_columns)
]
raise ValueError(
"Unknown tf.data.Dataset element_spec: " +
str(tfdata.element_spec)
)
Here is how tfdata_unzip()
works, given these example datasets:
>>> import numpy as np
>>> baby = tf.data.Dataset.from_tensor_slices([
np.array([1,2]),
np.array([3,4]),
np.array([5,6]),
])
>>> baby.element_spec
TensorSpec(shape=(2,), dtype=tf.int64, name=None)
TensorSpec(shape=(2,), dtype=tf.int64, name=None)
>>> parent = tf.data.Dataset.zip((baby, baby))
>>> parent.element_spec
(TensorSpec(shape=(2,), dtype=tf.int64, name=None),
TensorSpec(shape=(2,), dtype=tf.int64, name=None))
>>> grandparent = tf.data.Dataset.zip((parent, parent))
>>> grandparent.element_spec
((TensorSpec(shape=(2,), dtype=tf.int64, name=None),
TensorSpec(shape=(2,), dtype=tf.int64, name=None)),
(TensorSpec(shape=(2,), dtype=tf.int64, name=None),
TensorSpec(shape=(2,), dtype=tf.int64, name=None)))
This is what tfdata_unzip()
returns on the above baby
, parent
, and grandparent
datasets:
>>> tfdata_unzip(baby)
<TensorSliceDataset shapes: (2,), types: tf.int64>
>>> tfdata_unzip(parent)
[<ParallelMapDataset shapes: (2,), types: tf.int64>,
<ParallelMapDataset shapes: (2,), types: tf.int64>]
>>> tfdata_unzip(grandparent)
[<ParallelMapDataset shapes: ((2,), (2,)), types: (tf.int64, tf.int64)>,
<ParallelMapDataset shapes: ((2,), (2,)), types: (tf.int64, tf.int64)>]
>>> tfdata_unzip(grandparent, recursive=True)
[[<ParallelMapDataset shapes: (2,), types: tf.int64>,
<ParallelMapDataset shapes: (2,), types: tf.int64>],
[<ParallelMapDataset shapes: (2,), types: tf.int64>,
<ParallelMapDataset shapes: (2,), types: tf.int64>]]
>>> tfdata_unzip(grandparent, recursive=True, eager_numpy=True)
[[[array([1, 2]), array([3, 4]), array([5, 6])],
[array([1, 2]), array([3, 4]), array([5, 6])]],
[[array([1, 2]), array([3, 4]), array([5, 6])],
[array([1, 2]), array([3, 4]), array([5, 6])]]]