Took me a while to figure out, but:
foo <- apply(matr2,2,function(x) rev(cumprod(x+1)))
matr3 <- matr1*t(foo[,x])
-- PROOF --
set.seed(100)
n = 5
m = 3
k = m * n
matr1 <- matrix(sample(seq(0,1, by = 0.1), size = k * n, replace = T), nrow = k, ncol = n )
matr2 <- matrix(sample(seq(0,1, by = 0.1), size = m * n, replace = T), nrow = n, ncol = m)
x <- sample(1:m, size = k, replace = T)
foo <- apply(matr2,2,function(x) rev(cumprod(x+1)))
matr3 <- matr1*t(foo[,x])
for( i in 1:k){
for( j in 1:n){
matr1[i, 1:(n-j+1)] <- matr1[i, 1:(n-j+1)] +
matr1[i, 1:(n-j+1)] * matr2[j , x[i]]
}
}
all.equal(matr3,matr1)
# TRUE
-- EXPLANATION --
So it took me a while to figure this out correctly, but here goes... Assuming your code and assuming i = 1
, we can basically write for j=1
:
matr1[1,1:5] <- matr1[1,1:5] + matr1[1,1:5] * matr2[1,3]
So you take row 1, columns 1 to 5, and you update these numbers with the original number PLUS those numbers times some other number (in this case 0.8
). Then, when j=2
:
matr1[1,1:4] <- matr1[1,1:4] + matr1[1,1:4] * matr2[2,3]
So now you only take all columns but n
itself, and update the value in the same way as step 1. In the end, the pattern that should be clear is that matr1[1,1]
is updated n
times, whereas matr[1,n]
is updated 1
time (with only matr2[1,3]
.
We exploit this pattern by pre-calculating all steps in one go. We do that with:
foo <- apply(matr2,2,function(x) rev(cumprod(x+1)))
This basically is a new table that contains, for each column of matr1[i,]
, a number. This number is a combination of all loops that your previous code ran into a single number. So, instead of matr1[1,1]
requiring 5 multiplications, we now just do 1.
So now we have:
for (i in 1:k) for (j in 1:n) matr1[i,j] <- matr1[i,j] * foo[j,x[i]]
We can reduce that to:
for (i in 1:k) matr1[i,] <- matr1[i,] * foo[,x[i]]
Since i
always goes from 1:k
for every single time you index it, we can just vectorize that as well:
matr <- matr*t(foo[,x])
And we're done.
-- BENCHMARK --
I reran the code block that I gave as a proof, but with n=100
and m=100
.
Your code:
# user system elapsed
# 6.85 0.00 6.86
My code:
# user system elapsed
# 0.02 0.00 0.02