Back to Blog
Attention That Fits in Memory
A city doesn't need to store every possible route between every pair of addresses. It stores the street map, and computes routes on demand. Storing all routes would require impossible amounts of space.
Standard attention stores the equivalent of all routes. Memory-efficient attention stores the map. Same destination, vastly different storage requirements.
The Memory Problem
def attention_memory_scaling():
"""
Why standard attention hits walls
"""
def standard_attention_peak_memory(seq_len: int) -> float:
# Store full N x N attention matrix
bytes_per_element = 2 # FP16
attention_matrix_bytes = seq_len * seq_len * bytes_per_element
return attention_matrix_bytes / 1e9 # GB
examples = {
"2K": standard_attention_peak_memory(2048), # 8 MB
"8K": standard_attention_peak_memory(8192), # 128 MB
"32K": standard_attention_peak_memory(32768), # 2 GB
"128K": standard_attention_peak_memory(131072), # 32 GB
}
problem = """
At 32K sequence length, each attention head needs 2 GB.
With 64 heads across layers, that's 128 GB just for attention matrices.
This is per sample in the batch.
"""
return examples, problem
Memory-Efficient Alternatives
def memory_efficient_variants():
return {
"flash_attention": {
"approach": "Tile-based computation, never materialize full matrix",
"peak_memory": "O(sqrt(N))",
"output": "Mathematically identical to standard",
"availability": "PyTorch 2.0+, flash-attn library",
},
"xformers_memory_efficient": {
"approach": "Similar tiling, optimized for various hardware",
"peak_memory": "O(N) in practice (much smaller constant)",
"output": "Mathematically identical",
"availability": "xformers library",
},
"chunked_attention": {
"approach": "Process attention in sequence chunks",
"peak_memory": "O(chunk_size²)",
"output": "Identical if done correctly",
"availability": "Various implementations",
},
"linear_attention": {
"approach": "Approximate attention with linear complexity",
"peak_memory": "O(N)",
"output": "Approximation, not identical",
"availability": "Research implementations",
},
}
Using Memory-Efficient Attention
def implementation_examples():
pytorch_sdpa = """
# PyTorch 2.0+ scaled_dot_product_attention
# Automatically uses memory-efficient implementation when possible
import torch.nn.functional as F
output = F.scaled_dot_product_attention(
query, key, value,
is_causal=True, # Causal mask for autoregressive
)
# PyTorch automatically selects best backend:
# - FlashAttention if available
# - Memory-efficient attention as fallback
# - Standard attention if nothing else works
"""
xformers_explicit = """
# Explicit xformers memory-efficient attention
from xformers.ops import memory_efficient_attention
output = memory_efficient_attention(
query, key, value,
attn_bias=LowerTriangularMask(), # Causal mask
)
"""
check_backend = """
# Verify which backend is being used
import torch
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=True,
enable_mem_efficient=True
) as ctx:
output = F.scaled_dot_product_attention(query, key, value)
# Check which kernel was selected
"""
return pytorch_sdpa, xformers_explicit, check_backend
When It Matters
def when_memory_efficient_helps():
return {
"high_impact": [
{
"scenario": "Long context (>8K tokens)",
"why": "Quadratic memory becomes dominant",
"savings": "10-100x peak memory reduction",
},
{
"scenario": "Training with limited GPU memory",
"why": "Backward pass needs to store activations",
"savings": "Enables larger batch sizes",
},
{
"scenario": "Fitting large models on smaller GPUs",
"why": "Attention memory competes with model weights",
"savings": "Enables serving at all",
},
],
"moderate_impact": [
{
"scenario": "Short context (<2K tokens)",
"why": "Memory is small anyway",
"note": "Still useful for speed benefits",
},
],
"verification": """
# Check if memory-efficient attention is being used
import torch
torch.backends.cuda.flash_sdp_enabled() # Flash attention?
torch.backends.cuda.mem_efficient_sdp_enabled() # xformers-style?
"""
}
The Output Comparison
def verify_equivalence():
"""
Proof that outputs are identical
"""
return {
"test_code": """
import torch
import torch.nn.functional as F
# Random inputs
q = torch.randn(1, 8, 1024, 64, device='cuda', dtype=torch.float16)
k = torch.randn(1, 8, 1024, 64, device='cuda', dtype=torch.float16)
v = torch.randn(1, 8, 1024, 64, device='cuda', dtype=torch.float16)
# Standard attention (for comparison)
def standard_attention(q, k, v):
scale = q.shape[-1] ** -0.5
scores = torch.matmul(q, k.transpose(-2, -1)) * scale
weights = torch.softmax(scores, dim=-1)
return torch.matmul(weights, v)
out_standard = standard_attention(q, k, v)
out_efficient = F.scaled_dot_product_attention(q, k, v)
# Check equivalence
diff = (out_standard - out_efficient).abs().max()
print(f"Max difference: {diff}") # Should be ~1e-3 or less
""",
"expected_result": """
Outputs are mathematically identical.
Small numerical differences (1e-3 to 1e-4) from:
- Different floating point accumulation order
- Not a quality concern
""",
}
Common Issues
def common_issues():
return {
"not_using_efficient": {
"symptom": "Memory higher than expected",
"cause": "Backend not available or not enabled",
"fix": "Check PyTorch version, CUDA version, GPU capability",
},
"shape_requirements": {
"symptom": "Falls back to standard attention",
"cause": "Head dimension not supported (must be 64, 128, etc.)",
"fix": "Ensure model architecture matches requirements",
},
"causal_mask": {
"symptom": "Wrong outputs in autoregressive generation",
"cause": "Forgot to enable causal mask",
"fix": "Set is_causal=True or use appropriate attn_bias",
},
}
Practical Checklist
def memory_efficient_attention_checklist():
return [
"[ ] Using PyTorch 2.0+ or xformers",
"[ ] GPU supports Flash Attention (Ampere+)",
"[ ] Head dimension is power of 2 (64, 128)",
"[ ] Causal mask set correctly for autoregressive",
"[ ] Verified outputs match standard attention",
"[ ] Measured actual memory reduction",
]
Memory-efficient attention is free performance. Same mathematical operation, less memory, often faster. If you're not using it, you're leaving capacity on the table. Modern frameworks enable it by default, but verify it's actually being used.