When I run several calculations with torch.einsum in a row, the first one is always much slower than the following calculations.
The following code and plot illustrates the problem:
import torch as tor
from timeit import default_timer as timer
N = 1000
L = 10
time_arr = np.zeros(L)
for i in range(L):
a = tor.randn(N, N).to("cuda:0") #3 random 1000x1000 matrices for each cycle
b = tor.randn(N, N).to("cuda:0")
c = tor.randn(N, N).to("cuda:0")
time_start = timer()
tor.einsum("ij, kj",tor.einsum("ij, ja", aa, ab), ac)
time_end = timer()
time_arr[i] = time_end - time_start