diff --git a/Makefile b/Makefile index e832778..eca8dfc 100644 --- a/Makefile +++ b/Makefile @@ -12,7 +12,7 @@ BASIC_SRC = csrc/basic.cu GEMM = build/gemm GEMM_SRC = csrc/gemm.cu -all: $(HELLO) $(BASIC) +all: $(HELLO) $(BASIC) $(GEMM) $(HELLO): $(HELLO_SRC) mkdir -p build diff --git a/csrc/gemm.cu b/csrc/gemm.cu index 3656185..035a020 100644 --- a/csrc/gemm.cu +++ b/csrc/gemm.cu @@ -7,7 +7,7 @@ // You may increase this value to test larger matrices // But it will be slow on CPU -constexpr int MAXN = 2048; +constexpr int MAXN = 8192; /** * @brief A naive implementation of matrix multiplication on CPU. @@ -26,31 +26,66 @@ void naiveSgemm(float *a, float *b, float *c, const int M, const int N, } } +const int TILE_SIZE = 32; /** - * @brief A naive implementation of matrix multiplication on GPU. + * @brief Optimized implementation of matrix multiplication on GPU using shared memory. * Perform C = A * B, where A is M x K, B is K x N, and C is M x N. */ -__global__ void naiveSgemm2D(float *a, float *b, float *c, const int M, - const int N, const int K) { - int m = blockIdx.x * blockDim.x + threadIdx.x; // Row index - int n = blockIdx.y * blockDim.y + threadIdx.y; // Column index - if (m < M && n < N) { - float sum = 0.0; - for (int k = 0; k < K; ++k) { - sum += a[m * K + k] * b[k * N + n]; +__global__ void ZYMSgemm2D(float *a, float *b, float *c, const int M, + const int N, const int K) { + + // Shared memory for submatrices of A and B + __shared__ float As[TILE_SIZE][TILE_SIZE]; + __shared__ float Bs[TILE_SIZE][TILE_SIZE]; + + // Calculate row and column index of the element + int row = blockIdx.y * TILE_SIZE + threadIdx.y; + int col = blockIdx.x * TILE_SIZE + threadIdx.x; + + // Accumulator for the result + float value = 0.0f; + + // Loop over all the tiles + for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; ++t) { + // Load the tile elements into shared memory + if (row < M && (t * TILE_SIZE + threadIdx.x) < K) { + As[threadIdx.y][threadIdx.x] = a[row * K + t * TILE_SIZE + threadIdx.x]; + } else { + As[threadIdx.y][threadIdx.x] = 0.0f; } - c[m * N + n] = sum; + + if (col < N && (t * TILE_SIZE + threadIdx.y) < K) { + Bs[threadIdx.y][threadIdx.x] = b[(t * TILE_SIZE + threadIdx.y) * N + col]; + } else { + Bs[threadIdx.y][threadIdx.x] = 0.0f; + } + + // Synchronize to make sure the submatrices are loaded + __syncthreads(); + + // Multiply the two matrices together + for (int k = 0; k < TILE_SIZE; ++k) { + value += As[threadIdx.y][k] * Bs[k][threadIdx.x]; + } + + // Synchronize to make sure that the computation is done before loading new tiles + __syncthreads(); + } + + // Write the result back to the global memory + if (row < M && col < N) { + c[row * N + col] = value; } } /** - * @brief Launch naiveSgemm2D kernel. + * @brief Launch ZYMSgemm2D kernel. */ void launchSgemm2D(float *a, float *b, float *c, const int M, const int N, const int K) { - dim3 block(16, 16); // 256 threads per block (16 * 16 = 256) + dim3 block(TILE_SIZE, TILE_SIZE); // 256 threads per block (16 * 16 = 256) dim3 grid((M + block.x - 1) / block.x, (N + block.y - 1) / block.y); - naiveSgemm2D<<>>(a, b, c, M, N, K); + ZYMSgemm2D<<>>(a, b, c, M, N, K); } void initialize(float *a, float *b, float *c, const int M, const int N, @@ -91,7 +126,7 @@ int main() { // ********** CPU ********** auto start = std::chrono::high_resolution_clock::now(); - naiveSgemm(a, b, c, MAXN, MAXN, MAXN); + // naiveSgemm(a, b, c, MAXN, MAXN, MAXN); auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration elapsed = end - start; printf("CPU time: %.3fs\n", elapsed.count()); @@ -109,6 +144,8 @@ int main() { launchSgemm2D(d_a, d_b, d_c, MAXN, MAXN, MAXN); cudaDeviceSynchronize(); end = std::chrono::high_resolution_clock::now(); + cudaMemcpy(c, d_c, MAXN * MAXN * sizeof(float), cudaMemcpyDeviceToHost); + printf("d_c[0][0]=%f\n", c[0]); elapsed = end - start; printf("GPU time: %.3fs\n", elapsed.count()); @@ -117,6 +154,8 @@ int main() { launchCublasSgemm(d_a, d_b, d_c, MAXN, MAXN, MAXN); cudaDeviceSynchronize(); end = std::chrono::high_resolution_clock::now(); + cudaMemcpy(c, d_c, MAXN * MAXN * sizeof(float), cudaMemcpyDeviceToHost); + printf("d_c[0][0]=%f\n", c[0]); elapsed = end - start; printf("cuBLAS time: %.3fs\n", elapsed.count()); }