1

I have a Dataset API doohickey which is part of my tensorflow graph. How do I swap it out when I want to use different data?

dataset = tf.data.Dataset.range(3)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

variable = tf.Variable(3, dtype=tf.int64)
model = variable*next_element

#pretend like this is me training my model, or something
with tf.Session() as sess:
    sess.run(variable.initializer)
    try:
        while True:
            print(sess.run(model)) # (0,3,6)
    except:
        pass

dataset = tf.data.Dataset.range(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()  

### HOW TO DO THIS THING?
with tf.Session() as sess:
    sess.run(variable.initializer) #This would be a saver restore operation, normally...
    try:
        while True:
            print(sess.run(model)) # (0,3)... hopefully
    except:
        pass
Olivier Moindrot
  • 27,908
  • 11
  • 92
  • 91
Him
  • 5,257
  • 3
  • 26
  • 83

2 Answers2

2

I do not believe this is possible. You are asking to change the computation graph itself, which is not allowed in tensorflow. Rather than explain that myself, I find the accepted answer in this post to be particularly clear in explaining that point Is it possible to modify an existing TensorFlow computation graph?

Now, that said, I think there is a fairly simple/clean way to accomplish what you seek. Essentially, you want to reset the graph and rebuild the Dataset part. Of course you want to reuse the model part of the code. Thus just put that model in a class or function to allow reuse. A simple example built on your code:

# the part of the graph you want to reuse
def get_model(next_element):
    variable = tf.Variable(3,dtype=tf.int64)
    return variable*next_element

# the first graph you want to build
tf.reset_default_graph()

# the part of the graph you don't want to reuse
dataset = tf.data.Dataset.range(3)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

# reusable part
model = get_model(next_element)

#pretend like this is me training my model, or something
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    try:
        while True:
            print(sess.run(model)) # (0,3,6)
    except:
        pass

# now the second graph
tf.reset_default_graph()

# the part of the graph you don't want to reuse
dataset = tf.data.Dataset.range(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()  

# reusable part
model = get_model(next_element)

### HOW TO DO THIS THING?
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    try:
        while True:
            print(sess.run(model)) # (0,3)... hopefully
    except:
        pass

Final Note: you will also see some references here and there to tf.contrib.graph_editor docs here. They specifically say that you can't accomplish exactly what you want with the graph_editor (see in that link: "Here is an example of what you cannot do"; but you can get pretty close). Even still though, it's not good practice; they had good reason to make the graph append only, and I think the above method I suggest is the cleaner way to accomplish what you seek.

muskrat
  • 1,519
  • 11
  • 18
  • This is what I've been doing, and it is working. When I asked the question, I was having problems restoring the variables because I wasn't resetting the default graph (duh!). This is definitely the way to go. – Him Jan 10 '18 at 18:39
0

One way I would suggest but that will make things slower is by using place_holders followed by the tf.data.dataset. Therefore, you will have the following:

train_data = tf.placeholder(dtype=tf.float32, shape=[None, None, 1]) # just an example
# Then add the tf.data.dataset here
train_data = tf.data.Dataset.from_tensor_slices(train_data).shuffle(10000).batch(batch_size)

Now when running the graph within a session, you have to feed in the data using the placeholder. So you feed whatever you like...

Hope this helps!!

I. A
  • 2,252
  • 26
  • 65