I am trying to write a program that will do matrix multiplication recursively using a 2D array represented as a 1D array in column major order. Here is what I have right now.
This is the main method for the recursive calls.
vector<double> recursive_mult(int n, vector<double> A, vector<double> B) {
vector<double> C;
if (n == 1) {
C.push_back(A[0] * B[0]);
} else {
C = equals(C, recursive_mult(n/2, matrix_partitioner(n/2, 0, A), matrix_partitioner(n/2, 0, B))
+ recursive_mult(n/2, matrix_partitioner(n/2, 1, A), matrix_partitioner(n/2, 2, B)));
C = equals(C, recursive_mult(n/2, matrix_partitioner(n/2, 0, A), matrix_partitioner(n/2, 1, B))
+ recursive_mult(n/2, matrix_partitioner(n/2, 1, A), matrix_partitioner(n/2, 3, B)));//1
C = equals(C, recursive_mult(n/2, matrix_partitioner(n/2, 2, A), matrix_partitioner(n/2, 0, B))
+ recursive_mult(n/2, matrix_partitioner(n/2, 3, A), matrix_partitioner(n/2, 2, B)));//2
C = equals(C, recursive_mult(n/2, matrix_partitioner(n/2, 2, A), matrix_partitioner(n/2, 1, B))
+ recursive_mult(n/2, matrix_partitioner(n/2, 3, A), matrix_partitioner(n/2, 3, B)));//3
}
return C;
}
Matrix partitioner gets a specific quadrant from the given matrix.
vector<double> matrix_partitioner(int n, int section, vector<double> A) {
vector<double> C(n*n);
int start_i, start_j, tmp_j;
if (section == 0) {
start_i = 0;
tmp_j = 0;
}
else if (section == 1) {
start_i = 0;
tmp_j = n;
}
else if (section == 2) {
start_i = n;
tmp_j = 0;
}
else if (section == 3) {
start_i = n;
tmp_j = n;
}
for (int i = 0; i < n; i++) {
start_j = tmp_j;
for (int j = 0; j < n; j++) {
C[i+(j*n)] = A[start_i+(start_j*(n*2))];
start_j++;
}
start_i++;
}
return C;
}
Equals puts the result from adding the two matrices together into C
vector<double> equals(vector<double> A, vector<double> B) {
for (int i = 0; i < B.size(); i++) {
A.push_back(B[i]);
}
return A;
}
I have also overloaded the '+' operator to make adding the matrices easier.
These are the results I am getting (I have the iterative method results to compare to, they both use the same print method):
Iterative
| 250 260 270 280 |
| 618 644 670 696 |
| 986 1028 1070 1112 |
| 1354 1412 1470 1528 |
Recursive
| 250 270 986 1070 |
| 260 280 1028 1112 |
| 618 670 1354 1470 |
| 644 696 1412 1528 |
Clearly my recursive results are not correct (or at least not in the right order) but I do not know how to fix my code to make it print properly. Can someone help me fix this code?
I have tried re-ordering the equals statements, but no luck