I encountered an issue when building a network that is loosely based on a CycleGAN architecture
I made all of its components fit inside one nn.Module
from torch import nn
from classes.EncoderDecoder import EncoderDecoder
from classes.Discriminator import Discriminator
class CycleGAN(nn.Module):
def __init__(self):
super(CycleGAN, self).__init__()
self.encdec1 = EncoderDecoder(encoder_in_channels=3)
self.encdec2 = EncoderDecoder(encoder_in_channels=3)
self.disc = Discriminator()
def forward(self, images, images_bw):
disc_color = self.disc(images) # I want the Discriminator to be trained here
disc_bw = self.disc(images_bw) # I want the Discriminator to be trained here
decoded1 = self.encdec1(images_bw) # EncoderDecoder forward pass
decoded2 = self.encdec2(decoded1)
decoded_disc = self.disc(decoded1) # I don't want to train the Discriminator here,
# only the EncoderDecoder should be trained based
# on this Discriminator's result
return [disc_color, disc_bw, decoded1, decoded2, decoded_disc]
This is how I initialize this module, loss functions and the optimizer
c_gan = CycleGAN().to('cuda', dtype=float32, non_blocking=True)
l2_loss = MSELoss().to('cuda', dtype=float32).train()
bce_loss = BCELoss().to('cuda', dtype=float32).train()
optimizer_gan = Adam(c_gan.parameters(), lr=0.00001)
This is how I train the network inside the training loop
c_gan.zero_grad()
optimizer_gan.zero_grad()
disc_color, disc_bw, decoded1, decoded2, decoded_disc = c_gan(images, images_bw)
loss_true = bce_loss(disc_color, label_true)
loss_false = bce_loss(disc_bw, label_false)
disc_loss = loss_true + loss_false
disc_loss.backward()
decoded_loss = l2_loss(decoded2, images_bw)
decoded_disc_loss = bce_loss(decoded_disc, label_true) # This is where the loss for that Discriminator forward pass is calculated
both_decoded_losses = decoded_loss + decoded_disc_loss
both_decoded_losses.backward()
optimizer_gan.step()
The issue
I don't want to train the Discriminator
module based on the EncoderDecoder -> Discriminator
forward pass. I do however want to train it based on images -> Discriminator
and images_bw -> Discriminator
forward passes.
- Is it possible to achieve this using only one optimizer for my
CycleGAN
module? - Can I freeze the
Discriminator
during the optimizer's.step()
?
I would appreciate any help.