I'm trying to implement baseline als subtraction in pytorch so that I can run it on my GPU but I am running into problems because pytorch.gesv gives a different result than scipy.linalg.spsolve. Here is my code for scipy:
def baseline_als(y, lam, p, niter=10):
L = len(y)
D = sparse.diags([1,-2,1],[0,-1,-2], shape=(L,L-2))
w = np.ones(L)
for i in range(niter):
W = sparse.spdiags(w, 0, L, L)
Z = W + lam * D.dot(D.transpose())
z = spsolve(Z, w*y)
w = p * (y > z) + (1-p) * (y < z)
return z
and here is my code for pytorch
def baseline_als_pytorch(y, lam, p, niter=10):
diag = torch.tensor(np.repeat(1, L))
diag = torch.diag(diag, 0)
diag_minus_one = torch.tensor(np.repeat(-2, L - 1))
diag_minus_one = torch.diag(diag_minus_one, -1)
diag_minus_two = torch.tensor(np.repeat(1, L - 2))
diag_minus_two = torch.diag(diag_minus_two, -2)
D = diag + diag_minus_one + diag_minus_two
D = D[:, :L - 2].double()
w = torch.tensor(np.repeat(1, L)).double()
for i in range(10):
W = diag.double()
Z = W + lam * torch.mm(D, D.permute(1, 0))
z = torch.gesv(w * y, Z)
z = z[0].squeeze()
w = p * (y > z).double() + (1 - p) * (y < z).double()
return z
Sorry that the pytorch code looks so bad I'm just starting out in it.
I've confirmed that Z, w, and y are all the same going into both scipy and pytorch and that z is different between them right after I try to solve the system of equations.
Thanks for the comment, here is an example:
I use 100000 for lam and 0.001 for p.
Using the dummy input: y = (5,5,5,5,5,10,10,5,5,5,10,10,10,5,5,5,5,5,5,5),
I get (3.68010263, 4.90344214, 6.12679489, 7.35022406, 8.57384278, 9.79774074, 11.02197199, 12.2465927 , 13.47164891, 14.69711435,15.92287813, 17.14873257, 18.37456982, 19.60038184, 20.82626043,22.05215157, 23.27805103, 24.50400438, 25.73010693, 26.95625922) from scipy and
(6.4938312 , 6.46912395, 6.44440175, 6.41963499, 6.39477958,6.36977727, 6.34455582, 6.31907933, 6.29334844, 6.26735058, 6.24106029, 6.21443939, 6.18748732, 6.16024137, 6.13277694,6.10515785, 6.07743658, 6.04965455, 6.02184242, 5.99402035) from pytorch.
This is with just one iteration of the loop. Scipy is correct, pytorch is not.
Interestingly, if I use a shorter dummy input (5,5,5,5,5,10,10,5,5,5), I get the same answer from both. My real input is 1011 dimensional.