Back to Blog

How Speculative Decoding Works

Branch prediction in CPUs guesses which way conditionals will go, executing speculatively. If the guess is right, you've saved cycles. If wrong, you discard the work and continue normally. The cost of wrong guesses is bounded; the benefit of right guesses is real.

Speculative decoding applies the same principle to LLMs. A small draft model quickly generates candidate tokens. The large target model verifies them in parallel. Correct predictions are accepted instantly. Wrong predictions are regenerated. When the draft model predicts well, generation accelerates dramatically.

The Core Idea

def speculative_decoding_concept():
    return {
        "standard_decoding": """
            For each token:
            1. Run large model (slow, e.g., 50ms)
            2. Get one token
            3. Repeat

            Time for 100 tokens: 100 * 50ms = 5 seconds
        """,

        "speculative_decoding": """
            Draft phase (small model):
            1. Run small model 5 times (fast, e.g., 5ms each)
            2. Get 5 candidate tokens: [t1, t2, t3, t4, t5]

            Verify phase (large model):
            1. Run large model ONCE on all 5 tokens in parallel
            2. Verify which predictions match
            3. Accept matching prefix, reject rest

            If all 5 match: Got 5 tokens in time of 1 large + 5 small
            = 50ms + 25ms = 75ms instead of 250ms (3.3x speedup)
        """,

        "key_insight": """
            Large model verification is parallelizable.
            Verifying 5 tokens costs same as generating 1.
            (Because attention over K+5 tokens ≈ attention over K tokens)
        """,
    }

The Algorithm

def speculative_decoding_algorithm():
    return """
    def speculative_decode(prompt, draft_model, target_model, gamma=5):
        '''
        gamma: number of tokens to speculate
        '''
        output_tokens = []

        while not done:
            # Draft phase: generate gamma candidates
            draft_tokens = []
            draft_probs = []
            for _ in range(gamma):
                token, prob = draft_model.generate_one(
                    prompt + output_tokens + draft_tokens
                )
                draft_tokens.append(token)
                draft_probs.append(prob)

            # Verify phase: check all candidates at once
            target_probs = target_model.get_probs(
                prompt + output_tokens,
                draft_tokens
            )

            # Accept matching tokens
            accepted = 0
            for i in range(gamma):
                # Acceptance probability
                r = random.random()
                if r < min(1, target_probs[i] / draft_probs[i]):
                    output_tokens.append(draft_tokens[i])
                    accepted += 1
                else:
                    # Reject this and all following
                    break

            # If all rejected, sample one from target
            if accepted < gamma:
                # Sample from adjusted distribution
                new_token = sample_from_residual(
                    target_probs[accepted],
                    draft_probs[accepted]
                )
                output_tokens.append(new_token)

        return output_tokens
    """

Why Verification is Cheap

def verification_efficiency():
    return {
        "single_token_generation": {
            "steps": [
                "Read KV cache for all prior tokens",
                "Compute attention for new token",
                "Run FFN",
                "Sample output",
            ],
            "bottleneck": "Reading KV cache (memory-bound)",
        },

        "multi_token_verification": {
            "steps": [
                "Read KV cache for all prior tokens",
                "Compute attention for N new tokens (batched)",
                "Run FFN (batched)",
                "Get probabilities for all N",
            ],
            "key_insight": """
                KV cache read is the bottleneck.
                Reading KV cache once for N tokens ≈ reading once for 1 token.
                The verification of N tokens is ~same cost as generating 1.
            """,
        },

        "mathematical": """
            Time to generate 1 token: T_gen
            Time to verify N tokens: T_verify ≈ T_gen

            Speedup potential = N tokens / (T_draft * N + T_verify)
                              ≈ N / (T_draft * N + T_gen)

            If T_draft << T_gen and N tokens accepted:
            Speedup ≈ N
        """,
    }

Draft Model Selection

