2

I am trying to implement the deep q learning programs DeepMind used to train an AI to play Atari games. One of the features they use and is mentioned in multiple tutorials is to have two versions of your neural network; one to update as you cycle through mini-batch training data (call this Q), and one to call as your doing this to help construct the training data (Q'). Then periodically (say every 10k data points) the weights in Q' get set to the current values of Q.

My question is what is the best way to do this in TensorFlow? Both to store two identical architecture networks at the same time, and to periodically update ones weights from the other? My current net is shown below and is currently just using the default graph and an interactive session.

sess = tf.InteractiveSession()

x = tf.placeholder(tf.float32, shape=[None, height, width, m])
y_ = tf.placeholder(tf.float32, shape=[None, env.action_space.n])

W_conv1 = weight_variable([8, 8, 4, 32])
b_conv1 = bias_variable([32])
h_conv1 = tf.nn.relu(conv2d(x, W_conv1, 4, 4) + b_conv1)

W_conv2 = weight_variable([4, 4, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_conv1, W_conv2, 2, 2) + b_conv2)

W_conv3 = weight_variable([3, 3, 64, 64])
b_conv3 = bias_variable([64])
h_conv3 = tf.nn.relu(conv2d(h_conv2, W_conv3, 1, 1) + b_conv3)

# Flattern conv to dense
flat_input_size = 14*10*64
h_conv3_reshape = tf.reshape(h_conv3, [-1, flat_input_size])

# Dense layers
W_fc1 = weight_variable([flat_input_size, 512])
b_fc1 = bias_variable([512])
h_fc1 = tf.nn.relu(tf.matmul(h_conv3_reshape, W_fc1) + b_fc1)

W_fc2 = weight_variable([512, env.action_space.n])
b_fc2 = bias_variable([env.action_space.n])
y_conv = tf.matmul(h_fc1, W_fc2) + b_fc2

accuracy = tf.squared_difference(y_, y_conv)
loss = tf.reduce_mean(accuracy)
optimizer = tf.train.AdamOptimizer(0.0001).minimize(loss)

tf.global_variables_initializer().run()
Maxim
  • 52,561
  • 27
  • 155
  • 209
Usherwood
  • 359
  • 3
  • 11

1 Answers1

3

Here's a way to arrange this. First up, you make a separate graph for each network to run them in parallel in different sessions:

graph1 = tf.Graph()
with graph1.as_default():
  model1 = build_model()

graph2 = tf.Graph()
with graph2.as_default():
  model2 = build_model()

... where build_model() defines all placeholders, variables and training ops. Both models should use the same naming for variables, which will allow them to swap the state easily.

Each network can be trained using a snapshot of the other network for the target (the latest or previous best, it's up to you). Periodically, each network is saved to the disk via tf.Saver() and restored using the weights of one the networks. For example, this code will load the weights from the 2nd network into the 1st graph:

with tf.Session(graph=graph1) as sess:
  saver = tf.train.import_meta_graph('/tmp/model-2/network.meta')
  saver.restore(sess, '/tmp/model-2/network')
  ... continue training

And this is how the model is saved:

with tf.Session(graph=graph1) as sess:
  ... do some training
  save_path = saver.save(sess, '/tmp/model-1/network')

More on saving and restoring in this question. You can do this in the same session or start a new one.

In fact, you can even try to use the same location on disk for both networks, so that both save and restore from the same file. But this will force you to have the latest snapshot, while the previous approach is more flexible.

One thing you need to be careful of is the use of sessions: a session created for graph1 can only evaluate tensors and ops from the graph1. Example:

def build_model():
  x = tf.placeholder(tf.float32, name='x')
  y = tf.placeholder(tf.float32, name='y')
  z = x + y
  return x, y, z

graph1 = tf.Graph()
with graph1.as_default():
  x1, y1 ,z1 = build_model()

graph2 = tf.Graph()
with graph2.as_default():
  x2, y2, z2 = build_model()

with tf.Session(graph=graph1) as sess1:
  with tf.Session(graph=graph2) as sess2:
    # Good
    print(sess1.run(z1, feed_dict={x1: 1, y1: 2}))  # 3.0
    print(sess2.run(z2, feed_dict={x2: 3, y2: 1}))  # 4.0

    # BAD! Wrong graph
    # print(sess1.run(z2, feed_dict={x2: 3, y2: 1}))
Maxim
  • 52,561
  • 27
  • 155
  • 209