Back to Blog

What Flash Attention Actually Does

Hard drives are slower than RAM. But reading sequentially from a hard drive can beat random access to RAM. The bottleneck isn't the storage medium. It's the access pattern.

Flash Attention applies this insight to GPUs. The operation isn't fundamentally different. The memory access pattern is. That pattern change enables longer sequences and faster execution as a bonus.

The Standard Attention Problem

def standard_attention_memory():
    """
    Why attention doesn't scale
    """
    # For sequence length N, attention computes:
    # Q @ K^T -> N x N matrix
    # softmax(scores) -> N x N matrix
    # attention @ V -> N x d output

    # The problem: that N x N matrix
    seq_length = 8192
    bytes_per_element = 2  # FP16

    attention_matrix_bytes = seq_length * seq_length * bytes_per_element
    attention_matrix_gb = attention_matrix_bytes / 1e9

    # 8K sequence: 128 MB per attention matrix
    # 32K sequence: 2 GB per attention matrix
    # And you have this for every layer, every head

    return {
        "8k_sequence": f"{8192 * 8192 * 2 / 1e6:.0f} MB",
        "32k_sequence": f"{32768 * 32768 * 2 / 1e9:.1f} GB",
        "problem": "Memory grows quadratically with sequence length",
    }

What Flash Attention Changes

class FlashAttentionExplained:
    """
    The core insight: never materialize the full attention matrix
    """

    standard_approach = """
    1. Compute Q @ K^T (full N x N matrix)
    2. Apply softmax
    3. Multiply by V
    4. Store all intermediate results in HBM (slow GPU memory)

    Memory reads/writes: O(N²)
    """

    flash_approach = """
    1. Load small blocks of Q, K, V into SRAM (fast on-chip memory)
    2. Compute attention for that block
    3. Update running statistics for softmax
    4. Accumulate result
    5. Never store full N x N matrix

    Memory reads/writes: O(N) for the attention matrix
    """

    why_it_works = """
    GPU has two types of memory:
    - HBM (High Bandwidth Memory): 80GB, 3 TB/s
    - SRAM (on-chip): 20MB per SM, 19 TB/s

    SRAM is 6x faster but 4000x smaller.

    Standard attention: constantly reading/writing HBM
    Flash Attention: does more compute in SRAM, fewer HBM trips

    The algorithm is mathematically identical.
    The implementation is memory-access optimized.
    """

The Tiling Strategy

def tiling_explanation():
    """
    How Flash Attention processes in blocks
    """
    return {
        "step_1": {
            "action": "Divide Q into blocks of size Br",
            "reason": "Each block fits in SRAM",
        },
        "step_2": {
            "action": "Divide K, V into blocks of size Bc",
            "reason": "Process attention in tiles",
        },
        "step_3": {
            "action": "For each Q block, iterate through all K,V blocks",
            "compute": "Partial attention scores",
        },
        "step_4": {
            "action": "Use online softmax trick",
            "compute": "Update running max and sum for numerical stability",
        },
        "step_5": {
            "action": "Accumulate output block",
            "result": "Never materialize N x N matrix",
        },
    }

Memory Savings

def memory_comparison():
    """
    Real memory savings from Flash Attention
    """
    return {
        "standard_attention": {
            "attention_matrix": "O(N²)",
            "32k_seq_per_head": "2 GB",
            "64_heads": "128 GB (just for attention!)",
        },
        "flash_attention": {
            "attention_matrix": "O(1) - never fully materialized",
            "block_size": "Typically 128 or 256",
            "memory_per_block": "128 * 128 * 2 = 32 KB",
            "total": "Much smaller working set",
        },
        "practical_impact": {
            "before": "32K context barely fits",
            "after": "128K context becomes feasible",
            "why": "Removed quadratic memory term",
        },
    }

Speed Improvement

def speed_improvement():
    """
    Why Flash Attention is faster (as a side effect)
    """
    return {
        "not_because": [
            "Fewer FLOPs (actually similar or more)",
            "Better algorithms for attention",
            "Approximations",
        ],
        "because": [
            "Fewer memory round trips",
            "Better cache utilization",
            "More compute per memory load",
            "Fused kernels (fewer kernel launches)",
        ],
        "typical_speedup": {
            "vs_standard_pytorch": "2-4x faster",
            "on_long_sequences": "Even larger speedup",
            "why_varies": "Depends on sequence length and hardware",
        },
    }

Implementation Details

def flash_attention_usage():
    """
    How to use Flash Attention
    """
    pytorch = """
    # PyTorch 2.0+ has built-in Flash Attention
    import torch.nn.functional as F

    # Automatically uses Flash Attention when possible
    output = F.scaled_dot_product_attention(
        query, key, value,
        is_causal=True,  # For autoregressive decoding
    )
    """

    explicit = """
    # Explicit Flash Attention via flash-attn library
    from flash_attn import flash_attn_func

    output = flash_attn_func(
        q, k, v,
        causal=True,
        softmax_scale=1.0 / math.sqrt(head_dim),
    )
    """

    vllm = """
    # vLLM uses Flash Attention by default
    # No configuration needed
    # Falls back to other implementations if not available
    """

    return pytorch, explicit, vllm

When Flash Attention Helps Most

def when_flash_helps():
    return {
        "high_impact": [
            {
                "scenario": "Long context (>4K tokens)",
                "why": "Quadratic memory savings compound",
            },
            {
                "scenario": "Training (backward pass)",
                "why": "Recomputation is cheaper than storing",
            },
            {
                "scenario": "Memory-constrained serving",
                "why": "Enables sequences that wouldn't fit",
            },
        ],
        "moderate_impact": [
            {
                "scenario": "Short context serving",
                "why": "Memory isn't the bottleneck anyway",
            },
            {
                "scenario": "Already memory-efficient setup",
                "why": "Smaller relative improvement",
            },
        ],
        "note": """
        Flash Attention for inference mostly helps with memory.
        The speed improvement is nice but secondary.
        For short sequences, the difference is minimal.
        """
    }

What Flash Attention Doesn't Do

def common_misconceptions():
    return [
        {
            "myth": "Flash Attention approximates attention",
            "truth": "It's mathematically identical, just computed differently",
        },
        {
            "myth": "Flash Attention reduces FLOPs",
            "truth": "Same or more FLOPs, fewer memory operations",
        },
        {
            "myth": "Flash Attention helps with KV cache",
            "truth": "KV cache is separate. Flash Attention helps with attention computation",
        },
        {
            "myth": "Flash Attention always makes things faster",
            "truth": "For short sequences, overhead can outweigh benefits",
        },
    ]

Flash Attention is a memory access optimization that happens to speed things up. Understanding this distinction matters: if you're optimizing for speed on short sequences, Flash Attention isn't your answer. If you're trying to serve longer contexts, it's essential.