1

I'm a Beginner of java.

Recently, I'm writing a program to calculate matrix multiplication. So I write a class to do this.

public class MultiThreadsMatrixMultipy{
 public   int[][] multipy(int[][] matrix1,int[][] matrix2) {
     if(!utils.CheckDimension(matrix1,matrix2)){
         return null;
     }
     int row1 = matrix1.length;
     int col1 = matrix1[0].length;
     int row2 = matrix2.length;
     int col2 = matrix2[0].length;
     int[][] ans = new int[row1][col2];
     Thread[][]  threads = new SingleRowMultipy[row1][col2];

     for(int i=0;i<row1;i++){
         for(int j=0;j<col2;j++){
             threads[i][j] = new SingleRowMultipy(i,j,matrix1,matrix2,ans));
             threads[i][j].start();
         }
     }
     return ans;
 }
}
public class SingleRowMultipy extends Thread{
        private int row;
        private int col;
        private int[][] A;
        private int[][] B;
        private int[][] ans;
        public SingleRowMultipy(int row,int col,int[][] A,int[][] B,int[][] C){
            this.row = row;
            this.col = col;
            this.A = A;
            this.B = B;
            this.ans = C;
        }
        public void run(){
            int sum =0;
            for(int i=0;i<A[row].length;i++){
                 sum+=(A[row][i]*B[i][col]);
            }
            ans[row][col] = sum;
        }
}

I want to use a thread to calculate matrix1[i][:] * matrix2[:][j], and the size of the matrix is 1000*5000 and 5000*1000, so the number of threads is 1000*1000.When I run this program, it's super slow and it will cost about 38s. If I just use the single-thread to calculate the result, it will cost 17s. The single-threads code is below:

public class SimpleMatrixMultipy
{
    public int[][] multipy(int[][] matrix1,int[][] matrix2){
        int row1 = matrix1.length;
        int col1 = matrix1[0].length;
        int row2 = matrix2.length;
        int col2 = matrix2[0].length;
        int[][] ans = new int[row1][col2];
        for(int i=0;i<row1;i++){
            for(int j=0;j<col2;j++){
                for(int k=0;k<col1;k++){
                    ans[i][j] += matrix1[i][k]*matrix2[k][j];
                }
            }
        }
        return ans;
    }

}

What can I do to speed up the program?

Sandeep Kumar
  • 2,397
  • 5
  • 30
  • 37
TangDH
  • 11
  • 5
  • The parallelism is too fine-grained. Thread creation consumes time and memory. Try to limit the number of threads to a reasonable amount, e.g. something between 4 and 32 threads. Also, since the second matrix is accessed column-wise, it may be beneficial to transpose the second matrix or store it in col-major instead of row-major format to increase locality (i.e. caching). Look up "matrix tiling" for an optimized, parallel matrix multiplication algorithm. – Turing85 Dec 24 '19 at 08:42
  • Thanks a lot,I use a ThreadPool and the size of ThreadPool is 20.The program becomes faster a lot.:) – TangDH Dec 24 '19 at 08:53
  • The optimal number of threads is the same as number of logical processors in your system. – rustyx Dec 24 '19 at 10:03

1 Answers1

0

as @Turing85 said need to manage threads counts. There are two ways either use Executors.newFixedThreadPool with fixed number of threads or use Executors.newCachedThreadPool to use existing thread as available.

Other important point is to avoid inheriting Thread class directly, implement runnable instead.

import java.util.ArrayList;
import java.util.Date;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;

public class MultiThreadsMatrixMultipy {

    public static void main(final String[] args) {

    }

    public int[][] multipy(final int[][] matrix1, final int[][] matrix2) {
        if(!utils.CheckDimension(matrix1,matrix2)){
            return null;
        }
        final int row1 = matrix1.length;
        final int col2 = matrix2[0].length;
        final int[][] ans = new int[row1][col2];
        // final Executor executor = Executors.newCachedThreadPool(new CustomThreadFactory("Multiplier"));
        final Executor executor = Executors.newFixedThreadPool(20, new CustomThreadFactory("Multiplier"));

        for (int i = 0; i < row1; i++) {
            for (int j = 0; j < col2; j++) {
                executor.execute(new SingleRowMultipy(i, j, matrix1, matrix2, ans));
            }
        }
        return ans;
    }
}

class CustomThreadFactory implements ThreadFactory {
    private int counter;
    private final String name;
    private final List<String> stats;

    public CustomThreadFactory(final String name) {
        counter = 1;
        this.name = name;
        stats = new ArrayList<>();
    }

    @Override
    public Thread newThread(final Runnable runnable) {
        final Thread t = new Thread(runnable, name + "-Thread_" + counter);
        counter++;
        stats.add(String.format("Created thread %d with name %s on %s \n", t.getId(), t.getName(), new Date()));
        return t;
    }

    public String getStats() {
        final StringBuffer buffer = new StringBuffer();
        final Iterator<String> it = stats.iterator();
        while (it.hasNext()) {
            buffer.append(it.next());
        }
        return buffer.toString();
    }
}

class SingleRowMultipy implements Runnable {
    private final int row;
    private final int col;
    private final int[][] A;
    private final int[][] B;
    private final int[][] ans;

    public SingleRowMultipy(final int row, final int col, final int[][] A, final int[][] B, final int[][] C) {
        this.row = row;
        this.col = col;
        this.A = A;
        this.B = B;
        this.ans = C;
    }

    @Override
    public void run() {
        int sum = 0;
        for (int i = 0; i < A[row].length; i++) {
            sum += (A[row][i] * B[i][col]);
        }
        ans[row][col] = sum;
    }
}
ArpanKhandelwal
  • 137
  • 1
  • 5
  • Thanks a lot!I understand why I should manager threads counts.But I don't know the reason of implementing runnable instead of inheriting Thread class.Could you tell me why?:) – TangDH Dec 24 '19 at 10:18
  • @TangDH It has nothing do do with performance, however a good practice. Please read this https://stackoverflow.com/questions/541487/implements-runnable-vs-extends-thread-in-java for the ans. – ArpanKhandelwal Dec 24 '19 at 11:17