2

[Update: I'm not getting all zero gradients for the discriminator now. There might be some issue with my architecture or initialization of the layers. I'll try to fix it.]

I'm trying to train a conditional GAN in TensorFlow for image synthesis using a caption. This is the original paper: http://arxiv.org/abs/1605.05396

The problem that I'm facing is that the gradients that I get for the parameters of the discriminator are all zeros. I'm not able to backtrack the issue from this point on, since the discriminator loss is positive, and the gradient is computed using the predefined function tf.gradient(discriminator_loss, discriminator_variables)

Also, I have done this in PyTorch as well, but I didn't face an issue there, since the syntax for computing gradients in TensorFlow and PyTorch are somewhat different. So, I think that the issue is somewhere in my understanding of TensorFlow and not with the architectures of the Generator and Discriminator, but I could be wrong.

I'm pasting the important parts of the code below, if any of you could help me figure out the problem.

Please let me know if I should post more details or remove some clutter.

My guess is that something is wrong in the train_step function, but I also included my Generator and the Discriminator's architecture below (this might be too much code to read, though).

generator = Generator()
discriminator = Discriminator()
criterion = tf.keras.losses.BinaryCrossentropy(from_logits=True)

generator_optimizer = tf.keras.optimizers.Adam(1e-3, beta_1=0.5, beta_2=0.999)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-2, beta_1=0.5, beta_2=0.999)

def discriminator_loss(criterion, real_output, fake_output, wrong_caption_output):

    real_labels = tf.ones_like(real_output)    
    fake_labels = tf.zeros_like(fake_output)

    real_loss = criterion(real_labels, real_output)
    fake_loss = criterion(fake_labels, fake_output) + criterion(fake_labels, wrong_caption_output)

    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(criterion, fake_output):
    real_labels = tf.ones_like(fake_output)
    loss = criterion(real_labels, fake_output)
    return loss

@tf.function
def train_step(right_images, right_embed, wrong_images, wrong_embed, noise):

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator.forward(right_embed, noise, training=True) #(B, C, 64, 64)

        real_output, _ = discriminator.forward(right_embed, right_images, training = True)
        fake_output, _ = discriminator.forward(right_embed, tf.stop_gradient(generated_images), training = True)

        wrong_caption_output, _ = discriminator.forward(wrong_embed, right_images, training = True)

        # Disc Losses
        disc_loss = discriminator_loss(criterion, real_output, fake_output, wrong_caption_output)


        ## Pass generated images through the trained disc
        fake_output_1, _ = discriminator.forward(right_embed, generated_images, training = True)

        # Gen loss
        gen_loss = generator_loss(criterion, fake_output_1)

    # Train Disc
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    # Train Gen
    gen_variables = generator.trainable_variables

    gradients_of_generator = gen_tape.gradient(gen_loss, gen_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, gen_variables))

    return gen_loss, disc_loss, real_output, fake_output, gradients_of_generator, gradients_of_discriminator

The architecture of Generator and Discriminator:


class Generator(tf.Module):
    def __init__(self):
        super().__init__()

        w_init = tf.random_normal_initializer(stddev=0.02)

        gamma_init = tf.random_normal_initializer(1., 0.02)

        model = tf.keras.Sequential()
        model.add(layers.Dense(projected_embedding_size, input_shape = (embedding_size, ),
                              kernel_initializer = w_init))
        model.add(layers.BatchNormalization(epsilon = 1e-5, 
                                            gamma_initializer = gamma_init))
        model.add(layers.ReLU()) #(B, projected_embedding_size)

        self.projection = model

        model_1 = tf.keras.Sequential()

        model_1.add(layers.Conv2DTranspose(filters = ngf*8, input_shape = (latent_dim, 1, 1),
                                            kernel_size = 4, kernel_initializer = w_init,
                                         strides= 1, padding = 'valid', 
                                              data_format='channels_first', use_bias = False))
        model_1.add(layers.BatchNormalization(epsilon = 1e-5,
                                             gamma_initializer = gamma_init))
        model_1.add(layers.ReLU()) #(B, ngf*8, 4, 4)


        model_1.add(layers.Conv2DTranspose(filters = ngf*4, kernel_size = 4,
                                         strides= 2, kernel_initializer = w_init,
                                         padding = 'same', data_format='channels_first', use_bias = False))
        model_1.add(layers.BatchNormalization(epsilon = 1e-5,
                                             gamma_initializer = gamma_init))
        model_1.add(layers.ReLU()) # (B, ngf*8, 8, 8)


        model_1.add(layers.Conv2DTranspose(filters = ngf*2, kernel_size = 4, 
                                         strides= 2, kernel_initializer = w_init,
                                         padding = 'same', data_format='channels_first', use_bias = False))
        model_1.add(layers.BatchNormalization(epsilon = 1e-5,
                                             gamma_initializer = gamma_init))
        model_1.add(layers.ReLU()) # (B, ngf*2, 16, 16)


        model_1.add(layers.Conv2DTranspose(filters = ngf, kernel_size = 4, 
                                         strides= 2, kernel_initializer = w_init,
                                         padding = 'same', data_format='channels_first', use_bias = False))
        model_1.add(layers.BatchNormalization(epsilon = 1e-5,
                                             gamma_initializer = gamma_init))
        model_1.add(layers.ReLU()) # (B, ngf, 32, 32)


        model_1.add(layers.Conv2DTranspose(filters = img_channels, kernel_size = 4, 
                                         strides= 2, kernel_initializer = w_init,
                                         padding = 'same', data_format='channels_first', use_bias = False))
        model_1.add(layers.Activation('tanh')) # (B, img_channels, 64, 64)

        self.netG = model_1

    def forward(self, embedding, noise, training = True):
        projected_embedding = self.projection(embedding, training = training) # (B, projected_embedding_size)
        noise = noise # (B, noise_dim)

        input = tf.keras.backend.concatenate((noise, projected_embedding), axis = 1) #(B, projected_embedding_size + noise_dim)
        input = tf.keras.backend.reshape(input, shape=(input.shape[0], input.shape[1], 1, 1))
        output = self.netG(input, training = training) # (B, img_channels, 64, 64)

        return output

