1

I'm actually working on basic GANs using TF keras, here I use train_on_batch method for training generator and discriminator alternatively which has no callbacks parameter for writing tensorboard logs as like fit method of keras model. Now I want to write logs of model during training for monitoring weights and gradients on the tensorboard.

The training code part is as follows,

def train(g_model, d_model, gan_model, dataset, latent_dim, seed, n_epochs=100, n_batch=128):
  bat_per_epo = int(dataset.shape[0] / n_batch)
  half_batch = int(n_batch / 2)

  for i in range(n_epochs):
    for j in range(bat_per_epo):
      # Training discriminator with real images
      X_real, y_real = generate_real_samples(dataset, half_batch)
      d_loss1, _ = d_model.train_on_batch(X_real, y_real * .9) # Label Smoothing

      # Training discriminator with fake images
      X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
      d_loss2, _ = d_model.train_on_batch(X_fake, y_fake + .1) # Label Smoothing

      # Training generator with latent points
      X_gan = generate_latent_points(latent_dim, n_batch)
      y_gan = ones((n_batch, 1))

      g_loss = gan_model.train_on_batch(X_gan, y_gan)

      if not j%10:
        print('>%d, %d/%d, d1=%.3f, d2=%.3f g=%.3f' %(i+1, j+1, bat_per_epo, d_loss1, d_loss2, g_loss))

    display.clear_output(True)
    print('>%d, %d/%d, d1=%.3f, d2=%.3f g=%.3f' %(i+1, j+1, bat_per_epo, d_loss1, d_loss2, g_loss))
    summarize_performance(i, g_model, d_model, dataset, latent_dim, seed)

  display.clear_output(True)
  print('>%d, %d/%d, d1=%.3f, d2=%.3f g=%.3f' %(i+1, j+1, bat_per_epo, d_loss1, d_loss2, g_loss))
  summarize_performance(i, g_model, d_model, dataset, latent_dim, seed)

Some of the ways for calculating gradients I've found is here,

  1. How to obtain the gradients in keras?
  2. Getting gradient of model output w.r.t weights using Keras

But I'm confused with this and how to log the gradients without callbacks option. Can someone please help me on this?

Sanjay
  • 71
  • 10

0 Answers0