[Disclaimer: This blog was written solely for my understanding purpose only. Any mistakes found that need to be addressed, please feel free to reach out to me.]
Note: In this post, we are focusing exclusively on Grouped Query Attention (GQA). While there are many other optimization techniques for LLM inference, this blog concentrates solely on understanding GQA and its memory efficiency benefits.
1. Why We Need GQA: The Memory Problem
When you chat with ChatGPT or generate text with any large language model (LLM), something interesting happens behind the scenes. The model doesn't just process your entire conversation from scratch every time it generates a new word. That would be incredibly slow!
Instead, it uses a clever trick called KV caching (learn more in KV Caching: The Hidden Trick Behind Fast LLM Inference). But even with caching, there's a problem: memory.
The Scale of the Problem: Qwen3-8B Example
In the standard KV caching approach, we calculated that Qwen3-8B requires the below memory requirements at different precisions:
| Precision | Bytes/Param | Model Params | KV Cache (Standard) | Total RAM |
|---|---|---|---|---|
| FP32 (Float32) | 4 | 32.00 GB | 46.08 GB | 81.12 GB |
| FP16 (Float16) | 2 | 16.00 GB | 23.04 GB | 40.56 GB |
| INT8 (8-bit) | 1 | 8.00 GB | 11.52 GB | 20.28 GB |
Even with FP16 (2 bytes), standard KV caching still requires 40.56 GB! This is where GQA becomes critical.
2. What is Grouped Query Attention (GQA)?
A Quick Refresher: How Transformers Work
Transformer models (like GPT, Llama, Qwen) use something called attention. Think of it like this:
When writing the next word, the model asks:
- Query (Q): "What should I focus on?"
- Key (K): "What information is available?"
- Value (V): "What are the actual values of that information?"
For each word, we compute:
Attention = softmax(q × k^T) × v
This tells the model which previous words to "pay attention to" when generating the next one.
Multi-Head Attention: Looking from Multiple Perspectives
Instead of doing this once, we do it 32 times in parallel (in Qwen3-8B). Each is called a "head."
Why? Each head can learn to focus on different things:
- Head 1 might focus on grammar
- Head 2 might focus on topics
- Head 3 might focus on entities
- ... and so on
This is called Multi-Head Attention (MHA).
The Traditional Approach: Multi-Head Attention (MHA)
In classic transformers:
32 Query heads (q)
32 Key heads (k)
32 Value heads (v)
Every q head gets its own k and v head. Simple and powerful!
The Memory Problem
Let's say you're generating a 1000-word response. For each layer (Qwen3-8B has 36 layers):
Memory needed for KV cache:
32 k heads × 1000 tokens × 128 dimensions × 4 bytes = 16.4 MB
32 v heads × 1000 tokens × 128 dimensions × 4 bytes = 16.4 MB
Per layer: 32.8 MB
Total (36 layers): 1.18 GB just for the cache!
For longer conversations (10,000 tokens), you'd need 11.8 GB just for KV cache!
Problem: This limits:
- How much you can fit on a GPU
- How many users you can serve simultaneously
- How long conversations can be
3. The Solution: Grouped Query Attention (GQA)
Here's the brilliant insight: Not every q head needs its own k and v!
Instead of 32 separate k/v heads, what if we used just 8 k/v heads and shared them across query heads?
From the Qwen3-8B configuration:
"num_attention_heads": 32, "num_key_value_heads": 8, groups = num_attention_heads / num_key_value_heads = 32 / 8 = 4 So each k/v head serves 4 query heads → GQA
Grouped Query Attention in Qwen3-8B
32 Query heads (q) ←──┐
│
8 Key heads (k) │ Share!
8 Value heads (v) ←───┘
Each k/v head serves 4 q heads (32 ÷ 8 = 4)
Grouping:
q heads 0-3 → Use k head 0, v head 0
q heads 4-7 → Use k head 1, v head 1
q heads 8-11 → Use k head 2, v head 2
q heads 12-15 → Use k head 3, v head 3
q heads 16-19 → Use k head 4, v head 4
q heads 20-23 → Use k head 5, v head 5
q heads 24-27 → Use k head 6, v head 6
q heads 28-31 → Use k head 7, v head 7
Memory Savings
With GQA:
8 K heads × 1000 tokens × 128 dimensions × 4 bytes = 4.1 MB
8 V heads × 1000 tokens × 128 dimensions × 4 bytes = 4.1 MB
Per layer: 8.2 MB (was 32.8 MB)
Total (36 layers): 295.2 MB (was 1.18 GB)
4× less memory! 🎉
How Does This Actually Work? A Step-by-Step Dimensional Analysis
Let's trace through the dimensional transformations in GQA for Qwen3-8B processing 3 input tokens.
Input Dimensions
Tokens: 3 (sequence_length)
Hidden size: 4096
Query heads: 32
K/V heads: 8 (GQA configuration)
Head dimension: 4096 / 32 = 128
Step 1: Input Embedding
Input tokens → Token embeddings
Dimensions: [sequence_length, hidden_size]
= [3, 4096]
Step 2: Linear Projections to Q, K, V
Query Projection (32 heads):
q_proj: [3, 4096] → [3, 32 × 128]
→ [3, 4096]
Reshape: [3, 4096] → [3, 32, 128]
[seq, hidden] → [seq, num_heads, head_dim]
Key Projection (8 heads - GQA!):
k_proj: [3, 4096] → [3, 8 × 128]
→ [3, 1024]
Reshape: [3, 1024] → [3, 8, 128]
[seq, hidden] → [seq, num_kv_heads, head_dim]
Value Projection (8 heads - GQA!):
v_proj: [3, 4096] → [3, 8 × 128]
→ [3, 1024]
Reshape: [3, 1024] → [3, 8, 128]
[seq, hidden] → [seq, num_kv_heads, head_dim]
Critical observation: K and V projections produce 4× fewer parameters than Q!
Step 3: Transpose for Attention Computation
q: [3, 32, 128] → [32, 3, 128]
[seq, heads, head_dim] → [heads, seq, head_dim]
k: [3, 8, 128] → [8, 3, 128]
[seq, kv_heads, head_dim] → [kv_heads, seq, head_dim]
v: [3, 8, 128] → [8, 3, 128]
[seq, kv_heads, head_dim] → [kv_heads, seq, head_dim]
Step 4: Repeat k and v to Match q Heads (The GQA Magic!)
We have 8 k/v heads but need 32 to match q. Solution: repeat each k/v head 4 times (32 ÷ 8 = 4).
# Original k and v dimensions
k: [8, 3, 128] (8 unique k heads)
v: [8, 3, 128] (8 unique v heads)
# Repeat to match q heads
k_expanded: [32, 3, 128] (8 heads × 4 repetitions = 32)
v_expanded: [32, 3, 128] (8 heads × 4 repetitions = 32)
How the repetition works:
Original k head 0 [3, 128] → copied to positions 0, 1, 2, 3
Original k head 1 [3, 128] → copied to positions 4, 5, 6, 7
Original k head 2 [3, 128] → copied to positions 8, 9, 10, 11
Original k head 3 [3, 128] → copied to positions 12, 13, 14, 15
Original k head 4 [3, 128] → copied to positions 16, 17, 18, 19
Original k head 5 [3, 128] → copied to positions 20, 21, 22, 23
Original k head 6 [3, 128] → copied to positions 24, 25, 26, 27
Original k head 7 [3, 128] → copied to positions 28, 29, 30, 31
Result: 32 k heads, but only 8 unique ones (each appears 4 times)
Step 5: Attention Score Computation
# Now all dimensions align for matrix multiplication
q: [32, 3, 128]
k^T: [32, 128, 3] (transposed last two dimensions)
Scores = q @ k^T
Dimensions: [32, 3, 128] @ [32, 128, 3]
= [32, 3, 3]
[heads, seq_q, seq_k]
# Each head gets a 3×3 attention matrix
# (3 query positions attending to 3 key positions)
Step 6: Apply Softmax and Multiply with v
Attention_weights = softmax(Scores)
Dimensions: [32, 3, 3]
Output = Attention_weights @ v
Dimensions: [32, 3, 3] @ [32, 3, 128]
= [32, 3, 128]
[heads, seq, head_dim]
Step 7: Concatenate Heads and Project
# Transpose back
[32, 3, 128] → [3, 32, 128]
# Reshape (concatenate all heads)
[3, 32, 128] → [3, 4096]
# Final output projection
o_proj: [3, 4096] → [3, 4096]
4. Memory Footprint Comparison
| Tensor | Standard MHA Dimensions | GQA Dimensions | Memory Ratio |
|---|---|---|---|
| k (cached) | [32, 40960, 128] | [8, 40960, 128] | 4:1 (75% savings) |
| v (cached) | [32, 40960, 128] | [8, 40960, 128] | 4:1 (75% savings) |
Combining GQA with Quantization
GQA becomes even more powerful when combined with quantization. Let's see the total RAM requirements:
| Configuration | Precision | Model Params | KV Cache | Total RAM |
|---|---|---|---|---|
| Standard (32 heads) | FP32 (4 bytes) | 32.00 GB | 46.08 GB | 81.12 GB |
| Standard (32 heads) | FP16 (2 bytes) | 16.00 GB | 23.04 GB | 40.56 GB |
| GQA (8 heads) | FP32 (4 bytes) | 32.00 GB | 11.52 GB | 46.56 GB |
| GQA (8 heads) | FP16 (2 bytes) | 16.00 GB | 5.76 GB | 23.28 GB |
| GQA (8 heads) | INT8 (1 byte) | 8.00 GB | 2.88 GB | 11.64 GB |
By grouping queries to share K/V heads, we get most of the benefits of multi-head attention with a fraction of the memory cost.
This optimization is crucial for making large language models practical in production environments, where memory efficiency directly translates to cost savings and better user experience.