2

I'm using RLlib to train a reinforcement learning policy (PPO algorithm). I want to see the weights in the neural network underlying the policy.

After digging through RLlib's PPO object, I found the TensorFlow Graph object. I thought that I would find the weights of the neural network there. But I can't find them. I see that this graph has ~1,000 nodes but I can't for the life of me find where TensorFlow is hiding the actual weights for the neural network. I looked through the nodes. I was told to keep an eye out for tf.Variable objects, but I couldn't find any. The closest thing I could find are nodes of type ReadVariableOp, but I couldn't find a tf.Variable in them. I did find a tf.Tensor in there, but I'm not sure whether it holds actual numbers, and if so how to get them.

Where do I find the weights of my neural network?

Ram Rachum
  • 84,019
  • 84
  • 236
  • 374

1 Answers1

0

In a single-agent setup, do this:

weights = algo.get_policy().get_state()["weights"]

In a multi-agent setup, you'll need to specify the policy name:

weights = algo.get_policy(policy_name).get_state()["weights"]
Ram Rachum
  • 84,019
  • 84
  • 236
  • 374
Pedro Fillastre
  • 892
  • 6
  • 10