0

I'm trying to setup a simple GANs training loop but am getting the following error:

RuntimeError: Trying to backward through the graph a second time (or directly access saved variables after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved variables after calling backward.

for epoch in range(N_EPOCHS):
    # gets data for the generator
    for i, batch in enumerate(dataloader, 0):

        # passing target images to the Discriminator
        global_disc.zero_grad()
        output_disc = global_disc(batch.to(device))
        error_target = loss(output_disc, torch.ones(output_disc.shape).cuda())
        error_target.backward()

        # apply mask to the images
        batch = apply_mask(batch)

        # passes fake images to the Discriminator
        global_output, local_output = gen(batch.to(device))
        output_disc = global_disc(global_output.detach())
        error_fake = loss(output_disc, torch.zeros(output_disc.shape).to(device))
        error_fake.backward()

        # combines the errors
        error_total = error_target + error_fake
        optimizer_disc.step()

        # updates the generator
        gen.zero_grad()
        error_gen = loss(output_disc, torch.ones(output_disc.shape).to(device))
        error_gen.backward()
        optimizer_gen.step()

        break
    break

As far as I can tell, I have the operations in the right order, I'm zeroing out the gradients, and I'm detaching the output of the generator before it goes into discriminator.

This article was helpful but I'm still running into something I don't understand.

desertnaut
  • 57,590
  • 26
  • 140
  • 166
user2962197
  • 218
  • 2
  • 12
  • Why are you passing your batch to the generator (`lobal_output, local_output = gen(batch.to(device))`)? – Ivan Sep 09 '21 at 19:23

1 Answers1

1

Two important points come to mind:

  1. You should feed your generator with noise, and not the real input:

    global_output, local_output = gen(noise.to(device))
    

Above noise should have the appropriate shape (it is the input of your generator).

  1. In order to optimize the generator, you are required to recompute the discriminator output, because it has already been backpropagated on. Simply add this line to recompute output_disc:

    # updates the generator
    gen.zero_grad()
    output_disc = global_disc(global_output)
    # ...
    

Please refer to this tutorial provided by PyTorch for a full walkthrough.

Ivan
  • 34,531
  • 8
  • 55
  • 100
  • 1
    Thanks, number 2 solved it. To answer your number 1, this GANs model is for image inpainiting so it's taking in an image with missing pieces and trying to generate a whole image. That's why the batch has a mask applied to it before going into the generator. – user2962197 Sep 09 '21 at 20:57
  • 1
    @user2962197 This makes sense! Thanks for explaining it. – Ivan Sep 09 '21 at 20:59
  • I still don't understand why we need to pass data to the discriminator twice. Yes, it's already been backpropagated, but by the last section of code I'm just trying to backpropagate the generator and I already have what I need to compute it's loss. – user2962197 Sep 09 '21 at 21:06
  • 1
    Good question, that's because we want to use the updated discriminator. In a GAN training of the generator and discriminator are alternated. So after `optimizer_disc.step()` (*i.e.* `global_disc` gets its weights updated) we have to compute a new disciminator output prediction. The input is the same, but the output won't be because the weights have changed! – Ivan Sep 09 '21 at 21:23