4

I have the following four nested loops in Matlab:

timesteps = 5;
inputsize = 10;
additionalinputsize = 3;
outputsize = 7;

input = randn(timesteps, inputsize);
additionalinput = randn(timesteps, additionalinputsize);
factor = randn(inputsize, additionalinputsize, outputsize);

output = zeros(timesteps,outputsize);
for t=1:timesteps
    for i=1:inputsize
        for o=1:outputsize
            for a=1:additionalinputsize
                output(t,o) = output(t,o) + factor(i,a,o) * input(t,i) * additionalinput(t,a);
            end
        end
    end
end

There are three vectors: One input vector, one additional input vector and an output vector. All the are connected by factors. Every vector has values at given timesteps. I need the sum of all combined inputs, additional inputs and factors at every given timestep. Later, I need to calculate from the output to the input:

result2 = zeros(timesteps,inputsize);
for t=1:timesteps
    for i=1:inputsize
        for o=1:outputsize
            for a=1:additionalinputsize
                result2(t,i) = result2(t,i) + factor(i,a,o) * output(t,o) * additionalinput(t,a);
            end
        end
    end
end

In a third case, I need the product of all three vectors summed over every timestep:

product = zeros(inputsize,additionalinputsize,outputsize)
for t=1:timesteps
    for i=1:inputsize
        for o=1:outputsize
            for a=1:additionalinputsize
                product(i,a,o) = product(i,a,o) + input(t,i) * output(t,o) * additionalinput(t,a);
            end
        end
    end
end

The two code snippets work but are incredibly slow. How can I remove the nested loops?

Edit: Added values and changed minor things so the snippets are executable

Edit2: Added other use case

Divakar
  • 218,885
  • 19
  • 262
  • 358
user1406177
  • 1,328
  • 2
  • 22
  • 36

1 Answers1

2

First Part

One approach -

t1 = bsxfun(@times,additionalinput,permute(input,[1 3 2]));
t2 = bsxfun(@times,t1,permute(factor,[4 2 1 3]));
t3 = permute(t2,[2 3 1 4]);
output = squeeze(sum(sum(t3)));

Or a slight variant to avoid squeeze -

t1 = bsxfun(@times,additionalinput,permute(input,[1 3 2]));
t2 = bsxfun(@times,t1,permute(factor,[4 2 1 3]));
t3 = permute(t2,[1 4 2 3]);
output = sum(sum(t3,3),4); 

Second Part

t11 = bsxfun(@times,additionalinput,permute(output,[1 3 2]));
t22 = bsxfun(@times,permute(t11,[1 4 2 3]),permute(factor,[4 1 2 3]));
result2=sum(sum(t22,3),4);

Third Part

t11 = bsxfun(@times,permute(output,[4 3 2 1]),permute(additionalinput,[4 2 3 1]));
t22 = bsxfun(@times,permute(input,[2 4 3 1]),t11);
product = sum(t22,4);
Divakar
  • 218,885
  • 19
  • 262
  • 358
  • Works great! Could you also help me with the third use case I just added? – user1406177 Jun 04 '14 at 16:17
  • 2
    @user1406177 Added solution for that too. – Divakar Jun 04 '14 at 17:05
  • Great answer. Always nice to see bsxfun in action. – Fred S Jun 04 '14 at 18:24
  • @FredS It was a great problem to let `bsxfun` explore! Thanks btw :) – Divakar Jun 04 '14 at 18:26
  • Great! One last thing: In the second part, I want the factor and additionalinput in the first line and t11 combined with the output in the second. I figured the first line out, but I'm struggling with the second line. My first one looks like this: t11 = bsxfun(@times,additionalinput,permute(factor,[4 2 1 3])); – user1406177 Jun 04 '14 at 18:26
  • 1
    @user1406177 `t11 = bsxfun(@times,additionalinput,permute(factor,[4 2 1 3])); t22 = bsxfun(@times,t11,permute(output,[1 3 4 2])); result2 = squeeze(sum(sum(t22,2),4));` – Divakar Jun 04 '14 at 18:38