0

I encountered an issue when building a network that is loosely based on a CycleGAN architecture

I made all of its components fit inside one nn.Module

from torch import nn

from classes.EncoderDecoder import EncoderDecoder
from classes.Discriminator import Discriminator

class CycleGAN(nn.Module):
    def __init__(self):
        super(CycleGAN, self).__init__()
        self.encdec1 = EncoderDecoder(encoder_in_channels=3)
        self.encdec2 = EncoderDecoder(encoder_in_channels=3)
        self.disc = Discriminator()
        

    def forward(self, images, images_bw):

        disc_color = self.disc(images) # I want the Discriminator to be trained here
        disc_bw = self.disc(images_bw) # I want the Discriminator to be trained here

        decoded1 = self.encdec1(images_bw) # EncoderDecoder forward pass
        decoded2 = self.encdec2(decoded1)

        decoded_disc = self.disc(decoded1)  # I don't want to train the Discriminator here, 
                                            # only the EncoderDecoder should be trained based
                                            # on this Discriminator's result

        return [disc_color, disc_bw, decoded1, decoded2, decoded_disc]

This is how I initialize this module, loss functions and the optimizer

c_gan = CycleGAN().to('cuda', dtype=float32, non_blocking=True)

l2_loss = MSELoss().to('cuda', dtype=float32).train()
bce_loss = BCELoss().to('cuda', dtype=float32).train()

optimizer_gan = Adam(c_gan.parameters(), lr=0.00001)

This is how I train the network inside the training loop

c_gan.zero_grad()
optimizer_gan.zero_grad()

disc_color, disc_bw, decoded1, decoded2, decoded_disc = c_gan(images, images_bw)

loss_true = bce_loss(disc_color, label_true)
loss_false = bce_loss(disc_bw, label_false)
disc_loss = loss_true + loss_false
disc_loss.backward()

decoded_loss = l2_loss(decoded2, images_bw)
decoded_disc_loss = bce_loss(decoded_disc, label_true) # This is where the loss for that Discriminator forward pass is calculated
both_decoded_losses = decoded_loss + decoded_disc_loss
both_decoded_losses.backward()
optimizer_gan.step()

The issue

I don't want to train the Discriminator module based on the EncoderDecoder -> Discriminator forward pass. I do however want to train it based on images -> Discriminator and images_bw -> Discriminator forward passes.

  • Is it possible to achieve this using only one optimizer for my CycleGAN module?
  • Can I freeze the Discriminator during the optimizer's .step()?

I would appreciate any help.

NakedCat
  • 852
  • 1
  • 11
  • 40

1 Answers1

0

From PyTorch example: freezing a part of the net (including fine-tuning) - GitHub gist

class CycleGan:
    ...

c_gan = CycleGan()
# freeze every layer of discriminator
# c_gan.disc.{layer}.weight.requires_grad = False
# c_gan.disc.{layer}.bias.requires_grad = False

...
Felipe Whitaker
  • 470
  • 3
  • 9
  • Will this work for my use case? I don't want to freeze it immediately after initializing the network, I actually want to train the `Discriminator` based on the first two forward passes through it. I just want to freeze it during the third pass, where it receives input from `EncoderDecoder`. I want to train the `EncoderDecoder` based on this `Discriminator`'s output label. – NakedCat Sep 25 '21 at 13:36
  • It seems to me that you would like to `detach` or set the `requires_grad` to False for the `Discriminator` for the third pass - [Difference Between Detach and with torch.no_grad()](https://stackoverflow.com/questions/56816241/difference-between-detach-and-with-torch-nograd-in-pytorch), hope it helps! – Felipe Whitaker Sep 25 '21 at 13:45
  • Thank you, I will do some testing and I will come back to you if with my results today / tomorrow. – NakedCat Sep 25 '21 at 13:59