I'm trying to setup a simple GANs training loop but am getting the following error:
RuntimeError: Trying to backward through the graph a second time (or directly access saved variables after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved variables after calling backward.
for epoch in range(N_EPOCHS):
# gets data for the generator
for i, batch in enumerate(dataloader, 0):
# passing target images to the Discriminator
global_disc.zero_grad()
output_disc = global_disc(batch.to(device))
error_target = loss(output_disc, torch.ones(output_disc.shape).cuda())
error_target.backward()
# apply mask to the images
batch = apply_mask(batch)
# passes fake images to the Discriminator
global_output, local_output = gen(batch.to(device))
output_disc = global_disc(global_output.detach())
error_fake = loss(output_disc, torch.zeros(output_disc.shape).to(device))
error_fake.backward()
# combines the errors
error_total = error_target + error_fake
optimizer_disc.step()
# updates the generator
gen.zero_grad()
error_gen = loss(output_disc, torch.ones(output_disc.shape).to(device))
error_gen.backward()
optimizer_gen.step()
break
break
As far as I can tell, I have the operations in the right order, I'm zeroing out the gradients, and I'm detaching the output of the generator before it goes into discriminator.
This article was helpful but I'm still running into something I don't understand.