Tensor MatMul on GPU: Caching
In this tutorial, you will learn how to implement a Matrix Multiplication (MatMul) function that uses specialized matrix multiplication hardware on the GPU while caching data to minimize expensive data traffic to/from global memory.
Prerequisites
- You have completed the Tensor_MatMul_GPU tutorial.
Input data caching
Since accessing the same data repeatedly from the global memory can be expensive, we will use the shared memory which is available much closer to the compute units to cache the input data to achieve much faster data accesses.
Since for input data caching we use shared memory, we need to be careful how much of the global data we cache since shared memory is comparatively much smaller in size. For this reason, we introduce an additional split in the K-loop and we use this newly created loop index kk
for caching:
kk = schedule.split(k, 256)
Sequential caching
In this approach, each thread block starts caching the next tile of input data only after the computation on the current tile is complete. This is the most simple form of shared memory caching which does not involve any overlapped execution of data copy and computation pipelines. This can be achieved by adding the following lines of DSL code:
plan.cache(A, index=kk, location=target.MemorySpace.SHARED)
plan.cache(B, index=kk, location=target.MemorySpace.SHARED)
The complete python script with caching of input data into shared memory can be found here.
This generates the following kernel code, note the barriers in the generated code to see how caching of the next tile waits for the computation of the current tile to finish:
extern "C" __global__ __launch_bounds__(256) void tensor_input_cache_matmul_gpu_866f5763c1d8d520__gpu__(float *arg0, float *arg1, float *arg2) {
// Calculate threadid offsets and other locals
/*...*/
// Declare shared memory caches for A and B
__shared__ float var8[32][256];
__shared__ float var9[256][32];
// k-loop
for (int32_t idx16 = 0; idx16 < 8; idx16 += 1) {
int32_t var17 = idx16 * 256;
// Wait for compute on previously cached items to finish
__builtin_amdgcn_s_barrier();
// Cache current tile of A into shared memory
block_copy<CopyMode::Striped, /*SRC_ROW_MAJOR*/ 1, /*DST_ROW_MAJOR*/ 1, /*STRIDE*/ 1, /*WPT*/ 32, /*TILE_R,C*/32, 256, /*BLOCK_DIM_X,Y,Z*/ 128, 2, 1, MemSpace::None, MemSpace::Shared, float>(var11, (float*)arg0, var7, var17, affine_map_func_0_i0, (float*)var8);
// Cache current tile of B into shared memory
block_copy<CopyMode::Striped, /*SRC_ROW_MAJOR*/ 1, /*DST_ROW_MAJOR*/ 1, /*STRIDE*/ 1, /*WPT*/ 32, /*TILE_R,C*/256, 32, /*BLOCK_DIM_X,Y,Z*/ 128, 2, 1, MemSpace::None, MemSpace::Shared, float>(var11, (float*)arg1, var17, var5, affine_map_func_1_i0, (float*)var9);
// Wait for input caching to finish
__builtin_amdgcn_s_barrier();
// kk-loop
for (int32_t idx18 = 0; idx18 < 64; idx18 += 1) {
int32_t var19 = idx18 * 4;
// Declare matrix fragments for A, B and C
/*...*/
// Load C from global memory
rocwmma::load_matrix_sync<0, rocwmma::layout_t::mem_row_major, 1024>(var11, mmaMatrix_22, arg2 + ...);
// Load A and B from shared memory cache
rocwmma::load_matrix_sync<256>(var11, mmaMatrix_20, &var8[var12][var19]);
rocwmma::load_matrix_sync<32>(var11, mmaMatrix_21, &var9[var19][var14]);
// Compute matrix multiplication
rocwmma::mma_sync<0, 0, 0>(mmaMatrix_22, mmaMatrix_20, mmaMatrix_21, mmaMatrix_22);
// Store result into global memory
rocwmma::store_matrix_sync<0, rocwmma::layout_t::mem_row_major, 1024>(var11, arg2 + ..., mmaMatrix_22);
}
}
}
Benchmarking results using hatlib
Similar to the previous experiments we can use hatlib to benchmark this kernel using the following command:
python3 -m hatlib.benchmark_hat_package <path to tensor_input_cache_matmul_gpu.hat> --cpp --min_time_in_sec 10 --time_in_ms
This produces the following output which shows that sequential caching reduces the runtime to ~3 ms which is ~30% faster than the non-cached version presented in Tensor_MatMul_GPU.md:
function_name mean median_of_means mean_of_small_means robust_mean min_of_means
0 tensor_input_cache_matmul_gpu_866f5763c1d8d520 3.02507486 3.02532532 3.02407842 3.02519751 3.02233856
Overlapped caching (a.k.a. Double Buffering)
In this approach, each thread block prefetches the next tile into registers while the current tile is being computed. This overlapped execution of data copy and compute typically achieves better performance by utilizing different hardware pipelines more efficiently. Using Accera DSL, this can be done by setting the double_buffer
flag in the plan.cache
call:
plan.cache(A, index=kk, location=target.MemorySpace.SHARED, double_buffer=True, double_buffer_location=target.MemorySpace.PRIVATE)
plan.cache(B, index=kk, location=target.MemorySpace.SHARED, double_buffer=True, double_buffer_location=target.MemorySpace.PRIVATE)
The complete python script with caching of input data using double buffering can be found here.
The generated kernel code looks something like this, note how the prefetch of the next tile and the computation of the current tile happen without synchronization to achieve global memory latency hiding:
extern "C" __global__ __launch_bounds__(256) void tensor_input_double_buffer_cache_matmul_gpu_ce60189b3e52267d__gpu__(float *arg0, float *arg1, float *arg2) {
// Calculate threadid offsets and other locals
/*...*/
// Declare register caches for prefetching input data of A and B
float var8[32][1];
float var9[32][1];
// Declare shared memory caches for A and B
__shared__ float var10[32][256];
__shared__ float var11[256][32];
// Cache tile 0 of A from global memory to shared memory
block_copy<CopyMode::Striped, /*SRC_ROW_MAJOR*/ 1, /*DST_ROW_MAJOR*/ 1, /*STRIDE*/ 1, /*WPT*/ 32, /*TILE_R,C*/32, 256, /*BLOCK_DIM_X,Y,Z*/ 128, 2, 1, MemSpace::None, MemSpace::Shared, float>(var13, (float*)arg0, var7, 0, affine_map_func_0_i0, (float*)var10);
// Cache tile 0 of B from global memory to shared memory
block_copy<CopyMode::Striped, /*SRC_ROW_MAJOR*/ 1, /*DST_ROW_MAJOR*/ 1, /*STRIDE*/ 1, /*WPT*/ 32, /*TILE_R,C*/256, 32, /*BLOCK_DIM_X,Y,Z*/ 128, 2, 1, MemSpace::None, MemSpace::Shared, float>(var13, (float*)arg1, 0, var5, affine_map_func_1_i0, (float*)var11);
// Wait for tile 0 data to finish copying
__builtin_amdgcn_s_barrier();
// k-loop (Current tile)
for (int32_t idx18 = 0; idx18 < 7; idx18 += 1) {
// Prefetch next tile of A from global memory to registers
block_copy<CopyMode::Striped, /*SRC_ROW_MAJOR*/ 1, /*DST_ROW_MAJOR*/ 1, /*STRIDE*/ 1, /*WPT*/ 32, /*TILE_R,C*/32, 256, /*BLOCK_DIM_X,Y,Z*/ 128, 2, 1, MemSpace::None, MemSpace::Private, float>(
var13, (float*)arg0, var7, var20, affine_map_func_0_i0, (float*)var9);
// Prefetch next tile of B from global memory to registers
block_copy<CopyMode::Striped, /*SRC_ROW_MAJOR*/ 1, /*DST_ROW_MAJOR*/ 1, /*STRIDE*/ 1, /*WPT*/ 32, /*TILE_R,C*/256, 32, /*BLOCK_DIM_X,Y,Z*/ 128, 2, 1, MemSpace::None, MemSpace::Private, float>(
var13, (float*)arg1, var20, var5, affine_map_func_1_i0, (float*)var8);
// kk-loop
for (int32_t idx24 = 0; idx24 < 64; idx24 += 1) {
// Declare matrix fragments for A, B and C
/*...*/
// Perform matmul on the current tile from shared memory
rocwmma::load_matrix_sync<0, rocwmma::layout_t::mem_row_major, 1024>(var13, mmaMatrix_28, arg2 + ...);
rocwmma::load_matrix_sync<256>(var13, mmaMatrix_26, &var10[var14][var25]);
rocwmma::load_matrix_sync<32>(var13, mmaMatrix_27, &var11[var25][var16]);
rocwmma::mma_sync<0, 0, 0>(mmaMatrix_28, mmaMatrix_26, mmaMatrix_27, mmaMatrix_28);
rocwmma::store_matrix_sync<0, rocwmma::layout_t::mem_row_major, 1024>(var13, arg2 + ..., mmaMatrix_28);
}
// Wait for matmul on current tile to finish
__builtin_amdgcn_s_barrier();
// Copy prefetched data of A from registers to shared memory
block_copy<CopyMode::Striped, /*SRC_ROW_MAJOR*/ 1, /*DST_ROW_MAJOR*/ 1, /*STRIDE*/ 1, /*WPT*/ 32, /*TILE_R,C*/256, 32, /*BLOCK_DIM_X,Y,Z*/ 128, 2, 1, MemSpace::Private, MemSpace::Shared, float>(var13, (float*)var11, 0, 0, affine_map_func_4_i0, (float*)var8);
// Copy prefetched data of B from registers to shared memory
block_copy<CopyMode::Striped, /*SRC_ROW_MAJOR*/ 1, /*DST_ROW_MAJOR*/ 1, /*STRIDE*/ 1, /*WPT*/ 32, /*TILE_R,C*/32, 256, /*BLOCK_DIM_X,Y,Z*/ 128, 2, 1, MemSpace::Private, MemSpace::Shared, float>(var13, (float*)var10, 0, 0, affine_map_func_3_i0, (float*)var9);
// Wait for copy to finish before starting next tile
__builtin_amdgcn_s_barrier();
}
// Last tile (loop peeling)
for (int32_t idx19 = 0; idx19 < 64; idx19 += 1) {
// Declare matrix fragments for A, B and C
/*...*/
// Perform matmul on the last tile from shared memory
rocwmma::load_matrix_sync<0, rocwmma::layout_t::mem_row_major, 1024>(var13, mmaMatrix_23, arg2 + ...);
rocwmma::load_matrix_sync<256>(var13, mmaMatrix_21, &var10[var14][var20]);
rocwmma::load_matrix_sync<32>(var13, mmaMatrix_22, &var11[var20][var16]);
rocwmma::mma_sync<0, 0, 0>(mmaMatrix_23, mmaMatrix_21, mmaMatrix_22, mmaMatrix_23);
rocwmma::store_matrix_sync<0, rocwmma::layout_t::mem_row_major, 1024>(var13, arg2 + ..., mmaMatrix_23);
}
}
Benchmarking results using hatlib
Benchmarking the above kernel with hatlib shows that double-buffer caching further reduces the runtime to ~1.45 ms which is ~66% faster than the non-cached version presented in Tensor_MatMul_GPU.md:
function_name mean median_of_means mean_of_small_means robust_mean min_of_means
0 tensor_input_double_buffer_cache_matmul_gpu_ce... 1.45501032 1.45495605 1.45370367 1.45489838 1.45116257
Output data caching
Similar to input caching, the result data can also be cached to prevent unnecessary global memory accesses. Here we will see how we can accumulate the result in registers before copying it to global memory. This is can be done by adding:
plan.cache(C, index=k, location=target.MemorySpace.MMA_FRAGMENT)
The complete python script with both input and output caching can be found here.
extern "C" __global__ __launch_bounds__(256) void tensor_input_output_cache_matmul_gpu_1b4d39ede237d688__gpu__(float *arg0, float *arg1, float *arg2) {
// Calculate threadid offsets and other locals
/*...*/
// Declare register caches for prefetching input data of A and B
float var8[32][1];
float var9[32][1];
// Declare shared memory caches for A and B
__shared__ float var10[32][256];
__shared__ float var11[256][32];
// Declare fragment cache (registers) for output, C
rocwmma::fragment<rocwmma::accumulator, 16, 16, 4, 1, 1, float> mmaMatrix_12;
// Fill output cache with data from global memory
rocwmma::load_matrix_sync<0, rocwmma::layout_t::mem_row_major, 1024>(var18, mmaMatrix_12, arg2 + ...);
// Cache tile 0 of A and B from global memory to shared memory
/*...*/
// Wait for tile 0 data to finish copying
__builtin_amdgcn_s_barrier();
// k-loop (Current tile)
for (int32_t idx19 = 0; idx19 < 7; idx19 += 1) {
// Prefetch next tile of A and B from global memory to registers
/*...*/
// kk-loop
for (int32_t idx25 = 0; idx25 < 64; idx25 += 1) {
// Declare matrix fragments for A and B
/*...*/
// Load A and B from shared memory cache
/*...*/
// Compute matrix multiplication and accumulate in fragment cache
rocwmma::mma_sync<0, 0, 0>(mmaMatrix_12, mmaMatrix_27, mmaMatrix_28, mmaMatrix_12);
}
// Wait for matmul on current tile to finish
__builtin_amdgcn_s_barrier();
// Copy prefetched data of A and B from registers to shared memory
/*...*/
// Wait for copy to finish before starting next tile
__builtin_amdgcn_s_barrier();
}
// Last tile (loop peeling)
for (int32_t idx20 = 0; idx20 < 64; idx20 += 1) {
// Declare matrix fragments for A and B
/*...*/
// Load A and B from shared memory cache
/*...*/
// Compute matrix multiplication and accumulate in fragment cache
rocwmma::mma_sync<0, 0, 0>(mmaMatrix_12, mmaMatrix_22, mmaMatrix_23, mmaMatrix_12);
}
// Store result into global memory ONCE!
rocwmma::store_matrix_sync<0, rocwmma::layout_t::mem_row_major, 1024>(var18, arg2 + ..., mmaMatrix_12);
}
Benchmarking results using hatlib
Benchmarking the above kernel with hatlib shows that double-buffer caching combined with output caching further reduces the runtime to ~1.34 ms which is an overall ~69% improvement compared to the non-cached version presented in Tensor_MatMul_GPU.md:
function_name mean median_of_means mean_of_small_means robust_mean min_of_means
0 tensor_input_output_cache_matmul_gpu_1b4d39ede... 1.33967323 1.33956711 1.33841698 1.33953149 1.33726944