def draft_model_selection():
    return {
        "requirements": {
            "fast": "Much faster than target (5-10x)",
            "accurate": "Predicts target's outputs well",
            "compatible": "Same vocabulary and tokenization",
        },

        "options": {
            "same_family_smaller": {
                "example": "LLaMA-7B draft for LLaMA-70B target",
                "pros": "Same training distribution, good match",
                "cons": "Still requires separate model load",
            },
            "distilled_from_target": {
                "example": "Custom small model distilled from target",
                "pros": "Optimized for prediction accuracy",
                "cons": "Requires distillation effort",
            },
            "early_exit": {
                "example": "Use first 4 layers of target model",
                "pros": "No additional model, shared weights",
                "cons": "Lower prediction quality",
            },
            "n_gram_model": {
                "example": "Statistical model from corpus",
                "pros": "Very fast, no GPU needed",
                "cons": "Lower prediction quality",
            },
        },

        "acceptance_rate_targets": """
            Good draft model: 70-90% acceptance rate
            Mediocre draft: 40-60% acceptance rate
            Poor draft: <40% acceptance rate (may hurt overall)

            Measure on your actual workload to validate.
        """,
    }

Measuring Speedup

class SpeculativeSpeedupMeasurement:
    """
    Measure actual speedup from speculative decoding
    """

    def measure_acceptance_rate(
        self,
        draft_model,
        target_model,
        test_prompts: list,
        gamma: int = 5
    ) -> dict:
        """Measure how well draft predicts target"""
        total_proposed = 0
        total_accepted = 0

        for prompt in test_prompts:
            proposed, accepted = self.run_speculation(
                draft_model, target_model, prompt, gamma
            )
            total_proposed += proposed
            total_accepted += accepted

        acceptance_rate = total_accepted / total_proposed
        return {
            "acceptance_rate": acceptance_rate,
            "avg_accepted_per_round": acceptance_rate * gamma,
            "expected_speedup": self.calculate_speedup(acceptance_rate, gamma),
        }

    def calculate_speedup(
        self,
        acceptance_rate: float,
        gamma: int,
        draft_time_ratio: float = 0.1  # draft is 10x faster
    ) -> float:
        """Calculate expected speedup"""
        # Expected accepted tokens per round
        expected_accepted = sum(
            acceptance_rate ** i for i in range(1, gamma + 1)
        )

        # Time for standard decoding: gamma tokens
        standard_time = gamma

        # Time for speculative: draft + verify + rejected sampling
        spec_time = draft_time_ratio * gamma + 1 + (1 - acceptance_rate)

        return expected_accepted / spec_time

Implementation Considerations

def implementation_considerations():
    return {
        "memory": {
            "challenge": "Need both models in memory",
            "solutions": [
                "Draft on CPU, target on GPU",
                "Share layers between draft and target",
                "Use very small draft (1-2B params)",
            ],
        },

        "batching": {
            "challenge": "Speculative decoding is per-request",
            "interaction": "Can batch multiple speculative requests",
            "note": "Doesn't help throughput as much as latency",
        },

        "variable_acceptance": {
            "challenge": "Different prompts have different acceptance rates",
            "solution": "Adaptive gamma based on recent acceptance",
            "implementation": """
                if recent_acceptance > 0.8:
                    gamma = min(gamma + 1, max_gamma)
                elif recent_acceptance < 0.5:
                    gamma = max(gamma - 1, min_gamma)
            """,
        },

        "kv_cache_management": {
            "challenge": "Both models need KV cache",
            "solution": "Draft model cache is small, target cache normal",
            "optimization": "Can discard draft cache after verification",
        },
    }

When Speculative Decoding Helps

def when_speculative_helps():
    return {
        "ideal_conditions": [
            "Target model is much slower than draft",
            "High acceptance rate (>70%)",
            "Latency-sensitive, not throughput-focused",
            "Enough memory for both models",
        ],

        "helps_most": {
            "predictable_outputs": "Code completion, structured generation",
            "stylistic_similarity": "Draft trained on similar data",
            "repetitive_patterns": "Templates, boilerplate",
        },

        "helps_least": {
            "creative_generation": "Low predictability",
            "rare_tokens": "Draft model struggles",
            "short_outputs": "Overhead not amortized",
        },

        "typical_results": """
            Good scenario: 2-3x latency reduction
            Average scenario: 1.5-2x latency reduction
            Poor scenario: No improvement or slight slowdown
        """,
    }

Speculative decoding is probabilistic optimization. It bets that a small model can predict what a large model will say. When the bet pays off, you get multiple tokens for the price of one. When it doesn't, you've lost only the time spent drafting. The expected value is positive when draft predictions are good.