ready for further opt

This commit is contained in:
2024-07-11 18:02:56 +08:00
parent c440358d9e
commit 2c651ef4c4
2 changed files with 55 additions and 16 deletions

View File

@ -12,7 +12,7 @@ BASIC_SRC = csrc/basic.cu
GEMM = build/gemm GEMM = build/gemm
GEMM_SRC = csrc/gemm.cu GEMM_SRC = csrc/gemm.cu
all: $(HELLO) $(BASIC) all: $(HELLO) $(BASIC) $(GEMM)
$(HELLO): $(HELLO_SRC) $(HELLO): $(HELLO_SRC)
mkdir -p build mkdir -p build

View File

@ -7,7 +7,7 @@
// You may increase this value to test larger matrices // You may increase this value to test larger matrices
// But it will be slow on CPU // But it will be slow on CPU
constexpr int MAXN = 2048; constexpr int MAXN = 8192;
/** /**
* @brief A naive implementation of matrix multiplication on CPU. * @brief A naive implementation of matrix multiplication on CPU.
@ -26,31 +26,66 @@ void naiveSgemm(float *a, float *b, float *c, const int M, const int N,
} }
} }
const int TILE_SIZE = 32;
/** /**
* @brief A naive implementation of matrix multiplication on GPU. * @brief Optimized implementation of matrix multiplication on GPU using shared memory.
* Perform C = A * B, where A is M x K, B is K x N, and C is M x N. * Perform C = A * B, where A is M x K, B is K x N, and C is M x N.
*/ */
__global__ void naiveSgemm2D(float *a, float *b, float *c, const int M, __global__ void ZYMSgemm2D(float *a, float *b, float *c, const int M,
const int N, const int K) { const int N, const int K) {
int m = blockIdx.x * blockDim.x + threadIdx.x; // Row index
int n = blockIdx.y * blockDim.y + threadIdx.y; // Column index // Shared memory for submatrices of A and B
if (m < M && n < N) { __shared__ float As[TILE_SIZE][TILE_SIZE];
float sum = 0.0; __shared__ float Bs[TILE_SIZE][TILE_SIZE];
for (int k = 0; k < K; ++k) {
sum += a[m * K + k] * b[k * N + n]; // Calculate row and column index of the element
int row = blockIdx.y * TILE_SIZE + threadIdx.y;
int col = blockIdx.x * TILE_SIZE + threadIdx.x;
// Accumulator for the result
float value = 0.0f;
// Loop over all the tiles
for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; ++t) {
// Load the tile elements into shared memory
if (row < M && (t * TILE_SIZE + threadIdx.x) < K) {
As[threadIdx.y][threadIdx.x] = a[row * K + t * TILE_SIZE + threadIdx.x];
} else {
As[threadIdx.y][threadIdx.x] = 0.0f;
} }
c[m * N + n] = sum;
if (col < N && (t * TILE_SIZE + threadIdx.y) < K) {
Bs[threadIdx.y][threadIdx.x] = b[(t * TILE_SIZE + threadIdx.y) * N + col];
} else {
Bs[threadIdx.y][threadIdx.x] = 0.0f;
}
// Synchronize to make sure the submatrices are loaded
__syncthreads();
// Multiply the two matrices together
for (int k = 0; k < TILE_SIZE; ++k) {
value += As[threadIdx.y][k] * Bs[k][threadIdx.x];
}
// Synchronize to make sure that the computation is done before loading new tiles
__syncthreads();
}
// Write the result back to the global memory
if (row < M && col < N) {
c[row * N + col] = value;
} }
} }
/** /**
* @brief Launch naiveSgemm2D kernel. * @brief Launch ZYMSgemm2D kernel.
*/ */
void launchSgemm2D(float *a, float *b, float *c, const int M, const int N, void launchSgemm2D(float *a, float *b, float *c, const int M, const int N,
const int K) { const int K) {
dim3 block(16, 16); // 256 threads per block (16 * 16 = 256) dim3 block(TILE_SIZE, TILE_SIZE); // 256 threads per block (16 * 16 = 256)
dim3 grid((M + block.x - 1) / block.x, (N + block.y - 1) / block.y); dim3 grid((M + block.x - 1) / block.x, (N + block.y - 1) / block.y);
naiveSgemm2D<<<grid, block>>>(a, b, c, M, N, K); ZYMSgemm2D<<<grid, block>>>(a, b, c, M, N, K);
} }
void initialize(float *a, float *b, float *c, const int M, const int N, void initialize(float *a, float *b, float *c, const int M, const int N,
@ -91,7 +126,7 @@ int main() {
// ********** CPU ********** // ********** CPU **********
auto start = std::chrono::high_resolution_clock::now(); auto start = std::chrono::high_resolution_clock::now();
naiveSgemm(a, b, c, MAXN, MAXN, MAXN); // naiveSgemm(a, b, c, MAXN, MAXN, MAXN);
auto end = std::chrono::high_resolution_clock::now(); auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> elapsed = end - start; std::chrono::duration<double> elapsed = end - start;
printf("CPU time: %.3fs\n", elapsed.count()); printf("CPU time: %.3fs\n", elapsed.count());
@ -109,6 +144,8 @@ int main() {
launchSgemm2D(d_a, d_b, d_c, MAXN, MAXN, MAXN); launchSgemm2D(d_a, d_b, d_c, MAXN, MAXN, MAXN);
cudaDeviceSynchronize(); cudaDeviceSynchronize();
end = std::chrono::high_resolution_clock::now(); end = std::chrono::high_resolution_clock::now();
cudaMemcpy(c, d_c, MAXN * MAXN * sizeof(float), cudaMemcpyDeviceToHost);
printf("d_c[0][0]=%f\n", c[0]);
elapsed = end - start; elapsed = end - start;
printf("GPU time: %.3fs\n", elapsed.count()); printf("GPU time: %.3fs\n", elapsed.count());
@ -117,6 +154,8 @@ int main() {
launchCublasSgemm(d_a, d_b, d_c, MAXN, MAXN, MAXN); launchCublasSgemm(d_a, d_b, d_c, MAXN, MAXN, MAXN);
cudaDeviceSynchronize(); cudaDeviceSynchronize();
end = std::chrono::high_resolution_clock::now(); end = std::chrono::high_resolution_clock::now();
cudaMemcpy(c, d_c, MAXN * MAXN * sizeof(float), cudaMemcpyDeviceToHost);
printf("d_c[0][0]=%f\n", c[0]);
elapsed = end - start; elapsed = end - start;
printf("cuBLAS time: %.3fs\n", elapsed.count()); printf("cuBLAS time: %.3fs\n", elapsed.count());
} }