0

I want to change loss function for every alternative epoch or every 5 epochs. I tried using loss wrapper method suggested in this. This didn't work. It is keeping current_epoch value initialized forever and updated value is not coming in the loss wrapper function although this variable is updated for every epoch end in the on_epoch_end callback.

And also I tried using model.add_loss method in on_epoch_end callback. This is also not working. It takes only the loss function whatever is initialized in model.compile. It doesn't take the loss function passed in model.add_loss.

James Z
  • 12,209
  • 10
  • 24
  • 44
jkstar
  • 467
  • 2
  • 7
  • 19

1 Answers1

0

Just compile the model multiple times:

for epoch in range(num_epochs):
    if epoch % 2 == 0:
        current_loss_fn = loss_fn_1
    else:
        current_loss_fn = loss_fn_2
    model.compile(optimizer=current_optimizer, loss=current_loss_fn)
    train...
DMcC
  • 321
  • 2
  • 7
  • can you please share full code example if possible. – jkstar Apr 22 '23 at 11:10
  • Do you know how to create a model a generate a training loop? The former would go above the code provided and the latter in the "train..." section. If you're not sure how to do those there are plenty of resources. – DMcC Apr 23 '23 at 07:26