Paged attention kernel optimization(I)
Introduction
During recent interviews, I have been asked a question about my experience with optimizing
complex kernels or experience on architecture higher than sm89. Unfortunately, I don’t have much experience on this topic, but a few days ago, I have earned a H100x8 server for some reason, so I deside to spend a weekend to optimize my decode paged attention kernel in NanoPD. The post is the first part of the optimization process, which contains my reading notes on the vLLM and flash infer kernels.
VLLM’s paged attention kernel
First let us admire the vLLM’s paged attention kernel, which is implemented in this file. In the following part, I will line by line analyze the code and try to understand the optimization techniques used in this kernel.
Block sum kernel
Ahead of directly analysing the paged attention kernel, we need to understand the block sum kernel:
1 | template < int NUM_WARPS> |
First reduce in a warp, then store the result in shm, then reduce in the warp again, which is a common reduction pattern. No bank conflict as Warp Size is 32 or less, and the shared memory is float array, so each element is 4 bytes, which means each warp will write to a different bank. The reduction in the warp is done by shuffle instructions, which can be very efficient. But warp divergence can happen in the reduction, but it is not a big problem because the reduction is done in log2(NUM_WARPS) steps, which is small. The final result is broadcasted to all threads in the block, which can be efficient if NUM_WARPS is small.
Paged attention kernel
Now the main dish comes, first let us look at the kernel signature:
1 | // Grid: (num_heads, num_seqs, max_num_partitions). |
First consider the template parameters:
1 | typename scalar_t; // Q/K/V data type, can be float16, bfloat16, int8, etc. |
Partition attention:
At the decode stage, the query may be fewer than the prefill stage, but we still need to tackle a sequence-long kv cache. So the idea is to cut the kv cache into multiple partitions, and each partition will be processed by one kernel. Thus we achieved better parallelism and memory access pattern. The partition size can be tuned for better performance, and it is usually set to be a multiple of the block size.(from Tri Daos’ 2023 work Flash-Decoding for long-context inference)
Then consider the grid and block configuration:
The comment said Grid:(num_heads, num_seqs, max_num_partitions):
1 | blockIdx.x: head index, range from 0 to num_heads-1 |
Each cuda block will process one sequence of one head for one partition.
Output parameters:
1 | float * __restrict__ exp_sums; // [num_seqs, num_heads, max_num_partitions] |
They are the middle results of Flash-Decoding. When PARTITION_SIZE > 0, each partition computes local softmax, expsum, weighted ouput and then reduce them to get the final output. PARTITION_SIZE = 0 means no partition, the kernel will compute the final output directly without reduction.
Input parameters:
1 | scalar_t * __restrict__ q; // [num_seqs, num_heads, head_size] |
q: the query tensor, each sequence has one query vector for each head, so the shape is [num_seqs, num_heads, head_size].k_cacheandv_cache: the kv cache tensor, each sequence has multiple blocks of kv cache, each block has multiple kv heads, each kv head has multiple key/value vectors, so the shape is [num_blocks, num_kv_heads, head_size/x, block_size, x], where we split head_size into x parts to vectorize the memory access for better performance. The actual head size ishead_size/x * x = head_size, and the actual block size isblock_size * x.num_kv_heads: the number of kv heads for each head.scale: the scaling factor for the attention, usually set to be1/sqrt(head_size).block_tables: the block table for block sparse attention, each sequence has a block table to indicate which blocks are valid, so the shape is [num_seqs, max_num_blocks_per_seq].seq_lens: the actual sequence length for each sequence, so the shape is [num_seqs].max_num_blocks_per_seq: the maximum number of blocks for each sequence, used for block sparse attention.alibi_slopes: the alibi slopes for each head when we don’t use RoPE, we do not discuss it here, so just ignore it.q_stride,kv_block_stride,kv_head_stride: the stride for accessing the q, k_cache and v_cache tensors, avoid calculating the stride in the kernel for better performance.k_scaleandv_scale: the scaling factor for the quantized k and v, used for unquantization.tp_rank: the tensor parallel rank of the current process, used for partitioning the kv cache for tensor parallelism.blocksparse_local_blocks,blocksparse_vert_stride,blocksparse_block_size,blocksparse_head_sliding_step: the parameters for block sparse attention, used for calculating the valid blocks for each head.
Next comes the kernel body, first to decide the work space of now block:
1 | const int seq_idx = blockIdx.y; |
Read the sequence index and partition index from the block index, and read the sequence length from the seq_lens tensor.
Then calculate the start and end position of the current partition:
1 | const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); |
Calculate how many blocks are there for the current sequence, which is the sequence length divided by the block size, rounded up.
1 | const int num_blocks_per_partition = |
Calculate how many blocks are there for each partition.
1 | const int start_block_idx = |
Obviously, we can infer from the code itself.
Then convert the block range to the token range:
1 | const int start_token_idx = start_block_idx * BLOCK_SIZE; |
Thread Group Design
1 | constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); |
A thread group is a group of threads that work together to process the QK product for one token.
One warp has 32 threads and a KV block has BLOCK_SIZE tokens, so the design is to let threads balance the workload of one block. Each token receives WARP_SIZE / BLOCK_SIZE threads to compute the QK product.
1 | constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; |
Calculate how many thread groups are there in one block, which is the total number token can be processed in parallel.
1 | constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); |
If BLOCK_SIZE > WARP_SIZE, each thread group will process multiple tokens.
1 | constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; |
Standard thread index calculation.
GQA head projection
1 | const int head_idx = blockIdx.x; |
GQA means multi Q head with shared K/V, num_queries_per_kv is the number of query heads that share the same kv head, so we can calculate the kv head index from the query head index.
Vector Type Definition
1 | constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); |
The goal is to vectorize the memory access in a thread group to read 16 bytes of data which targeting LDG.128 instruction.
For an example: THREAD_GROUP_SIZE = 4, scalar_t = float16, then VEC_SIZE = 16 / (4 * 2) = 2, which means each thread will read 2 float16 elements once, which is 4 bytes, and the whole thread group will read 16 bytes once.
Then we can define the vector type for the q/k/v:
1 | using K_vec = typename Vec< scalar_t, VEC_SIZE> : :Type; |
Then we can calculate how many elements in one thread:
1 | constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; |
HEAD_SIZE elements is assigned to one thread group, each thread process NUM_ELEMS_PER_THREAD elements. Then make groups according to the VEC_SIZE for vectorized memory access, each thread will process NUM_VECS_PER_THREAD vectors.
Coordinate in the Thread Group
1 | const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; |
Calculate the thread group index and the offset of the thread in the thread group, which will be used for memory access and reduction.
Load the Query to registers.
1 | // Load the query to registers. |
Q‘s shape is [num_seqs, num_heads, head_size] and the grid shape is [num_heads, num_seqs, max_num_partitions], so the query for the current block can be calculated by q +seq_idx * q_stride + head_idx * HEAD_SIZE. The shared memory q_vecs layout is suitable for the thread group design.
Then each thread in the thread group(assume it’s size is 4) will traverse a row of the shm, thread 0, 4, 8, 12 will traverse the first row, then their thread group idx is 0, 1, … NUM_THREAD_GROUPS - 1, we assign them to a row of the shm and repeat it. For example, if NUM_THREAD_GROUPS = 4, then thread 0, 4, 8, 12 will traverse(0, 1, 2, 3), (4, 5, 6, 7)…
The Q access model is computed by vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE, meaning that the first row of shm stores (0, 4, 8, 12), the second row stores (1, 5, 9, 13), and so on. This access pattern can ensure coalesced memory access for the query.
Why we store the Query in shared memory instead of registers?
The query will be reused for multiple thread groups. After storing it into shm, each thread can access the query of other thread groups for free.
Shared Memory Plan
1 | extern __shared__ char shared_mem[]; |
The shared_mem is used to store the attention score of all the tokens and the red_smem is used for the reduction.
1 | constexpr int x = 16 / sizeof(cache_t); |
K cache’s layout is [num_blocks, num_kv_heads, head_size/x, block_size, x], so we need to calculate the x for vectorized memory access.
1 | float qk_max = -FLT_MAX; |
Each thread keep a local max used for online softmax.
1 | const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; |
block_tables‘s layout is [num_seqs, max_num_blocks_per_seq], we get the block table information.
Here we would not consider the block sparse case, so just ignore the block sparse related code for now.
Main Loop
1 | for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; |
KV block assign loop
1 | for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; |
Each warp tackle a few KV blocks, if there are 4 warps and 16 blocks, then the warp 1 will tackle block 1, 5, 9, 13.
Physical block location
1 | const int64_t physical_block_number = |
The block table is used to map the logical block index to the physical block index in the KV cache, which is used for block sparse attention.
Inner loop assign token in a warp
1 | for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { |
Usually, NUM_TOKENS_PER_THREAD_GROUP is 1, which means each thread group process one token.
Load K
1 | K_vec k_vecs[NUM_VECS_PER_THREAD]; |
K cache’s layout is [num_blocks, num_kv_heads, head_size/x, block_size, x], so the base ptr is located to [physical_block_number, kv_head_idx, 0, physical_block_offset, 0]. This is the start position of the k vector for the current token. Then the vec_idx is used like the Q loading part to ensure thread_group_offset thread process the vec_idxth vector.
1 | k_vecs[j] = *reinterpret_cast< const K_vec*> ( |
Then we can load the k vector to registers.
Assume THREAD_GROUP_SIZE=4, VEC_SIZE=2, x=8, NUM_VECS_PER_THREAD=8, a thread group will process a token as below:
1 | thread_group_offset=0, j=0: vec_idx=0, offset1=0, offset2=0 → [0,1] |
Dot product and store the logits
1 | // Compute dot product. |
Each thread compute the dot product and reduce in the thread group by Qk_dot, which is implemented by warp shuffle.
1 | if (thread_group_offset == 0) { |
Then store the logits to shared memory, and update the local max for softmax.
Reduction for max logits
1 | // Perform reduction across the threads in the same warp to get the |
Reduce all the kv block processed by a warp.
1 | qk_max = lane < NUM_WARPS? red_smem[lane]: -FLT_MAX; |
Then cross warps to get the max qk value.
Then compute the exp and the exp sums:
1 | float exp_sum = 0.f; |
After this, we normalize the logits and get the softmax output:
1 | const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); |
Load V into the memory
V cache’s layout is [num_blocks, num_kv_heads, head_size, block_size], which is different from K cache.
1 | constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); |
There are BLOCK_SIZE elements in one row of V, split them in V_VEC_SIZE. A warp has WARP_SIZE threads, so one iteration can process NUM_ROWS_PER_ITER rows, and each thread will process NUM_ROWS_PER_THREAD rows.
Acc initialization
1 | // Initialize the accumulators. |
Then comes the second main part: the matrix multiplication between the softmax output and the V vectors.
1 | scalar_t zero_value; |
This code is similar to the previous loop for K.