2

I am currently experimenting with distributed tensorflow. I am using the tf.estimator.Estimator class (custom model function) together with tf.contrib.learn.Experiment and managed it to get a working data parallel execution.

However, I would now like to try model parallel execution. I was not able to find any example for that, except Implementation of model parallelism in tensorflow. But I am not sure how to implement this using tf.estimators (e.g. how to deal with the input functions?).

Does anybody have any experience with it or can provide a working example?

Tobias
  • 1,880
  • 11
  • 17

1 Answers1

2

First up, you should stop using tf.contrib.learn.Estimator in favor of tf.estimator.Estimator, because contrib is an experimental module, and classes that have graduated to the core API (such es Estimator) automatically get deprecated.

Now, back to your main question, you can create a distributed model and pass it via model_fn parameter of tf.estimator.Estimator.__init__.

def my_model(features, labels, mode):
  net = features[X_FEATURE]
  with tf.device('/device:GPU:1'):
    for units in [10, 20, 10]:
      net = tf.layers.dense(net, units=units, activation=tf.nn.relu)
      net = tf.layers.dropout(net, rate=0.1)

  with tf.device('/device:GPU:2'):
    logits = tf.layers.dense(net, 3, activation=None)
    onehot_labels = tf.one_hot(labels, 3, 1, 0)
    loss = tf.losses.softmax_cross_entropy(onehot_labels=onehot_labels, 
                                           logits=logits)

  optimizer = tf.train.AdagradOptimizer(learning_rate=0.1)
  train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
  return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

[...]

classifier = tf.estimator.Estimator(model_fn=my_model)

The model above defines 6 layers with /device:GPU:1 placement and 3 other layers with /device:GPU:2 placement. The return value of my_model function should be an EstimatorSpec instance. A complete working example can be found in tensorflow examples.

Maxim
  • 52,561
  • 27
  • 155
  • 209
  • Thanks! I am already using `tf.estimator.Estimator`. Just the `Experiment` class is from `tf.contrib.learn`, but to my knowledge there is no other version of this class. Also thanks for your example! I think I got it now. I just was not sure how to deal with the input fn etc, but apparently I do not have to do that device specific. – Tobias Nov 13 '17 at 08:27
  • Right, `Experiment` class is still experimental (sounds funny), that's why it's in `contrib`. This means it's API can change. But once it or any other class is graduated to core, the code should be migrated to core API. – Maxim Nov 13 '17 at 19:02