opt to 2x
This commit is contained in:
111
csrc/gemm.cu
111
csrc/gemm.cu
@ -26,66 +26,89 @@ void naiveSgemm(float *a, float *b, float *c, const int M, const int N,
|
||||
}
|
||||
}
|
||||
|
||||
const int TILE_SIZE = 32;
|
||||
/**
|
||||
* @brief Optimized implementation of matrix multiplication on GPU using shared memory.
|
||||
* Perform C = A * B, where A is M x K, B is K x N, and C is M x N.
|
||||
*/
|
||||
__global__ void ZYMSgemm2D(float *a, float *b, float *c, const int M,
|
||||
template <const int BM, const int BN, const int BK, const int TM, const int TN>
|
||||
__global__ void __launch_bounds__((BM * BN) / (TM * TN), 1) ZYMSgemm2D(const float *a,const float *b, float *c, const int M,
|
||||
const int N, const int K) {
|
||||
|
||||
// Shared memory for submatrices of A and B
|
||||
__shared__ float As[TILE_SIZE][TILE_SIZE];
|
||||
__shared__ float Bs[TILE_SIZE][TILE_SIZE];
|
||||
|
||||
// Calculate row and column index of the element
|
||||
int row = blockIdx.y * TILE_SIZE + threadIdx.y;
|
||||
int col = blockIdx.x * TILE_SIZE + threadIdx.x;
|
||||
|
||||
// Accumulator for the result
|
||||
float value = 0.0f;
|
||||
|
||||
// Loop over all the tiles
|
||||
for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; ++t) {
|
||||
// Load the tile elements into shared memory
|
||||
if (row < M && (t * TILE_SIZE + threadIdx.x) < K) {
|
||||
As[threadIdx.y][threadIdx.x] = a[row * K + t * TILE_SIZE + threadIdx.x];
|
||||
} else {
|
||||
As[threadIdx.y][threadIdx.x] = 0.0f;
|
||||
const int block_row = blockIdx.x;
|
||||
const int block_col = blockIdx.y;
|
||||
const int elements_per_block = BM * BN;
|
||||
const int threads_per_block = (BM*BN)/(TM*TN);
|
||||
assert(blockDim.x == threads_per_block);
|
||||
const int thread_row = threadIdx.x / (BN/TN);
|
||||
const int thread_col = threadIdx.x % (BN/TN);
|
||||
__shared__ float shared_a[BM*BK];
|
||||
__shared__ float shared_b[BK*BN];
|
||||
a+=block_row*BM*K;
|
||||
b+=block_col*BN;
|
||||
c+=block_row*BM*N+block_col*BN;
|
||||
const int load_a_col = threadIdx.x % BK;
|
||||
const int load_a_row = threadIdx.x / BK;
|
||||
const int load_a_row_stride = threads_per_block / BK;
|
||||
const int load_b_col = threadIdx.x % BN;
|
||||
const int load_b_row = threadIdx.x / BN;
|
||||
const int load_b_row_stride = threads_per_block / BN;
|
||||
float result_cache[TM*TN]={0.0};
|
||||
float a_cache[TM]={0.0};
|
||||
float b_cache[TN]={0.0};
|
||||
for(int k_idx=0;k_idx<K;k_idx+=BK) {
|
||||
for(int load_a_offset=0;load_a_offset<BM;load_a_offset+=load_a_row_stride) {
|
||||
shared_a[(load_a_offset+load_a_row)*BK+load_a_col]=a[(load_a_offset+load_a_row)*K+load_a_col];
|
||||
}
|
||||
|
||||
if (col < N && (t * TILE_SIZE + threadIdx.y) < K) {
|
||||
Bs[threadIdx.y][threadIdx.x] = b[(t * TILE_SIZE + threadIdx.y) * N + col];
|
||||
} else {
|
||||
Bs[threadIdx.y][threadIdx.x] = 0.0f;
|
||||
for(int load_b_offset=0;load_b_offset<BK;load_b_offset+=load_b_row_stride) {
|
||||
shared_b[(load_b_offset+load_b_row)*BN+load_b_col]=b[(load_b_offset+load_b_row)*N+load_b_col];
|
||||
}
|
||||
|
||||
// Synchronize to make sure the submatrices are loaded
|
||||
__syncthreads();
|
||||
|
||||
// Multiply the two matrices together
|
||||
for (int k = 0; k < TILE_SIZE; ++k) {
|
||||
value += As[threadIdx.y][k] * Bs[k][threadIdx.x];
|
||||
a+=BK;
|
||||
b+=BK*N;
|
||||
for(int dot_idx=0;dot_idx<BK;dot_idx++) {
|
||||
for(int i=0;i<TM;i++) {
|
||||
a_cache[i]=shared_a[(thread_row*TM+i)*BK+dot_idx];
|
||||
}
|
||||
for(int i=0;i<TN;i++) {
|
||||
b_cache[i]=shared_b[dot_idx*BN+thread_col*TN+i];
|
||||
}
|
||||
for(int i=0;i<TM;i++) {
|
||||
for(int j=0;j<TN;j++) {
|
||||
result_cache[i*TN+j]+=a_cache[i]*b_cache[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Synchronize to make sure that the computation is done before loading new tiles
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Write the result back to the global memory
|
||||
if (row < M && col < N) {
|
||||
c[row * N + col] = value;
|
||||
for(int i=0;i<TM;i++) {
|
||||
for(int j=0;j<TN;j++) {
|
||||
c[(thread_row*TM+i)*N+thread_col*TN+j]=result_cache[i*TN+j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Launch ZYMSgemm2D kernel.
|
||||
* @details see https://siboehm.com/articles/22/CUDA-MMM
|
||||
*/
|
||||
void launchSgemm2D(float *a, float *b, float *c, const int M, const int N,
|
||||
void launchSgemm2D(const float *a,const float *b, float *c, const int M, const int N,
|
||||
const int K) {
|
||||
dim3 block(TILE_SIZE, TILE_SIZE); // 256 threads per block (16 * 16 = 256)
|
||||
dim3 grid((M + block.x - 1) / block.x, (N + block.y - 1) / block.y);
|
||||
ZYMSgemm2D<<<grid, block>>>(a, b, c, M, N, K);
|
||||
const int BK=8;
|
||||
const int TM=8;
|
||||
const int TN=8;
|
||||
if(M>=128&&N>=128&&K>=128){
|
||||
const int BM=128;
|
||||
const int BN=128;
|
||||
dim3 gridDim((M + BM - 1) / BM, (N + BN - 1) / BN);
|
||||
dim3 blockDim((BM * BN) / (TM * TN));
|
||||
ZYMSgemm2D<BM, BN, BK, TM, TN><<<gridDim, blockDim>>>(a, b, c, M, N, K);
|
||||
}
|
||||
else{
|
||||
const int BM=64;
|
||||
const int BN=64;
|
||||
dim3 gridDim((M + BM - 1) / BM, (N + BN - 1) / BN);
|
||||
dim3 blockDim((BM * BN) / (TM * TN));
|
||||
ZYMSgemm2D<BM, BN, BK, TM, TN><<<gridDim, blockDim>>>(a, b, c, M, N, K);
|
||||
}
|
||||
}
|
||||
|
||||
void initialize(float *a, float *b, float *c, const int M, const int N,
|
||||
@ -145,7 +168,7 @@ int main() {
|
||||
cudaDeviceSynchronize();
|
||||
end = std::chrono::high_resolution_clock::now();
|
||||
cudaMemcpy(c, d_c, MAXN * MAXN * sizeof(float), cudaMemcpyDeviceToHost);
|
||||
printf("d_c[0][0]=%f\n", c[0]);
|
||||
printf("d_c[108873]=%f\n", c[108873]);
|
||||
elapsed = end - start;
|
||||
printf("GPU time: %.3fs\n", elapsed.count());
|
||||
|
||||
@ -155,7 +178,7 @@ int main() {
|
||||
cudaDeviceSynchronize();
|
||||
end = std::chrono::high_resolution_clock::now();
|
||||
cudaMemcpy(c, d_c, MAXN * MAXN * sizeof(float), cudaMemcpyDeviceToHost);
|
||||
printf("d_c[0][0]=%f\n", c[0]);
|
||||
printf("d_c[108873]=%f\n", c[108873]);
|
||||
elapsed = end - start;
|
||||
printf("cuBLAS time: %.3fs\n", elapsed.count());
|
||||
}
|
||||
|
Reference in New Issue
Block a user