If your preprocessing pipeline is very long and the output is small, the processed data should fit in memory. If this is the case, you can use tf.data.Dataset.cache
to cache the processed data in memory or in a file.
From the official performance guide:
The tf.data.Dataset.cache
transformation can cache a dataset, either in memory or on local storage. If the user-defined function passed into the map transformation is expensive, apply the cache transformation after the map transformation as long as the resulting dataset can still fit into memory or local storage. If the user-defined function increases the space required to store the dataset beyond the cache capacity, consider pre-processing your data before your training job to reduce resource usage.
Example use of cache in memory
Here is an example where each pre-processing takes a lot of time (0.5s). The second epoch on the dataset will be much faster than the first
def my_fn(x):
time.sleep(0.5)
return x
def parse_fn(x):
return tf.py_func(my_fn, [x], tf.int64)
dataset = tf.data.Dataset.range(5)
dataset = dataset.map(parse_fn)
dataset = dataset.cache() # cache the processed dataset, so every input will be processed once
dataset = dataset.repeat(2) # repeat for multiple epochs
res = dataset.make_one_shot_iterator().get_next()
with tf.Session() as sess:
for i in range(10):
# First 5 iterations will take 0.5s each, last 5 will not
print(sess.run(res))
Caching to a file
If you want to write the cached data to a file, you can provide an argument to cache()
:
dataset = dataset.cache('/tmp/cache') # will write cached data to a file
This will allow you to only process the dataset once, and run multiple experiments on the data without reprocessing it again.
Warning: You have to be careful when caching to a file. If you change your data, but keep the /tmp/cache.*
files, it will still read the old data that was cached. For instance, if we use the data from above and change the range of the data to be in [10, 15]
, we will still obtain data in [0, 5]
:
dataset = tf.data.Dataset.range(10, 15)
dataset = dataset.map(parse_fn)
dataset = dataset.cache('/tmp/cache')
dataset = dataset.repeat(2) # repeat for multiple epochs
res = dataset.make_one_shot_iterator().get_next()
with tf.Session() as sess:
for i in range(10):
print(sess.run(res)) # will still be in [0, 5]...
Always delete the cached files whenever the data that you want to cache changes.
Another issue that may arise is if you interrupt the script before all the data is cached. You will receive an error like this:
AlreadyExistsError (see above for traceback): There appears to be a concurrent caching iterator running - cache lockfile already exists ('/tmp/cache.lockfile'). If you are sure no other running TF computations are using this cache prefix, delete the lockfile and re-initialize the iterator.
Make sure that you let the whole dataset be processed to have an entire cache file.