2

I would like to calculate the gradient of my model for several loss functions.

I would like to find out if calculating successive backwards calls with retain_graph=True is cheap or expensive.

In theory I would expect that the first call should be slower than those following the first, because the computational graph does not have to be reevaluated, but just a few matrix multiplications need to be made.

In practice I found it hard to benchmark.

My code:

# Code in file nn/two_layer_net_nn.py
import torch

D_in = 40
model = torch.load('model.pytorch')
device = torch.device('cpu')
def loss1(y_pred,x):
    return (y_pred*(0.5-x.clamp(0,1))).sum()

def loss2(y_pred,x):
    return (y_pred*(1-x.clamp(0,1))).sum()


# Predict random input    
x = torch.rand(1,D_in, device=device,requires_grad=True)
y_pred = model(x)

# Is this 
%%timeit
loss = loss1(y_pred,x)
loss.backward(retain_graph=True)

202 µs ± 4.34 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

# Slower than this?
%%timeit
loss = loss2(y_pred,x)
loss.backward(retain_graph=True)

216 µs ± 27.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


# Are successive backwards calls cheap?
loss = lossX(y_pred,x)
loss.backward(retain_graph=True)

I think that %%timeit doesn't work because it will run several iterations and then average over it.

How can I measure whether successive calls to backward will be fast? What does retain_graph=True actually mean for performance?

Sebastian Wozny
  • 16,943
  • 7
  • 52
  • 69

1 Answers1

1

I think that you only asked if

  1. the first backward pass is slower than subsequent ones?

There are two more questions that might as well be combined here:

  1. is it more efficient to combine losses
  2. does keeping the graph (if not its values) require more memory?

Before that, however, let's emphasize on what retain_graph is actually for: multiple passes through your graph if you computationally happen to have multiple outputs at multiple times. As an example, think joint multi-task-learning (see this question and its answers for a discussion on this).

Back to the questions: In general, I would expect that it does not really matter if you are retaining the graph. After all, it is just keeping partial computations in memory for future use, without "doing" anything with them.

That said - the first backwards pass will take longer, as pytorch will cache some of the computations needed when computing the gradients.

So here's proof:

import numpy as np
import torch
import torch.nn as nn
import time
import os
import psutil
D_in = 1024
model = nn.Sequential(nn.Linear(1024, 4096), nn.ReLU(), nn.Linear(4096, 4096), nn.ReLU(), nn.Linear(4096, 1024))
device = torch.device('cpu')
def loss1(y_pred,x):
    return (y_pred*(0.5-x.clamp(0,1))).sum()


def loss2(y_pred,x):
    return (y_pred*(1-x.clamp(0,1))).sum()


def timeit(func, repetitions):
    time_taken = []
    mem_used = []
    for _ in range(repetitions):
        time_start = time.time()
        mem_used.append(func())
        time_taken.append(time.time() - time_start)
    return np.round([np.mean(time_taken), np.min(time_taken), np.max(time_taken), \
           np.mean(mem_used), np.min(mem_used), np.max(mem_used)], 4).tolist()


# Predict random input
x = torch.rand(1,D_in, device=device,requires_grad=True)

def init():
    out = model(x)
    loss = loss1(out, x)
    loss.backward()

def func1():
    x = torch.rand(1, D_in, device=device, requires_grad=True)
    loss = loss1(model(x),x)
    loss.backward()
    loss = loss2(model(x),x)
    loss.backward()
    del x
    process = psutil.Process(os.getpid())
    return process.memory_info().rss

def func2():
    x = torch.rand(1, D_in, device=device, requires_grad=True)
    loss = loss1(model(x),x) + loss2(model(x),x)
    loss.backward()
    del x
    process = psutil.Process(os.getpid())
    return process.memory_info().rss


def func3():
    x = torch.rand(1, D_in, device=device, requires_grad=True)
    loss = loss1(model(x),x)
    loss.backward(retain_graph=True)
    loss = loss2(model(x),x)
    loss.backward(retain_graph=True)
    del x
    process = psutil.Process(os.getpid())
    return process.memory_info().rss


def func4():
    x = torch.rand(1, D_in, device=device, requires_grad=True)
    loss = loss1(model(x),x) + loss2(model(x),x)
    loss.backward(retain_graph=True)
    del x
    process = psutil.Process(os.getpid())
    return process.memory_info().rss

init()
print(timeit(func1, 100))
print(timeit(func2, 100))
print(timeit(func3, 100))
print(timeit(func4, 100))

The results are (sorry for my lazy formatting):

# time mean, time min, time max, memory mean, memory min, memory max
[0.1165, 0.1138, 0.1297, 383456419.84, 365731840.0, 384438272.0]
[0.127, 0.1233, 0.1376, 400914759.68, 399638528.0, 434044928.0]
[0.1167, 0.1136, 0.1272, 400424468.48, 399577088.0, 401223680.0]
[0.1263, 0.1226, 0.134, 400815964.16, 399556608.0, 434307072.0]

However, if you skip the first backwards pass (comment out the call to the init() function), the very first backwards run in func1 will take longer:

# time mean, time min, time max, memory mean, memory min, memory max
[0.1208, 0.1136, **0.1579**, 350157455.36, 349331456.0, 350978048.0]
[0.1297, 0.1232, 0.1499, 393928540.16, 350052352.0, 401854464.0]
[0.1197, 0.1152, 0.1547, 350787338.24, 349982720.0, 351629312.0]
[0.1335, 0.1229, 0.1793, 382819123.2, 349929472.0, 401776640.0]
cleros
  • 4,005
  • 1
  • 20
  • 30