3

I am using tf.Estimator to create my model. This is training for a while, and then does estimator.export_savedmodel. Since I use dropout for training, I'm afraid that doing the export straight after training will apply dropout when doing predictions.

Now all I have is a model loaded with tf.saved_model.loader.load. I figured I can get the graph definition from the session in which I am loading the model. Can I check the value of dropout here?

Ciprian Tomoiagă
  • 3,773
  • 4
  • 41
  • 65

1 Answers1

4

It turns out you can check values of any variables or constants in the graph. After all, that is the purpose of exporting a model.

You should have access to the session in which the model was loaded. In this case, you can go through all the nodes in the graph, as explained in this question and extract the one corresponding to the dropout value. If you didn't give it a specific name, this will default to something like name_space/dropout/keep_prob.

dropout_nodes = [node for node in sess.graph_def.node if 'dropout' in node.name]

Then, you can inspect the value of any such node. In my case, it looks like this:

name: "deep_bidirectional_lstm/dropout/keep_prob"
op: "Const"
attr {
  key: "dtype"
  value {
    type: DT_FLOAT
  }
}
attr {
  key: "value"
  value {
    tensor {
      dtype: DT_FLOAT
      tensor_shape {
      }
      float_val: 1.0
    }
  }
}

This is a protobuf message. It says the operation is a "Const" and its value is a tensor of type DT_FLOAT, with no shape, and value 1.0

You can use the protobuf API to parse this into a dictionary, or if want just the last part, you can extract it like so:

print(dropout_nodes[0].attr.get('value').tensor.float_val[0])
1.0

So you are safe, your dropout is 1 :)


Coming back to this about 1 year later, I realise there is a point of confusion: when you say .attr.get('value'), the 'value' refers to which of the two attributes to get, based on their key: "dtype", or "value". It has nothing to do with the value property of each attribute.

Ciprian Tomoiagă
  • 3,773
  • 4
  • 41
  • 65