#define TS 32
int num_devices = 0;
__global__ void shared_kernel(float* A, float* B, float* C, int M, int N, int K) {
int global_col = blockDim.x * blockIdx.x + threadIdx.x;
int global_row = blockDim.y * blockIdx.y + threadIdx.y;
int local_col = threadIdx.x;
int local_row = threadIdx.y;
if (global_row >= M || global_col >= N) return;
__shared__ float Asub[TS][TS];
__shared__ float Bsub[TS][TS];
const int num_tiles = K / TS;
float acc = 0;
for(int t = 0; t < num_tiles; t++){
const int t_row = TS * t + local_row;
const int t_col = TS * t + local_col;
Asub[local_row][local_col] = A[global_row * K + t_col];
Bsub[local_row][local_col] = B[t_row * N + global_col];
__syncthreads();
printf("[DEBUG] first sync threads, global_row: %d, global_col: %d\n", global_row, global_col);
for (int k = 0; k < K; ++k) {
acc += Asub[local_row][k] * Bsub[k][local_col];
}
__syncthreads();
printf("[DEBUG] second sync threads, global_row: %d, global_col: %d\n", global_row, global_col);
}
C[global_row * N + global_col] = acc;
}
static float *a_d, *b_d, *c_d;
void mat_mul(float *A, float *B, float *C, int M, int N, int K) {
cudaMemcpy(a_d, A, M * K * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(b_d, B, K * N * sizeof(float), cudaMemcpyHostToDevice);
dim3 blockDim(TS, TS);
dim3 gridDim(M/TS, N/TS);
shared_kernel<<<gridDim, blockDim>>>(a_d, b_d, c_d, M, N, K);
cudaMemcpy(C, c_d, M * N * sizeof(float), cudaMemcpyDeviceToHost);
cudaDeviceSynchronize();
}
void mat_mul_init(float *A, float *B, float *C, int M, int N, int K) {
cudaGetDeviceCount(&num_devices);
cudaSetDevice(0);
cudaMalloc(&a_d, M * K * sizeof(float));
cudaMalloc(&b_d, K * N * sizeof(float));
cudaMalloc(&c_d, M * N * sizeof(float));
}
Above example is a matrix multiplication with shared memory.
I ran above kernel with dim3 blockDim(TS, TS)
and dim3 gridDim(M/TS, N/TS)
and M, N, K = 128.
I checked that float * C
has zero value after launching kernel. Also, I found that only few of global_row are printed(from 37 to 81) after first __syncthreads()
, and there is no printf
DEBUG message after the second __syncthreads()
.
I suspect that __syncthreads()
is causing the problem, but I don't know how to fix it. My code is almost the same as other matrix multiplication code in other site.
Would you give me some hint how to solve this?