0

I have a typical algorithm for matrix multiplication. I am trying to apply and understand loop unrolling, but I am having a problem implementing the algorithm when I am trying to unroll k times when k isn't a multiple of the matrices size. (I get very large numbers as a result instead). That means I am not getting how to handle the remaining elements after unrolling. Here is what I have:

void Mult_Matx(unsigned long* a, unsigned long* b, unsigned long*c, long n)
{
    long i = 0, j = 0, k = 0;
    unsigned long sum, sum1, sum2, sum3, sum4, sum5, sum6, sum7;

    for (i = 0; i < n; i++)
    {
        long in = i * n;
        for (j = 0; j < n; j++)
        {
            sum = sum1 = sum2 = sum3 = sum4 = sum5 = sum6 = sum7 = 0;

            for (k = 0; k < n; k += 8)
            {
                sum = sum + a[in + k] * b[k * n + j];
                sum1 = sum1 + a[in + (k + 1)] * b[(k + 1) * n + j];
                sum2 = sum2 + a[in + (k + 2)] * b[(k + 2) * n + j];
                sum3 = sum3 + a[in + (k + 3)] * b[(k + 3) * n + j];
                sum4 = sum4 + a[in + (k + 4)] * b[(k + 4) * n + j];
                sum5 = sum5 + a[in + (k + 5)] * b[(k + 5) * n + j];
                sum6 = sum6 + a[in + (k + 6)] * b[(k + 6) * n + j];
                sum7 = sum7 + a[in + (k + 7)] * b[(k + 7) * n + j];
            }

            if (n % 8 != 0)
            {
                for (k = 8 * (n / 8); k < n; k++)
                {
                    sum = sum + a[in + k] * b[k * n + j];
                }
            }
            c[in + j] = sum + sum1 + sum2 + sum3 + sum4 + sum5 + sum6 + sum7;
        }
    }
}

Let's say size aka n is 12. When I unroll it 4 times, this code works, meaning when it never enters the remainder loop. But I am losing track of what's going on when it does! If anyone can direct me where I am going wrong, I'd really appreciate it. I am new to this, and having a hard time figuring out.

mch
  • 9,424
  • 2
  • 28
  • 42
Luc Aux
  • 149
  • 10
  • 2
    manual loop unrolling is so 80s ...(I'd say: don't. If you insist, take a look at [Duff's device](https://en.wikipedia.org/wiki/Duff%27s_device), handling the "remainder" by jumping somewhere inside the unrolled code) –  Nov 20 '17 at 11:28
  • @FelixPalmen lol. I am taking an intro operating systems class. So.... – Luc Aux Nov 20 '17 at 11:29
  • Perhaps you should take some time to [learn how to debug your programs](https://ericlippert.com/2014/03/05/how-to-debug-small-programs/)? – Some programmer dude Nov 20 '17 at 11:29
  • Also, instead of the temporary variable `sum1` to `sum7`, why not simply add to `sum`? Like e.g. `sum += a[in+(k+5)]* b[(k+5)*n+j]`? Or perhaps skip `sum` too, and add directly to `c[in + j]`? – Some programmer dude Nov 20 '17 at 11:30
  • @Someprogrammerdude I was simply trying to visualize things. I absolutely had no real reason to use this many variables. I would probably go with that if I actually had to implement sth like this. – Luc Aux Nov 20 '17 at 11:32
  • 2
    The different whitespace for the `sum3` line is driving me potty :) – Steve Nov 20 '17 at 11:42
  • @Steve Me too, so I changed it and everything else my eclipse auto formatter wanted to change. ;) – mch Nov 20 '17 at 11:49

1 Answers1

2

A generic way of unrolling a loop on this shape:

for(int i=0; i<N; i++)
    ...

is

int i;
for(i=0; i<N-L; i+=L)
    ...
for(; i<N; i++)
    ...

or if you want to keep the index variable in the scope of the loops:

for(int i=0; i<N-L; i+=L)
    ...
for(int i=L*(N/L); i<N; i++)
    ...

Here, I'm using the fact that integer division is rounded down. L is the number of steps you do in the first loop.

Example:

const int N=22;
const int L=6;
int i;
for(i=0; i<N-L; i+=L)
{
    printf("%d\n", i);
    printf("%d\n", i+1);
    printf("%d\n", i+2);
    printf("%d\n", i+3);
    printf("%d\n", i+4);
    printf("%d\n", i+5);
}
for(; i<N; i++)
    printf("%d\n", i);

But I recommend taking a look at Duff's device. However, I do suspect that it's not always a good thing to use. The reason is that modulo is a pretty expensive operation.

The condition if (n % 8 != 0) should not be needed. The for header should take care of that if written properly.

klutt
  • 30,332
  • 17
  • 55
  • 95