I'm struggling with passing my (messy) code from tensorflow core to the Estimator
paradigm, especially using Experiments
- with learn_runner.run
. But I'm actually having issues feeding data to my neural network.
What I'm trying to achieve is actually pretty close to what's done with all the examples of TensorFlow and the tf.TextLineReader
, e.g. https://github.com/GoogleCloudPlatform/cloudml-samples/blob/master/census/customestimator/trainer/model.py#L297, though I load data not from a file on disk but with a web-service.
From my understanding (and looking at the code of tensorflow.python.estimator._train_model()
) the input_fn
is only called once and not at each iteration. I could easily load all my data, and then do something like:
def input_fn():
data = # all data in memory
batch = tf.train.input_producer(tf.constant(data))
return batch.dequeue_many(batch_size)
but this is not sustainable as my data won't fit in memory. I'm trying to do something like:
1. load first piece of data (say N lines)
2. consume it by batches in a queue just like the input_fn above
2'. feed this queue asynchronously with new data when it's almost empty
I know how to do it in "pure" tf, e.g. How to prefetch data using a custom python function in tensorflow or Tensorflow: custom data load + asynchronous computation but I'm finding it hard to transpose it to the Experiment
paradigm as I don't have access to the session to load things by myself, nor to the graph to append operations inside.
EDIT
I managed to do it using tf.py_func()
, something like:
class Reader(object):
# a Python object that can load data and have some intelligence, not related to TF, initialized with batch_sized
def read_up_to(self):
"""Reads up to batch_size elements loaded in Python"""
def input_fn():
reader = Reader() # instantiated once
return tf.py_func(reader.read_up_to, inp=[], Tout=...)
I works fine, though it's a bit slower (as expected there's a way round from C++ execution to Python that introduces about 50% delay). I'm trying to work around this by putting in a specific TensorFlow queue the Python data that's read in the reader asynchronously, so that loading could be done without passing data from Python to C++ (just as in the two links above).