[Disclaimer: This blog was written solely for my understanding purpose only. Any mistakes found that need to be addressed, please feel free to reach out to me.]
Imagine if every time you wanted to continue a conversation, you had to re-read the entire chat history from the beginning, word by word, just to understand the context. That would be incredibly slow and wasteful, right? Yet this is exactly what Large Language Models (LLMs) would have to do without a clever optimization called KV Caching.
KV Caching is the unsung hero that makes modern AI chat interfaces possible. Without it, generating even a short response would take minutes instead of seconds. Let's explore how this seemingly simple concept revolutionized LLM inference.
1. Why We Need KV Caching
To understand the necessity of KV caching, let's first examine what happens during text generation without caching.
The Computational Problem
During autoregressive text generation, transformers generate one token at a time. For each new token, the model must compute attention over ALL previous tokens in the sequence.
Each computation requires loading the same data repeatedly from memory. Without caching, the model must:
- Recompute embeddings for all previous tokens
- Recompute Key and Value projections for all previous tokens
- Reload the same weight matrices repeatedly
- Perform redundant matrix multiplications
This results in both computational waste and memory bandwidth saturation.
Why KV Caching and Not QKV Caching?
You might wonder: "If attention uses Q, K, and V, why do we only cache K and V? Why not cache Q as well?"
π Q (Query): Position-Dependent
- Changes every step: Qβ, Qβ, Qβ... are all different
- Position-aware: Includes positional encoding for current generation step
- Context-sensitive: "What should I look for given where I am in the sequence?"
- Cannot reuse: Q for position 5 is useless for position 500
ποΈ K,V (Key, Value): Content-Intrinsic
- Token-specific: Kβhelloβ and Vβhelloβ are always the same regardless of position
- Context-independent: "Here's what information I contain"
- Reusable: Once computed, they never change
- Cacheable: Can be stored and reused across generation steps
Mathematical Reason:
For generating token at position t:
q_t = f(position_t, full_context_up_to_t) # Must be recomputed
For existing token i:
k_i = f(token_i) # Only depends on the token itself
v_i = f(token_i) # Only depends on the token itself
Since k and v are token-intrinsic (don't change based on position), they can be cached.
Since q is position-dependent (changes for each generation step), it cannot be cached.
2. How KV Caching Works: A Dimension-Focused Explanation
Let's trace through KV caching using concrete dimensions. We'll use a simplified example with:
- Model dimension (d_model): 512
- Number of attention heads: 1
- Head dimension (d_head): 512
- Number of layers: 1
- Vocabulary size: 50,000
Step 1: Input Embeddings and Projections
Input: Token sequence ["Hello", "world"]
# Step 1: Token to embeddings
input_ids = [15496, 995] # Token IDs for ["Hello", "world"]
sequence_length = 2
# Embedding lookup
embeddings = embedding_table[input_ids] # Shape: [2, 512]
# Each token becomes a 512-dimensional vector
Step 2: Weight Matrix Projections
The model has three learned weight matrices for each attention head:
# Weight matrices (learned during training)
W_Q = torch.randn(512, 512) # Query projection matrix
W_K = torch.randn(512, 512) # Key projection matrix
W_V = torch.randn(512, 512) # Value projection matrix
# Project embeddings to q, k, v (vectors)
q = embeddings @ W_Q # [2, 512] @ [512, 512] = [2, 512]
k = embeddings @ W_K # [2, 512] @ [512, 512] = [2, 512]
v = embeddings @ W_V # [2, 512] @ [512, 512] = [2, 512]
Key Insight: At this point, k and v contain all the information about tokens ["Hello", "world"]. These values are independent of what comes next!
Step 3: First Generation Step - Computing Attention
Goal: Generate the next token after ["Hello", "world"]
Here's the crucial insight: the current query depends on ALL previous keys and values:
# Current query for position 3 (next token generation)
q_current = compute_query_for_position(3) # [1, 512] - what we're looking for
# This query DEPENDS ON the cached keys and values:
k_cached = [k_hello, k_world] # [2, 512] - what info is available
v_cached = [v_hello, v_world] # [2, 512] - actual content available
# Attention computation shows the dependency:
attention_scores = q_current @ k_cached.transpose(-1, -2)
# [1, 512] @ [512, 2] = [1, 2]
attention_weights = softmax(attention_scores) # [1, 2]
# This gives us: [score_for_hello, score_for_world]
# Example result: [0.3, 0.7] meaning:
# - 30% attention to "Hello"
# - 70% attention to "world"
# The context is weighted combination of ALL previous values:
context = attention_weights @ v_cached # [1, 2] @ [2, 512] = [1, 512]
# Expanded this means:
context = 0.3 * v_hello + 0.7 * v_world
# β β
# depends on kβ,vβ depends on kβ,vβ
# This context vector [512] contains information from BOTH previous tokens
# and determines what the next token should be
Key Dependency Insight:
The current query qβ doesn't just "look at" kβ,kβ and vβ,vβ - it mathematically depends on them:
- qβ Γ kβ determines how much to attend to "Hello"
- qβ Γ kβ determines how much to attend to "world"
- Final output = weighted_sum(vβ, vβ) based on those attention scores
Without kβ,kβ,vβ,vβ, we cannot compute the next token!
Step 4: KV Caching - Store Computed Values
Instead of discarding k and v, we save them:
# Cache the computed k and v
kv_cache = {
'keys': k, # [2, 512] - Keys for ["Hello", "world"]
'values': v # [2, 512] - Values for ["Hello", "world"]
}
# These cached values represent the "meaning" of previous tokens
# They won't change in future generation steps!
Step 5: Second Generation Step - Reusing Cache
Goal: Generate the token after ["Hello", "world", "!"]
# New token: "!" (assume we generated this)
new_token_id = 0 # "!" token ID
new_embedding = embedding_table[new_token_id] # [512]
# Compute q, k, v ONLY for the new token
q_new = new_embedding @ W_Q # [1, 512] @ [512, 512] = [1, 512]
k_new = new_embedding @ W_K # [1, 512] @ [512, 512] = [1, 512]
v_new = new_embedding @ W_V # [1, 512] @ [512, 512] = [1, 512]
# Retrieve cached k, v
k_cached = kv_cache['keys'] # [2, 512] for ["Hello", "world"]
v_cached = kv_cache['values'] # [2, 512] for ["Hello", "world"]
# Concatenate: cached + new
k_total = torch.cat([k_cached, k_new], dim=0) # [3, 512] for all tokens
v_total = torch.cat([v_cached, v_new], dim=0) # [3, 512] for all tokens
# Update cache for next iteration
kv_cache['keys'] = k_total # [3, 512]
kv_cache['values'] = v_total # [3, 512]
Step 6: Attention with Mixed Cached and New Data
# Compute attention using only the new query
attention_scores = q_new @ k_total.transpose(-1, -2)
# [1, 512] @ [512, 3] = [1, 3]
# This gives us how much the new token should attend to all previous tokens:
# attention_scores = [score_for_hello, score_for_world, score_for_exclamation]
attention_weights = softmax(attention_scores) # [1, 3]
context = attention_weights @ v_total # [1, 3] @ [3, 512] = [1, 512]
Efficiency Gain:
- β Saved computations: No recomputation of k,v for ["Hello", "world"]
- β Saved memory access: No reloading of previous embeddings
- β Saved matrix multiplications: Only 2 new projections instead of 6 total
Dimension Flow Summary
Embeddings: [3, 512] (all tokens)
q, k, v: [3, 512] each (recompute everything)
With Caching (Step 2):
New Embedding: [1, 512] (only new token)
q_new, k_new, v_new: [1, 512] each (only new computations)
k_total, v_total: [3, 512] (cached + new) Computation Reduction: 66% fewer matrix multiplications!
3. Total RAM Requirements for Qwen3-8B with Standard KV Caching
Let's calculate the exact RAM requirements to load and run the Qwen3-8B model with standard KV caching (before GQA optimization). We'll use the actual model configuration:
- hidden_size: 4096
- max_position_embeddings: 40960 (maximum sequence length)
- model_type: qwen3
- num_attention_heads: 32
- num_hidden_layers: 36
- Precision: 4 bytes per parameter (float32)
For Standard KV Caching Analysis:
- Each attention head gets its own K,V cache (32 heads total)
- Head dimension: 4096 / 32 = 128 dimensions per head
1. Model Parameters Memory
# Qwen3-8B model parameters
model_parameters = 8_000_000_000 # 8 billion parameters
bytes_per_param = 4 # float32 = 4 bytes
model_memory = model_parameters * bytes_per_param
model_memory = 8,000,000,000 * 4 = 32,000,000,000 bytes
model_memory = 32 GB
2. KV Cache Memory (Standard Attention)
Scenario: Processing the maximum sequence length (40,960 tokens)
# KV cache dimensions for standard attention (NOT GQA)
# Each attention head maintains separate K,V caches
sequence_length = 40960 # max_position_embeddings
num_layers = 36
num_attention_heads = 32 # All heads get separate K,V (no GQA)
head_dimension = 4096 // 32 # 128
bytes_per_element = 4 # float32
# Per layer KV cache calculation
k_cache_per_layer = sequence_length * num_attention_heads * head_dimension * bytes_per_element
v_cache_per_layer = sequence_length * num_attention_heads * head_dimension * bytes_per_element
k_cache_per_layer = 40960 * 32 * 128 * 4 = 671,088,640 bytes = 640 MB
v_cache_per_layer = 40960 * 32 * 128 * 4 = 671,088,640 bytes = 640 MB
# Total per layer: K + V
cache_per_layer = k_cache_per_layer + v_cache_per_layer = 1,342,177,280 bytes = 1.28 GB
# Across all layers
total_kv_cache = cache_per_layer * num_layers
total_kv_cache = 1.28 GB * 36 = 46.08 GB
3. Additional Memory Components
# Activation buffers for forward pass
activation_memory = sequence_length * hidden_size * bytes_per_element
activation_memory = 40960 * 4096 * 4 = 671,088,640 bytes = 640 MB
# Intermediate buffers (rough estimate)
intermediate_memory = 2 * 1024 * 1024 * 1024 # ~2 GB
intermediate_memory = 2 GB
# Total additional memory
additional_memory = activation_memory + intermediate_memory
additional_memory = 0.64 + 2 = 2.64 GB
4. Total RAM Calculation
| Component | Memory (GB) | Percentage | Details |
|---|---|---|---|
| Model Parameters | 32.00 | 39.5% | 8B params Γ 4 bytes |
| KV Cache (40,960 tokens) | 46.08 | 56.8% | 32 heads Γ 36 layers Γ float32 |
| Activations | 0.64 | 0.8% | Forward pass buffers |
| Intermediate Buffers | 2.00 | 2.5% | Temporary computations |
| Memory Overhead | 0.40 | 0.5% | System/framework overhead |
| Total RAM Required | 81.12 GB | 100% | Minimum for full sequence |
Impact of Different Precision Levels
The calculation above uses float32 (4 bytes per parameter). Let's see how different quantization levels affect total RAM requirements:
| Precision | Bytes/Param | Model Params | KV Cache | Other | Total RAM |
|---|---|---|---|---|---|
| FP32 (Float32) | 4 | 32.00 GB | 46.08 GB | 3.04 GB | 81.12 GB |
| FP16 (Float16) | 2 | 16.00 GB | 23.04 GB | 1.52 GB | 40.56 GB |
| INT8 (8-bit) | 1 | 8.00 GB | 11.52 GB | 0.76 GB | 20.28 GB |
| INT4 (4-bit) | 0.5 | 4.00 GB | 5.76 GB | 0.38 GB | 10.14 GB |
KV caching greatly improves text-generation speed and efficiency, but running the full 40k context in full 32-bit precision still needs an 80GB-class GPU. Even though modern deployments commonly use INT8 (8-bit) β which typically reduces memory to around 8.0 GB for weights, 11.52 GB for the KV cache, and 0.76 GB for activations (β20.28 GB total) β models like Qwen use Grouped Query Attention (GQA), which significantly reduces RAM usage, and this is explained in the next blog.