How Speculative Decoding Works
Branch prediction in CPUs guesses which way conditionals will go, executing speculatively. If the guess is right, you've saved cycles. If wrong, you discard the work and continue normally. The cost of wrong guesses is bounded; the benefit of right guesses is real.
Speculative decoding applies the same principle to LLMs. A small draft model quickly generates candidate tokens. The large target model verifies them in parallel. Correct predictions are accepted instantly. Wrong predictions are regenerated. When the draft model predicts well, generation accelerates dramatically.
The Core Idea
def speculative_decoding_concept():
return {
"standard_decoding": """
For each token:
1. Run large model (slow, e.g., 50ms)
2. Get one token
3. Repeat
Time for 100 tokens: 100 * 50ms = 5 seconds
""",
"speculative_decoding": """
Draft phase (small model):
1. Run small model 5 times (fast, e.g., 5ms each)
2. Get 5 candidate tokens: [t1, t2, t3, t4, t5]
Verify phase (large model):
1. Run large model ONCE on all 5 tokens in parallel
2. Verify which predictions match
3. Accept matching prefix, reject rest
If all 5 match: Got 5 tokens in time of 1 large + 5 small
= 50ms + 25ms = 75ms instead of 250ms (3.3x speedup)
""",
"key_insight": """
Large model verification is parallelizable.
Verifying 5 tokens costs same as generating 1.
(Because attention over K+5 tokens ≈ attention over K tokens)
""",
}
The Algorithm
def speculative_decoding_algorithm():
return """
def speculative_decode(prompt, draft_model, target_model, gamma=5):
'''
gamma: number of tokens to speculate
'''
output_tokens = []
while not done:
# Draft phase: generate gamma candidates
draft_tokens = []
draft_probs = []
for _ in range(gamma):
token, prob = draft_model.generate_one(
prompt + output_tokens + draft_tokens
)
draft_tokens.append(token)
draft_probs.append(prob)
# Verify phase: check all candidates at once
target_probs = target_model.get_probs(
prompt + output_tokens,
draft_tokens
)
# Accept matching tokens
accepted = 0
for i in range(gamma):
# Acceptance probability
r = random.random()
if r < min(1, target_probs[i] / draft_probs[i]):
output_tokens.append(draft_tokens[i])
accepted += 1
else:
# Reject this and all following
break
# If all rejected, sample one from target
if accepted < gamma:
# Sample from adjusted distribution
new_token = sample_from_residual(
target_probs[accepted],
draft_probs[accepted]
)
output_tokens.append(new_token)
return output_tokens
"""
Why Verification is Cheap
def verification_efficiency():
return {
"single_token_generation": {
"steps": [
"Read KV cache for all prior tokens",
"Compute attention for new token",
"Run FFN",
"Sample output",
],
"bottleneck": "Reading KV cache (memory-bound)",
},
"multi_token_verification": {
"steps": [
"Read KV cache for all prior tokens",
"Compute attention for N new tokens (batched)",
"Run FFN (batched)",
"Get probabilities for all N",
],
"key_insight": """
KV cache read is the bottleneck.
Reading KV cache once for N tokens ≈ reading once for 1 token.
The verification of N tokens is ~same cost as generating 1.
""",
},
"mathematical": """
Time to generate 1 token: T_gen
Time to verify N tokens: T_verify ≈ T_gen
Speedup potential = N tokens / (T_draft * N + T_verify)
≈ N / (T_draft * N + T_gen)
If T_draft << T_gen and N tokens accepted:
Speedup ≈ N
""",
}
Draft Model Selection
def draft_model_selection():
return {
"requirements": {
"fast": "Much faster than target (5-10x)",
"accurate": "Predicts target's outputs well",
"compatible": "Same vocabulary and tokenization",
},
"options": {
"same_family_smaller": {
"example": "LLaMA-7B draft for LLaMA-70B target",
"pros": "Same training distribution, good match",
"cons": "Still requires separate model load",
},
"distilled_from_target": {
"example": "Custom small model distilled from target",
"pros": "Optimized for prediction accuracy",
"cons": "Requires distillation effort",
},
"early_exit": {
"example": "Use first 4 layers of target model",
"pros": "No additional model, shared weights",
"cons": "Lower prediction quality",
},
"n_gram_model": {
"example": "Statistical model from corpus",
"pros": "Very fast, no GPU needed",
"cons": "Lower prediction quality",
},
},
"acceptance_rate_targets": """
Good draft model: 70-90% acceptance rate
Mediocre draft: 40-60% acceptance rate
Poor draft: <40% acceptance rate (may hurt overall)
Measure on your actual workload to validate.
""",
}
Measuring Speedup
class SpeculativeSpeedupMeasurement:
"""
Measure actual speedup from speculative decoding
"""
def measure_acceptance_rate(
self,
draft_model,
target_model,
test_prompts: list,
gamma: int = 5
) -> dict:
"""Measure how well draft predicts target"""
total_proposed = 0
total_accepted = 0
for prompt in test_prompts:
proposed, accepted = self.run_speculation(
draft_model, target_model, prompt, gamma
)
total_proposed += proposed
total_accepted += accepted
acceptance_rate = total_accepted / total_proposed
return {
"acceptance_rate": acceptance_rate,
"avg_accepted_per_round": acceptance_rate * gamma,
"expected_speedup": self.calculate_speedup(acceptance_rate, gamma),
}
def calculate_speedup(
self,
acceptance_rate: float,
gamma: int,
draft_time_ratio: float = 0.1 # draft is 10x faster
) -> float:
"""Calculate expected speedup"""
# Expected accepted tokens per round
expected_accepted = sum(
acceptance_rate ** i for i in range(1, gamma + 1)
)
# Time for standard decoding: gamma tokens
standard_time = gamma
# Time for speculative: draft + verify + rejected sampling
spec_time = draft_time_ratio * gamma + 1 + (1 - acceptance_rate)
return expected_accepted / spec_time
Implementation Considerations
def implementation_considerations():
return {
"memory": {
"challenge": "Need both models in memory",
"solutions": [
"Draft on CPU, target on GPU",
"Share layers between draft and target",
"Use very small draft (1-2B params)",
],
},
"batching": {
"challenge": "Speculative decoding is per-request",
"interaction": "Can batch multiple speculative requests",
"note": "Doesn't help throughput as much as latency",
},
"variable_acceptance": {
"challenge": "Different prompts have different acceptance rates",
"solution": "Adaptive gamma based on recent acceptance",
"implementation": """
if recent_acceptance > 0.8:
gamma = min(gamma + 1, max_gamma)
elif recent_acceptance < 0.5:
gamma = max(gamma - 1, min_gamma)
""",
},
"kv_cache_management": {
"challenge": "Both models need KV cache",
"solution": "Draft model cache is small, target cache normal",
"optimization": "Can discard draft cache after verification",
},
}
When Speculative Decoding Helps
def when_speculative_helps():
return {
"ideal_conditions": [
"Target model is much slower than draft",
"High acceptance rate (>70%)",
"Latency-sensitive, not throughput-focused",
"Enough memory for both models",
],
"helps_most": {
"predictable_outputs": "Code completion, structured generation",
"stylistic_similarity": "Draft trained on similar data",
"repetitive_patterns": "Templates, boilerplate",
},
"helps_least": {
"creative_generation": "Low predictability",
"rare_tokens": "Draft model struggles",
"short_outputs": "Overhead not amortized",
},
"typical_results": """
Good scenario: 2-3x latency reduction
Average scenario: 1.5-2x latency reduction
Poor scenario: No improvement or slight slowdown
""",
}
Speculative decoding is probabilistic optimization. It bets that a small model can predict what a large model will say. When the bet pays off, you get multiple tokens for the price of one. When it doesn't, you've lost only the time spent drafting. The expected value is positive when draft predictions are good.