[Disclaimer: This blog was written solely for my understanding purpose only. Any mistakes found that need to be addressed, please feel free to reach out to me.]

Imagine if every time you wanted to continue a conversation, you had to re-read the entire chat history from the beginning, word by word, just to understand the context. That would be incredibly slow and wasteful, right? Yet this is exactly what Large Language Models (LLMs) would have to do without a clever optimization called KV Caching.

KV Caching is the unsung hero that makes modern AI chat interfaces possible. Without it, generating even a short response would take minutes instead of seconds. Let's explore how this seemingly simple concept revolutionized LLM inference.

1. Why We Need KV Caching

To understand the necessity of KV caching, let's first examine what happens during text generation without caching.

The Computational Problem

During autoregressive text generation, transformers generate one token at a time. For each new token, the model must compute attention over ALL previous tokens in the sequence.

Each computation requires loading the same data repeatedly from memory. Without caching, the model must:

This results in both computational waste and memory bandwidth saturation.

Why KV Caching and Not QKV Caching?

You might wonder: "If attention uses Q, K, and V, why do we only cache K and V? Why not cache Q as well?"

πŸ” Q (Query): Position-Dependent

  • Changes every step: Q₁, Qβ‚‚, Q₃... are all different
  • Position-aware: Includes positional encoding for current generation step
  • Context-sensitive: "What should I look for given where I am in the sequence?"
  • Cannot reuse: Q for position 5 is useless for position 500

πŸ—οΈ K,V (Key, Value): Content-Intrinsic

  • Token-specific: K₍helloβ‚Ž and V₍helloβ‚Ž are always the same regardless of position
  • Context-independent: "Here's what information I contain"
  • Reusable: Once computed, they never change
  • Cacheable: Can be stored and reused across generation steps

Mathematical Reason:

For generating token at position t:
q_t = f(position_t, full_context_up_to_t)  # Must be recomputed

For existing token i:
k_i = f(token_i)  # Only depends on the token itself
v_i = f(token_i)  # Only depends on the token itself

