I am using a custom PPO model with ray.tune(), and I want to add some self-supervised learning that is dependent on batch[‘obs’], batch[‘done’], batch[‘action’] and batch[‘next_obs’]
I have defined some layers in my model that are called only during training.
I have defined a loss function which I am passing to the trainer within the loss function, i passed various inputs through layers that were never called in the forward model. Specifically, these inputs are train_batch[‘actions’] (things from the observation), and layers that I have stored as attributes of the model (e.g. model.loss_context)
The layers that are not in the forward model (i.e. the ones only called during the loss function) do not seem to be added to the gradient - i am recording their magnitude and they are not changing ,even when I put a clearly simple example which is just a huge weight decay on a layer called outside the forward model.
I have also tried adding these layers to a overridden @custom_loss function, as per the example https://github.com/ray-project/ray/blob/50e1fda022a81e5015978cf723f7b5fd9cc06b2c/rllib/examples/models/custom_loss_model.py: but in this case the weights for those layers do not even initialise.
Has anyone solved this? I see a number of stack overflow questions asking about this but no answers!
see above. i was epxecting the weights to change. here is the loss function
LoggedPPO = PPOTFPolicy.with_updates(
name="SHPPOPolicy",
loss_fn=ppo_surrogate_loss,
grad_stats_fn=grad_stats,
stats_fn=stats,
)
context, action_mask, net_mask = tf.split(
logits,
[
model.context_dim * model.max_num_nets,
model.max_num_nets * (9 + model.svg_feature_dict["max_layers"]),
model.max_num_nets,
],
axis=1,
)
x = model.test_dense(context)
wd_loss = sum(
[tf.reduce_sum(v ** 2) for v in model.test_dense.variables]
) + 1e-4
batch_loss = [ ..... wd_loss]
in this example, test_dense, which is not called during the forward pass, never gets updated, even though the case is trivial and the model should try to reduce the absolute scalar value of its weights.