From 0d47785ad532e52d1cfd65aae436255495cee6d4 Mon Sep 17 00:00:00 2001 From: ZhuangYumin Date: Fri, 12 Jul 2024 19:44:53 +0800 Subject: [PATCH] opt to 2x --- csrc/gemm.cu | 111 +++++++++++++++++++++++++++++++-------------------- 1 file changed, 67 insertions(+), 44 deletions(-) diff --git a/csrc/gemm.cu b/csrc/gemm.cu index 035a020..77fbf61 100644 --- a/csrc/gemm.cu +++ b/csrc/gemm.cu @@ -26,66 +26,89 @@ void naiveSgemm(float *a, float *b, float *c, const int M, const int N, } } -const int TILE_SIZE = 32; /** * @brief Optimized implementation of matrix multiplication on GPU using shared memory. * Perform C = A * B, where A is M x K, B is K x N, and C is M x N. */ -__global__ void ZYMSgemm2D(float *a, float *b, float *c, const int M, + template +__global__ void __launch_bounds__((BM * BN) / (TM * TN), 1) ZYMSgemm2D(const float *a,const float *b, float *c, const int M, const int N, const int K) { - - // Shared memory for submatrices of A and B - __shared__ float As[TILE_SIZE][TILE_SIZE]; - __shared__ float Bs[TILE_SIZE][TILE_SIZE]; - - // Calculate row and column index of the element - int row = blockIdx.y * TILE_SIZE + threadIdx.y; - int col = blockIdx.x * TILE_SIZE + threadIdx.x; - - // Accumulator for the result - float value = 0.0f; - - // Loop over all the tiles - for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; ++t) { - // Load the tile elements into shared memory - if (row < M && (t * TILE_SIZE + threadIdx.x) < K) { - As[threadIdx.y][threadIdx.x] = a[row * K + t * TILE_SIZE + threadIdx.x]; - } else { - As[threadIdx.y][threadIdx.x] = 0.0f; + const int block_row = blockIdx.x; + const int block_col = blockIdx.y; + const int elements_per_block = BM * BN; + const int threads_per_block = (BM*BN)/(TM*TN); + assert(blockDim.x == threads_per_block); + const int thread_row = threadIdx.x / (BN/TN); + const int thread_col = threadIdx.x % (BN/TN); + __shared__ float shared_a[BM*BK]; + __shared__ float shared_b[BK*BN]; + a+=block_row*BM*K; + b+=block_col*BN; + c+=block_row*BM*N+block_col*BN; + const int load_a_col = threadIdx.x % BK; + const int load_a_row = threadIdx.x / BK; + const int load_a_row_stride = threads_per_block / BK; + const int load_b_col = threadIdx.x % BN; + const int load_b_row = threadIdx.x / BN; + const int load_b_row_stride = threads_per_block / BN; + float result_cache[TM*TN]={0.0}; + float a_cache[TM]={0.0}; + float b_cache[TN]={0.0}; + for(int k_idx=0;k_idx>>(a, b, c, M, N, K); + const int BK=8; + const int TM=8; + const int TN=8; + if(M>=128&&N>=128&&K>=128){ + const int BM=128; + const int BN=128; + dim3 gridDim((M + BM - 1) / BM, (N + BN - 1) / BN); + dim3 blockDim((BM * BN) / (TM * TN)); + ZYMSgemm2D<<>>(a, b, c, M, N, K); + } + else{ + const int BM=64; + const int BN=64; + dim3 gridDim((M + BM - 1) / BM, (N + BN - 1) / BN); + dim3 blockDim((BM * BN) / (TM * TN)); + ZYMSgemm2D<<>>(a, b, c, M, N, K); + } } void initialize(float *a, float *b, float *c, const int M, const int N, @@ -145,7 +168,7 @@ int main() { cudaDeviceSynchronize(); end = std::chrono::high_resolution_clock::now(); cudaMemcpy(c, d_c, MAXN * MAXN * sizeof(float), cudaMemcpyDeviceToHost); - printf("d_c[0][0]=%f\n", c[0]); + printf("d_c[108873]=%f\n", c[108873]); elapsed = end - start; printf("GPU time: %.3fs\n", elapsed.count()); @@ -155,7 +178,7 @@ int main() { cudaDeviceSynchronize(); end = std::chrono::high_resolution_clock::now(); cudaMemcpy(c, d_c, MAXN * MAXN * sizeof(float), cudaMemcpyDeviceToHost); - printf("d_c[0][0]=%f\n", c[0]); + printf("d_c[108873]=%f\n", c[108873]); elapsed = end - start; printf("cuBLAS time: %.3fs\n", elapsed.count()); }