[Update: I'm not getting all zero gradients for the discriminator now. There might be some issue with my architecture or initialization of the layers. I'll try to fix it.]
I'm trying to train a conditional GAN in TensorFlow for image synthesis using a caption. This is the original paper: http://arxiv.org/abs/1605.05396
The problem that I'm facing is that the gradients that I get for the parameters of the discriminator are all zeros. I'm not able to backtrack the issue from this point on, since the discriminator loss is positive, and the gradient is computed using the predefined function tf.gradient(discriminator_loss, discriminator_variables)
Also, I have done this in PyTorch as well, but I didn't face an issue there, since the syntax for computing gradients in TensorFlow and PyTorch are somewhat different. So, I think that the issue is somewhere in my understanding of TensorFlow and not with the architectures of the Generator and Discriminator, but I could be wrong.
I'm pasting the important parts of the code below, if any of you could help me figure out the problem.
Please let me know if I should post more details or remove some clutter.
My guess is that something is wrong in the train_step function, but I also included my Generator and the Discriminator's architecture below (this might be too much code to read, though).
generator = Generator()
discriminator = Discriminator()
criterion = tf.keras.losses.BinaryCrossentropy(from_logits=True)
generator_optimizer = tf.keras.optimizers.Adam(1e-3, beta_1=0.5, beta_2=0.999)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-2, beta_1=0.5, beta_2=0.999)
def discriminator_loss(criterion, real_output, fake_output, wrong_caption_output):
real_labels = tf.ones_like(real_output)
fake_labels = tf.zeros_like(fake_output)
real_loss = criterion(real_labels, real_output)
fake_loss = criterion(fake_labels, fake_output) + criterion(fake_labels, wrong_caption_output)
total_loss = real_loss + fake_loss
return total_loss
def generator_loss(criterion, fake_output):
real_labels = tf.ones_like(fake_output)
loss = criterion(real_labels, fake_output)
return loss
@tf.function
def train_step(right_images, right_embed, wrong_images, wrong_embed, noise):
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator.forward(right_embed, noise, training=True) #(B, C, 64, 64)
real_output, _ = discriminator.forward(right_embed, right_images, training = True)
fake_output, _ = discriminator.forward(right_embed, tf.stop_gradient(generated_images), training = True)
wrong_caption_output, _ = discriminator.forward(wrong_embed, right_images, training = True)
# Disc Losses
disc_loss = discriminator_loss(criterion, real_output, fake_output, wrong_caption_output)
## Pass generated images through the trained disc
fake_output_1, _ = discriminator.forward(right_embed, generated_images, training = True)
# Gen loss
gen_loss = generator_loss(criterion, fake_output_1)
# Train Disc
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
# Train Gen
gen_variables = generator.trainable_variables
gradients_of_generator = gen_tape.gradient(gen_loss, gen_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, gen_variables))
return gen_loss, disc_loss, real_output, fake_output, gradients_of_generator, gradients_of_discriminator
The architecture of Generator and Discriminator:
class Generator(tf.Module):
def __init__(self):
super().__init__()
w_init = tf.random_normal_initializer(stddev=0.02)
gamma_init = tf.random_normal_initializer(1., 0.02)
model = tf.keras.Sequential()
model.add(layers.Dense(projected_embedding_size, input_shape = (embedding_size, ),
kernel_initializer = w_init))
model.add(layers.BatchNormalization(epsilon = 1e-5,
gamma_initializer = gamma_init))
model.add(layers.ReLU()) #(B, projected_embedding_size)
self.projection = model
model_1 = tf.keras.Sequential()
model_1.add(layers.Conv2DTranspose(filters = ngf*8, input_shape = (latent_dim, 1, 1),
kernel_size = 4, kernel_initializer = w_init,
strides= 1, padding = 'valid',
data_format='channels_first', use_bias = False))
model_1.add(layers.BatchNormalization(epsilon = 1e-5,
gamma_initializer = gamma_init))
model_1.add(layers.ReLU()) #(B, ngf*8, 4, 4)
model_1.add(layers.Conv2DTranspose(filters = ngf*4, kernel_size = 4,
strides= 2, kernel_initializer = w_init,
padding = 'same', data_format='channels_first', use_bias = False))
model_1.add(layers.BatchNormalization(epsilon = 1e-5,
gamma_initializer = gamma_init))
model_1.add(layers.ReLU()) # (B, ngf*8, 8, 8)
model_1.add(layers.Conv2DTranspose(filters = ngf*2, kernel_size = 4,
strides= 2, kernel_initializer = w_init,
padding = 'same', data_format='channels_first', use_bias = False))
model_1.add(layers.BatchNormalization(epsilon = 1e-5,
gamma_initializer = gamma_init))
model_1.add(layers.ReLU()) # (B, ngf*2, 16, 16)
model_1.add(layers.Conv2DTranspose(filters = ngf, kernel_size = 4,
strides= 2, kernel_initializer = w_init,
padding = 'same', data_format='channels_first', use_bias = False))
model_1.add(layers.BatchNormalization(epsilon = 1e-5,
gamma_initializer = gamma_init))
model_1.add(layers.ReLU()) # (B, ngf, 32, 32)
model_1.add(layers.Conv2DTranspose(filters = img_channels, kernel_size = 4,
strides= 2, kernel_initializer = w_init,
padding = 'same', data_format='channels_first', use_bias = False))
model_1.add(layers.Activation('tanh')) # (B, img_channels, 64, 64)
self.netG = model_1
def forward(self, embedding, noise, training = True):
projected_embedding = self.projection(embedding, training = training) # (B, projected_embedding_size)
noise = noise # (B, noise_dim)
input = tf.keras.backend.concatenate((noise, projected_embedding), axis = 1) #(B, projected_embedding_size + noise_dim)
input = tf.keras.backend.reshape(input, shape=(input.shape[0], input.shape[1], 1, 1))
output = self.netG(input, training = training) # (B, img_channels, 64, 64)
return output
class Discriminator(tf.Module):
def __init__(self):
super().__init__()
w_init = tf.random_normal_initializer(stddev=0.02)
gamma_init = tf.random_normal_initializer(1., 0.02)
model = tf.keras.Sequential() #(B, 64, 64, img_channels)
model.add(layers.Conv2D(filters = ndf,
input_shape = (generated_img_size, generated_img_size, img_channels),
kernel_size = 4,kernel_initializer = w_init,
strides= 2, padding = 'same',
data_format='channels_last', use_bias = False))
model.add(layers.BatchNormalization(epsilon = 1e-5,
gamma_initializer = gamma_init))
model.add(layers.LeakyReLU(0.2)) #(B, 32, 32, ndf)
model.add(layers.Conv2D(filters = ndf*2, kernel_size = 4, kernel_initializer = w_init,
strides= 2, padding = 'same',
data_format='channels_last', use_bias = False))
model.add(layers.BatchNormalization(epsilon = 1e-5,
gamma_initializer = gamma_init))
model.add(layers.LeakyReLU(0.2)) #(B, 16, 16, ndf*2)
model.add(layers.Conv2D(filters = ndf*4, kernel_size = 4, kernel_initializer = w_init,
strides= 2, padding = 'same',
data_format='channels_last', use_bias = False))
model.add(layers.BatchNormalization(epsilon = 1e-5,
gamma_initializer = gamma_init))
model.add(layers.LeakyReLU(0.2)) #(B, 8, 8, ndf*4)
model.add(layers.Conv2D(filters = ndf*8, kernel_size = 4, kernel_initializer = w_init,
strides= 2, padding = 'same',
data_format='channels_last', use_bias = False))
model.add(layers.BatchNormalization(epsilon = 1e-5,
gamma_initializer = gamma_init))
model.add(layers.LeakyReLU(0.2)) #(B, 4, 4, ndf*8)
self.netD_1 = model
# Projection model
model_1 = tf.keras.Sequential()
model_1.add(layers.Dense(projected_embedding_size, input_shape=(embedding_size,),
kernel_initializer = w_init))
model_1.add(layers.BatchNormalization(epsilon = 1e-5,
gamma_initializer = gamma_init))
model_1.add(layers.LeakyReLU(0.2))
self.projection = model_1
# Discriminator model 2 - Combining with captions embedding
model_2 = tf.keras.Sequential()
model_2.add(layers.Conv2D(filters = 1, input_shape = (4, 4, ndf *8 + projected_embedding_size),
kernel_size = 4, kernel_initializer = w_init,
strides= 1, padding = 'valid',
data_format='channels_last', use_bias = False))
model_2.add(layers.Activation('sigmoid'))
self.netD_2 = model_2
def forward(self, embedding, input, training = True):
projected_embedding = self.projection(embedding, training = training) # (B, projected_embedding_size)
projected_embedding = tf.keras.backend.reshape(projected_embedding,
shape = (1, 1, projected_embedding.shape[0],
projected_embedding.shape[1]))
projected_embedding = tf.keras.backend.repeat_elements(projected_embedding, rep =4, axis = 0)
projected_embedding = tf.keras.backend.repeat_elements(projected_embedding, rep =4, axis = 1)
projected_embedding = projected_embedding # (4, 4, B, projected_embedding_size)
projected_embedding = tf.keras.backend.permute_dimensions(projected_embedding, pattern = (2, 0, 1, 3)) # (B, 4, 4, projected_embedding_size)
# input = (B, C, 64, 64)
input = tf.keras.backend.permute_dimensions(input, pattern = (0, 2, 3, 1)) # (B, 64, 64, C)
x_intermediate = self.netD_1(input, training = training) # (B, 4, 4, ndf*8)
output = tf.keras.backend.concatenate((x_intermediate, projected_embedding), axis = 3) # (B, 4, 4, ndf*8 + projected_embedding_size)
output = self.netD_2(output, training = training) # (B, 1, 1, 1)
output = tf.keras.backend.reshape(output, shape= (output.shape[0], ))
return output, x_intermediate