class Discriminator(tf.Module):
    def __init__(self):
        super().__init__()

        w_init = tf.random_normal_initializer(stddev=0.02)

        gamma_init = tf.random_normal_initializer(1., 0.02)
        model = tf.keras.Sequential() #(B, 64, 64, img_channels)

        model.add(layers.Conv2D(filters = ndf, 
                                input_shape = (generated_img_size, generated_img_size, img_channels),
                                kernel_size = 4,kernel_initializer = w_init,
                                strides= 2, padding = 'same', 
                                data_format='channels_last', use_bias = False))
        model.add(layers.BatchNormalization(epsilon = 1e-5,
                                           gamma_initializer = gamma_init))
        model.add(layers.LeakyReLU(0.2)) #(B, 32, 32, ndf)


        model.add(layers.Conv2D(filters = ndf*2, kernel_size = 4, kernel_initializer = w_init,
                                strides= 2, padding = 'same', 
                                              data_format='channels_last', use_bias = False))
        model.add(layers.BatchNormalization(epsilon = 1e-5,
                                           gamma_initializer = gamma_init))
        model.add(layers.LeakyReLU(0.2)) #(B, 16, 16, ndf*2)


        model.add(layers.Conv2D(filters = ndf*4, kernel_size = 4, kernel_initializer = w_init,
                                strides= 2, padding = 'same', 
                                              data_format='channels_last', use_bias = False))
        model.add(layers.BatchNormalization(epsilon = 1e-5,
                                           gamma_initializer = gamma_init))
        model.add(layers.LeakyReLU(0.2)) #(B, 8, 8, ndf*4)


        model.add(layers.Conv2D(filters = ndf*8, kernel_size = 4, kernel_initializer = w_init,
                                strides= 2, padding = 'same', 
                                              data_format='channels_last', use_bias = False))
        model.add(layers.BatchNormalization(epsilon = 1e-5,
                                           gamma_initializer = gamma_init))
        model.add(layers.LeakyReLU(0.2)) #(B, 4, 4, ndf*8)

        self.netD_1 = model

        # Projection model
        model_1 = tf.keras.Sequential()
        model_1.add(layers.Dense(projected_embedding_size, input_shape=(embedding_size,),
                                kernel_initializer = w_init))
        model_1.add(layers.BatchNormalization(epsilon = 1e-5,
                                             gamma_initializer = gamma_init))
        model_1.add(layers.LeakyReLU(0.2))

        self.projection = model_1

        # Discriminator model 2 - Combining with captions embedding
        model_2 = tf.keras.Sequential()
        model_2.add(layers.Conv2D(filters = 1, input_shape = (4, 4, ndf *8 + projected_embedding_size),
                                            kernel_size = 4, kernel_initializer = w_init,
                                strides= 1, padding = 'valid', 
                                              data_format='channels_last', use_bias = False))
        model_2.add(layers.Activation('sigmoid'))

        self.netD_2 = model_2

    def forward(self, embedding, input, training = True):
        projected_embedding = self.projection(embedding, training = training) # (B, projected_embedding_size)
        projected_embedding = tf.keras.backend.reshape(projected_embedding, 
                                                       shape = (1, 1, projected_embedding.shape[0], 
                                                                projected_embedding.shape[1]))
        projected_embedding = tf.keras.backend.repeat_elements(projected_embedding, rep =4, axis = 0)
        projected_embedding = tf.keras.backend.repeat_elements(projected_embedding, rep =4, axis = 1)

        projected_embedding = projected_embedding  # (4, 4, B, projected_embedding_size)
        projected_embedding = tf.keras.backend.permute_dimensions(projected_embedding, pattern = (2, 0, 1, 3)) # (B, 4, 4, projected_embedding_size)

        # input = (B, C, 64, 64)
        input = tf.keras.backend.permute_dimensions(input, pattern = (0, 2, 3, 1)) # (B, 64, 64, C)

        x_intermediate = self.netD_1(input, training = training) # (B, 4, 4, ndf*8)

        output = tf.keras.backend.concatenate((x_intermediate, projected_embedding), axis = 3)  # (B, 4, 4, ndf*8 + projected_embedding_size)

        output = self.netD_2(output, training = training) # (B, 1, 1, 1)
        output = tf.keras.backend.reshape(output, shape= (output.shape[0], ))

        return output, x_intermediate
Ajit Kumar
  • 21
  • 3

0 Answers0