I think that you only asked if
- the first backward pass is slower than subsequent ones?
There are two more questions that might as well be combined here:
- is it more efficient to combine losses
- does keeping the graph (if not its values) require more memory?
Before that, however, let's emphasize on what retain_graph
is actually for: multiple passes through your graph if you computationally happen to have multiple outputs at multiple times. As an example, think joint multi-task-learning (see this question and its answers for a discussion on this).
Back to the questions: In general, I would expect that it does not really matter if you are retaining the graph. After all, it is just keeping partial computations in memory for future use, without "doing" anything with them.
That said - the first backwards pass will take longer, as pytorch will cache some of the computations needed when computing the gradients.
So here's proof:
import numpy as np
import torch
import torch.nn as nn
import time
import os
import psutil
D_in = 1024
model = nn.Sequential(nn.Linear(1024, 4096), nn.ReLU(), nn.Linear(4096, 4096), nn.ReLU(), nn.Linear(4096, 1024))
device = torch.device('cpu')
def loss1(y_pred,x):
return (y_pred*(0.5-x.clamp(0,1))).sum()
def loss2(y_pred,x):
return (y_pred*(1-x.clamp(0,1))).sum()
def timeit(func, repetitions):
time_taken = []
mem_used = []
for _ in range(repetitions):
time_start = time.time()
mem_used.append(func())
time_taken.append(time.time() - time_start)
return np.round([np.mean(time_taken), np.min(time_taken), np.max(time_taken), \
np.mean(mem_used), np.min(mem_used), np.max(mem_used)], 4).tolist()
# Predict random input
x = torch.rand(1,D_in, device=device,requires_grad=True)
def init():
out = model(x)
loss = loss1(out, x)
loss.backward()
def func1():
x = torch.rand(1, D_in, device=device, requires_grad=True)
loss = loss1(model(x),x)
loss.backward()
loss = loss2(model(x),x)
loss.backward()
del x
process = psutil.Process(os.getpid())
return process.memory_info().rss
def func2():
x = torch.rand(1, D_in, device=device, requires_grad=True)
loss = loss1(model(x),x) + loss2(model(x),x)
loss.backward()
del x
process = psutil.Process(os.getpid())
return process.memory_info().rss
def func3():
x = torch.rand(1, D_in, device=device, requires_grad=True)
loss = loss1(model(x),x)
loss.backward(retain_graph=True)
loss = loss2(model(x),x)
loss.backward(retain_graph=True)
del x
process = psutil.Process(os.getpid())
return process.memory_info().rss
def func4():
x = torch.rand(1, D_in, device=device, requires_grad=True)
loss = loss1(model(x),x) + loss2(model(x),x)
loss.backward(retain_graph=True)
del x
process = psutil.Process(os.getpid())
return process.memory_info().rss
init()
print(timeit(func1, 100))
print(timeit(func2, 100))
print(timeit(func3, 100))
print(timeit(func4, 100))
The results are (sorry for my lazy formatting):
# time mean, time min, time max, memory mean, memory min, memory max
[0.1165, 0.1138, 0.1297, 383456419.84, 365731840.0, 384438272.0]
[0.127, 0.1233, 0.1376, 400914759.68, 399638528.0, 434044928.0]
[0.1167, 0.1136, 0.1272, 400424468.48, 399577088.0, 401223680.0]
[0.1263, 0.1226, 0.134, 400815964.16, 399556608.0, 434307072.0]
However, if you skip the first backwards pass (comment out the call to the init()
function), the very first backwards run in func1
will take longer:
# time mean, time min, time max, memory mean, memory min, memory max
[0.1208, 0.1136, **0.1579**, 350157455.36, 349331456.0, 350978048.0]
[0.1297, 0.1232, 0.1499, 393928540.16, 350052352.0, 401854464.0]
[0.1197, 0.1152, 0.1547, 350787338.24, 349982720.0, 351629312.0]
[0.1335, 0.1229, 0.1793, 382819123.2, 349929472.0, 401776640.0]