Since k and v are token-intrinsic (don't change based on position), they can be cached.
Since q is position-dependent (changes for each generation step), it cannot be cached.

2. How KV Caching Works: A Dimension-Focused Explanation

Let's trace through KV caching using concrete dimensions. We'll use a simplified example with:

Model Configuration:
- Model dimension (d_model): 512
- Number of attention heads: 1
- Head dimension (d_head): 512
- Number of layers: 1
- Vocabulary size: 50,000

Step 1: Input Embeddings and Projections

Input: Token sequence ["Hello", "world"]

# Step 1: Token to embeddings
input_ids = [15496, 995]  # Token IDs for ["Hello", "world"]
sequence_length = 2

# Embedding lookup
embeddings = embedding_table[input_ids]  # Shape: [2, 512]
# Each token becomes a 512-dimensional vector

Step 2: Weight Matrix Projections

The model has three learned weight matrices for each attention head:

# Weight matrices (learned during training)
W_Q = torch.randn(512, 512)  # Query projection matrix
W_K = torch.randn(512, 512)  # Key projection matrix  
W_V = torch.randn(512, 512)  # Value projection matrix

# Project embeddings to q, k, v (vectors)
q = embeddings @ W_Q  # [2, 512] @ [512, 512] = [2, 512]
k = embeddings @ W_K  # [2, 512] @ [512, 512] = [2, 512]  
v = embeddings @ W_V  # [2, 512] @ [512, 512] = [2, 512]

Key Insight: At this point, k and v contain all the information about tokens ["Hello", "world"]. These values are independent of what comes next!

Step 3: First Generation Step - Computing Attention

Goal: Generate the next token after ["Hello", "world"]

Here's the crucial insight: the current query depends on ALL previous keys and values:

# Current query for position 3 (next token generation)
q_current = compute_query_for_position(3)  # [1, 512] - what we're looking for

# This query DEPENDS ON the cached keys and values:
k_cached = [k_hello, k_world]  # [2, 512] - what info is available
v_cached = [v_hello, v_world]  # [2, 512] - actual content available

# Attention computation shows the dependency:
attention_scores = q_current @ k_cached.transpose(-1, -2)  
# [1, 512] @ [512, 2] = [1, 2]

attention_weights = softmax(attention_scores)  # [1, 2]

# This gives us: [score_for_hello, score_for_world]
# Example result: [0.3, 0.7] meaning:
# - 30% attention to "Hello" 
# - 70% attention to "world"

# The context is weighted combination of ALL previous values:
context = attention_weights @ v_cached  # [1, 2] @ [2, 512] = [1, 512]

# Expanded this means:
context = 0.3 * v_hello + 0.7 * v_world
#         ↑               ↑
#    depends on kβ‚€,vβ‚€   depends on k₁,v₁

# This context vector [512] contains information from BOTH previous tokens
# and determines what the next token should be

Key Dependency Insight:

The current query qβ‚‚ doesn't just "look at" kβ‚€,k₁ and vβ‚€,v₁ - it mathematically depends on them:

  • qβ‚‚ Γ— kβ‚€ determines how much to attend to "Hello"
  • qβ‚‚ Γ— k₁ determines how much to attend to "world"
  • Final output = weighted_sum(vβ‚€, v₁) based on those attention scores

Without kβ‚€,k₁,vβ‚€,v₁, we cannot compute the next token!

Step 4: KV Caching - Store Computed Values

Instead of discarding k and v, we save them:

# Cache the computed k and v
kv_cache = {
    'keys': k,    # [2, 512] - Keys for ["Hello", "world"]  
    'values': v   # [2, 512] - Values for ["Hello", "world"]
}

# These cached values represent the "meaning" of previous tokens
# They won't change in future generation steps!

Step 5: Second Generation Step - Reusing Cache

Goal: Generate the token after ["Hello", "world", "!"]

# New token: "!" (assume we generated this)
new_token_id = 0  # "!" token ID
new_embedding = embedding_table[new_token_id]  # [512]

# Compute q, k, v ONLY for the new token
q_new = new_embedding @ W_Q  # [1, 512] @ [512, 512] = [1, 512]
k_new = new_embedding @ W_K  # [1, 512] @ [512, 512] = [1, 512]  
v_new = new_embedding @ W_V  # [1, 512] @ [512, 512] = [1, 512]

# Retrieve cached k, v
k_cached = kv_cache['keys']    # [2, 512] for ["Hello", "world"]
v_cached = kv_cache['values']  # [2, 512] for ["Hello", "world"]

# Concatenate: cached + new
k_total = torch.cat([k_cached, k_new], dim=0)  # [3, 512] for all tokens
v_total = torch.cat([v_cached, v_new], dim=0)  # [3, 512] for all tokens

# Update cache for next iteration
kv_cache['keys'] = k_total    # [3, 512]
kv_cache['values'] = v_total  # [3, 512]

Step 6: Attention with Mixed Cached and New Data

# Compute attention using only the new query
attention_scores = q_new @ k_total.transpose(-1, -2)  
# [1, 512] @ [512, 3] = [1, 3]

# This gives us how much the new token should attend to all previous tokens:
# attention_scores = [score_for_hello, score_for_world, score_for_exclamation]

attention_weights = softmax(attention_scores)  # [1, 3]
context = attention_weights @ v_total  # [1, 3] @ [3, 512] = [1, 512]

Efficiency Gain:

  • βœ… Saved computations: No recomputation of k,v for ["Hello", "world"]
  • βœ… Saved memory access: No reloading of previous embeddings
  • βœ… Saved matrix multiplications: Only 2 new projections instead of 6 total

Dimension Flow Summary

Without Caching (Step 2):
Embeddings: [3, 512] (all tokens)
q, k, v: [3, 512] each (recompute everything)

With Caching (Step 2):
New Embedding: [1, 512] (only new token)
q_new, k_new, v_new: [1, 512] each (only new computations)
k_total, v_total: [3, 512] (cached + new) Computation Reduction: 66% fewer matrix multiplications!

3. Total RAM Requirements for Qwen3-8B with Standard KV Caching

Let's calculate the exact RAM requirements to load and run the Qwen3-8B model with standard KV caching (before GQA optimization). We'll use the actual model configuration:

Qwen3-8B Configuration:
- hidden_size: 4096
- max_position_embeddings: 40960 (maximum sequence length)
- model_type: qwen3
- num_attention_heads: 32
- num_hidden_layers: 36
- Precision: 4 bytes per parameter (float32)

For Standard KV Caching Analysis:
- Each attention head gets its own K,V cache (32 heads total)
- Head dimension: 4096 / 32 = 128 dimensions per head

1. Model Parameters Memory

# Qwen3-8B model parameters
model_parameters = 8_000_000_000  # 8 billion parameters
bytes_per_param = 4  # float32 = 4 bytes

model_memory = model_parameters * bytes_per_param
model_memory = 8,000,000,000 * 4 = 32,000,000,000 bytes
model_memory = 32 GB

2. KV Cache Memory (Standard Attention)

Scenario: Processing the maximum sequence length (40,960 tokens)

# KV cache dimensions for standard attention (NOT GQA)
# Each attention head maintains separate K,V caches

sequence_length = 40960  # max_position_embeddings
num_layers = 36
num_attention_heads = 32  # All heads get separate K,V (no GQA)
head_dimension = 4096 // 32  # 128
bytes_per_element = 4  # float32

# Per layer KV cache calculation
k_cache_per_layer = sequence_length * num_attention_heads * head_dimension * bytes_per_element
v_cache_per_layer = sequence_length * num_attention_heads * head_dimension * bytes_per_element

k_cache_per_layer = 40960 * 32 * 128 * 4 = 671,088,640 bytes = 640 MB
v_cache_per_layer = 40960 * 32 * 128 * 4 = 671,088,640 bytes = 640 MB

# Total per layer: K + V
cache_per_layer = k_cache_per_layer + v_cache_per_layer = 1,342,177,280 bytes = 1.28 GB

# Across all layers
total_kv_cache = cache_per_layer * num_layers
total_kv_cache = 1.28 GB * 36 = 46.08 GB

3. Additional Memory Components

# Activation buffers for forward pass
activation_memory = sequence_length * hidden_size * bytes_per_element
activation_memory = 40960 * 4096 * 4 = 671,088,640 bytes = 640 MB

# Intermediate buffers (rough estimate)
intermediate_memory = 2 * 1024 * 1024 * 1024  # ~2 GB
intermediate_memory = 2 GB

# Total additional memory
additional_memory = activation_memory + intermediate_memory
additional_memory = 0.64 + 2 = 2.64 GB

4. Total RAM Calculation

Component Memory (GB) Percentage Details
Model Parameters 32.00 39.5% 8B params Γ— 4 bytes
KV Cache (40,960 tokens) 46.08 56.8% 32 heads Γ— 36 layers Γ— float32
Activations 0.64 0.8% Forward pass buffers
Intermediate Buffers 2.00 2.5% Temporary computations
Memory Overhead 0.40 0.5% System/framework overhead
Total RAM Required 81.12 GB 100% Minimum for full sequence

Impact of Different Precision Levels

The calculation above uses float32 (4 bytes per parameter). Let's see how different quantization levels affect total RAM requirements:

Precision Bytes/Param Model Params KV Cache Other Total RAM
FP32 (Float32) 4 32.00 GB 46.08 GB 3.04 GB 81.12 GB
FP16 (Float16) 2 16.00 GB 23.04 GB 1.52 GB 40.56 GB
INT8 (8-bit) 1 8.00 GB 11.52 GB 0.76 GB 20.28 GB
INT4 (4-bit) 0.5 4.00 GB 5.76 GB 0.38 GB 10.14 GB

KV caching greatly improves text-generation speed and efficiency, but running the full 40k context in full 32-bit precision still needs an 80GB-class GPU. Even though modern deployments commonly use INT8 (8-bit) β€” which typically reduces memory to around 8.0 GB for weights, 11.52 GB for the KV cache, and 0.76 GB for activations (β‰ˆ20.28 GB total) β€” models like Qwen use Grouped Query Attention (GQA), which significantly reduces RAM usage, and this is explained in the next blog.