Back to Blog
The Cache That Makes LLMs Possible
Video streaming buffers ahead because downloading frame-by-frame would be unusable. You'd wait for each frame to download before seeing the next. The buffer trades memory for smoothness.
The KV cache is LLM inference's buffer. Without it, each new token would require recomputing attention for every previous token. The cache makes generation fast by trading memory for compute.
Why The Cache Exists
def without_kv_cache():
"""
Naive autoregressive generation
"""
# To generate token N, transformer must:
# 1. Process all N-1 previous tokens
# 2. Compute attention between token N and all previous tokens
# 3. Repeat for each of N-1 previous tokens
# Generating 100 tokens:
# Token 1: process 1 token
# Token 2: process 2 tokens
# Token 3: process 3 tokens
# ...
# Token 100: process 100 tokens
total_forward_passes = sum(range(1, 101)) # 5,050 passes!
# That's 50x more compute than necessary
def with_kv_cache():
"""
Cached autoregressive generation
"""
# Cache stores Key and Value vectors for all previous tokens
# To generate token N:
# 1. Only process the new token
# 2. Compute attention using cached K,V vectors
# 3. Add new K,V to cache
# Generating 100 tokens:
# Token 1: process 1 token, cache K,V
# Token 2: process 1 token + cached attention, cache K,V
# Token 3: process 1 token + cached attention, cache K,V
# ...
# Token 100: process 1 token + cached attention, cache K,V
total_forward_passes = 100 # Linear!
What Gets Cached
class KVCacheStructure:
"""
The actual data stored in KV cache
"""
def __init__(self, config):
# Per layer, per head
self.num_layers = config.num_layers # e.g., 80 for Llama-70B
self.num_heads = config.num_heads # e.g., 64
self.head_dim = config.head_dim # e.g., 128
def cache_size_per_token(self) -> int:
"""Bytes per token in cache"""
bytes_per_element = 2 # FP16
# K and V, both stored
return (
2 * # K and V
self.num_layers * # Every layer
self.num_heads * # Every head
self.head_dim * # Vector dimension
bytes_per_element
)
def example_llama_70b(self):
# 80 layers × 64 heads × 128 dim × 2 (K,V) × 2 bytes
# = 2,621,440 bytes per token
# = 2.5 MB per token!
# 1000 token context = 2.5 GB KV cache
# 32K context = 80 GB KV cache (!)
pass
The Attention Mechanism
def attention_with_cache(query, key_cache, value_cache, new_key, new_value):
"""
Simplified attention with KV cache
"""
# Update cache with new token's K and V
key_cache = torch.cat([key_cache, new_key], dim=1)
value_cache = torch.cat([value_cache, new_value], dim=1)
# Compute attention scores
# Query: just the new token (shape: [1, head_dim])
# Keys: all tokens including new (shape: [seq_len, head_dim])
attention_scores = torch.matmul(query, key_cache.transpose(-2, -1))
attention_scores = attention_scores / math.sqrt(head_dim)
attention_probs = torch.softmax(attention_scores, dim=-1)
# Compute output
# Values: all tokens including new (shape: [seq_len, head_dim])
output = torch.matmul(attention_probs, value_cache)
return output, key_cache, value_cache
The Memory Trade-off
def memory_tradeoff_analysis():
return {
"without_cache": {
"compute": "O(n²) forward passes for n tokens",
"memory": "O(1) - just model weights",
"practical": "Unusably slow for generation",
},
"with_cache": {
"compute": "O(n) forward passes for n tokens",
"memory": "O(n) - grows with context length",
"practical": "Fast generation, memory becomes limit",
},
"tradeoff": """
We trade O(n) memory for O(n) speedup.
For long contexts, this memory cost dominates.
This is why 100K context models need huge GPUs.
""",
}
Cache Management Strategies
class CacheManagement:
"""
How production systems handle KV cache
"""
strategies = {
"pre_allocation": {
"approach": "Allocate cache for max_tokens upfront",
"pro": "No allocation during generation",
"con": "Wastes memory if requests are short",
},
"dynamic_allocation": {
"approach": "Grow cache as tokens are generated",
"pro": "Only use what's needed",
"con": "Allocation overhead, fragmentation",
},
"paged_allocation": {
"approach": "Allocate in fixed-size pages (vLLM)",
"pro": "Best of both worlds",
"con": "Slightly more complex",
},
}
optimizations = {
"quantization": "Store cache in INT8 (2x memory saving)",
"prefix_sharing": "Multiple requests share same prefix cache",
"eviction": "Remove old tokens when memory-constrained",
"offloading": "Move old cache to CPU (adds latency)",
}
Why This Matters for Serving
def serving_implications():
return {
"concurrent_requests": {
"issue": "Each request needs its own KV cache",
"impact": "More concurrent = more memory",
"calculation": "GPU_memory = model + sum(kv_caches)",
},
"context_length": {
"issue": "Longer context = bigger cache",
"impact": "Fewer concurrent long-context requests",
"calculation": "Cache grows linearly with context",
},
"model_size": {
"issue": "Bigger models = bigger cache per token",
"impact": "70B model cache >> 7B model cache",
"calculation": "More layers, more heads = more storage",
},
"batch_size": {
"issue": "Batching multiplies cache size",
"impact": "Batch of 8 needs 8x the cache",
"calculation": "Total = batch_size × per_request_cache",
},
}
The Numbers That Matter
def kv_cache_cheat_sheet():
"""
Quick reference for common models
"""
return {
"llama_7b": {
"cache_per_token_bytes": 500_000,
"1k_context_gb": 0.5,
"8k_context_gb": 4,
},
"llama_70b": {
"cache_per_token_bytes": 2_500_000,
"1k_context_gb": 2.5,
"8k_context_gb": 20,
"32k_context_gb": 80,
},
"key_insight": """
KV cache often exceeds model weights for long context.
A 70B model is 140GB.
Its 32K context KV cache is 80GB PER REQUEST.
"""
}
The KV cache is why LLM generation is tractable. It's also why long context is expensive, why concurrent requests need lots of memory, and why inference optimization is largely about cache management.
Understanding the cache is understanding half of inference optimization.