Back to Blog
What Flash Attention Actually Does
Hard drives are slower than RAM. But reading sequentially from a hard drive can beat random access to RAM. The bottleneck isn't the storage medium. It's the access pattern.
Flash Attention applies this insight to GPUs. The operation isn't fundamentally different. The memory access pattern is. That pattern change enables longer sequences and faster execution as a bonus.
The Standard Attention Problem
def standard_attention_memory():
"""
Why attention doesn't scale
"""
# For sequence length N, attention computes:
# Q @ K^T -> N x N matrix
# softmax(scores) -> N x N matrix
# attention @ V -> N x d output
# The problem: that N x N matrix
seq_length = 8192
bytes_per_element = 2 # FP16
attention_matrix_bytes = seq_length * seq_length * bytes_per_element
attention_matrix_gb = attention_matrix_bytes / 1e9
# 8K sequence: 128 MB per attention matrix
# 32K sequence: 2 GB per attention matrix
# And you have this for every layer, every head
return {
"8k_sequence": f"{8192 * 8192 * 2 / 1e6:.0f} MB",
"32k_sequence": f"{32768 * 32768 * 2 / 1e9:.1f} GB",
"problem": "Memory grows quadratically with sequence length",
}
What Flash Attention Changes
class FlashAttentionExplained:
"""
The core insight: never materialize the full attention matrix
"""
standard_approach = """
1. Compute Q @ K^T (full N x N matrix)
2. Apply softmax
3. Multiply by V
4. Store all intermediate results in HBM (slow GPU memory)
Memory reads/writes: O(N²)
"""
flash_approach = """
1. Load small blocks of Q, K, V into SRAM (fast on-chip memory)
2. Compute attention for that block
3. Update running statistics for softmax
4. Accumulate result
5. Never store full N x N matrix
Memory reads/writes: O(N) for the attention matrix
"""
why_it_works = """
GPU has two types of memory:
- HBM (High Bandwidth Memory): 80GB, 3 TB/s
- SRAM (on-chip): 20MB per SM, 19 TB/s
SRAM is 6x faster but 4000x smaller.
Standard attention: constantly reading/writing HBM
Flash Attention: does more compute in SRAM, fewer HBM trips
The algorithm is mathematically identical.
The implementation is memory-access optimized.
"""
The Tiling Strategy
def tiling_explanation():
"""
How Flash Attention processes in blocks
"""
return {
"step_1": {
"action": "Divide Q into blocks of size Br",
"reason": "Each block fits in SRAM",
},
"step_2": {
"action": "Divide K, V into blocks of size Bc",
"reason": "Process attention in tiles",
},
"step_3": {
"action": "For each Q block, iterate through all K,V blocks",
"compute": "Partial attention scores",
},
"step_4": {
"action": "Use online softmax trick",
"compute": "Update running max and sum for numerical stability",
},
"step_5": {
"action": "Accumulate output block",
"result": "Never materialize N x N matrix",
},
}
Memory Savings
def memory_comparison():
"""
Real memory savings from Flash Attention
"""
return {
"standard_attention": {
"attention_matrix": "O(N²)",
"32k_seq_per_head": "2 GB",
"64_heads": "128 GB (just for attention!)",
},
"flash_attention": {
"attention_matrix": "O(1) - never fully materialized",
"block_size": "Typically 128 or 256",
"memory_per_block": "128 * 128 * 2 = 32 KB",
"total": "Much smaller working set",
},
"practical_impact": {
"before": "32K context barely fits",
"after": "128K context becomes feasible",
"why": "Removed quadratic memory term",
},
}
Speed Improvement
def speed_improvement():
"""
Why Flash Attention is faster (as a side effect)
"""
return {
"not_because": [
"Fewer FLOPs (actually similar or more)",
"Better algorithms for attention",
"Approximations",
],
"because": [
"Fewer memory round trips",
"Better cache utilization",
"More compute per memory load",
"Fused kernels (fewer kernel launches)",
],
"typical_speedup": {
"vs_standard_pytorch": "2-4x faster",
"on_long_sequences": "Even larger speedup",
"why_varies": "Depends on sequence length and hardware",
},
}
Implementation Details
def flash_attention_usage():
"""
How to use Flash Attention
"""
pytorch = """
# PyTorch 2.0+ has built-in Flash Attention
import torch.nn.functional as F
# Automatically uses Flash Attention when possible
output = F.scaled_dot_product_attention(
query, key, value,
is_causal=True, # For autoregressive decoding
)
"""
explicit = """
# Explicit Flash Attention via flash-attn library
from flash_attn import flash_attn_func
output = flash_attn_func(
q, k, v,
causal=True,
softmax_scale=1.0 / math.sqrt(head_dim),
)
"""
vllm = """
# vLLM uses Flash Attention by default
# No configuration needed
# Falls back to other implementations if not available
"""
return pytorch, explicit, vllm
When Flash Attention Helps Most
def when_flash_helps():
return {
"high_impact": [
{
"scenario": "Long context (>4K tokens)",
"why": "Quadratic memory savings compound",
},
{
"scenario": "Training (backward pass)",
"why": "Recomputation is cheaper than storing",
},
{
"scenario": "Memory-constrained serving",
"why": "Enables sequences that wouldn't fit",
},
],
"moderate_impact": [
{
"scenario": "Short context serving",
"why": "Memory isn't the bottleneck anyway",
},
{
"scenario": "Already memory-efficient setup",
"why": "Smaller relative improvement",
},
],
"note": """
Flash Attention for inference mostly helps with memory.
The speed improvement is nice but secondary.
For short sequences, the difference is minimal.
"""
}
What Flash Attention Doesn't Do
def common_misconceptions():
return [
{
"myth": "Flash Attention approximates attention",
"truth": "It's mathematically identical, just computed differently",
},
{
"myth": "Flash Attention reduces FLOPs",
"truth": "Same or more FLOPs, fewer memory operations",
},
{
"myth": "Flash Attention helps with KV cache",
"truth": "KV cache is separate. Flash Attention helps with attention computation",
},
{
"myth": "Flash Attention always makes things faster",
"truth": "For short sequences, overhead can outweigh benefits",
},
]
Flash Attention is a memory access optimization that happens to speed things up. Understanding this distinction matters: if you're optimizing for speed on short sequences, Flash Attention isn't your answer. If you're trying to serve longer contexts, it's essential.