Back to Blog

Extending Context Beyond Training Length

Musicians learn songs in a specific key. Transposing to another key is possible but harder. The further from the original key, the more effort required to maintain quality. A half-step transposition is easy; an octave is challenging.

Context extension works similarly. Models learn positional patterns at training length. Extending to longer contexts requires scaling those patterns. Small extensions work well. Large extensions degrade quality predictably. Understanding the mechanics helps you find the viable range.

How Position Extension Works

def position_extension_basics():
    return {
        "the_problem": """
            Model trained on 4K context.
            Position 4001 was never seen during training.
            What happens when you pass position 8000?

            Without modification: Undefined behavior, usually bad.
            With scaling: Map new positions to trained range.
        """,

        "rope_mechanism": """
            RoPE encodes position by rotating query and key vectors.
            Rotation frequency varies by dimension.

            Position 1000: Rotate by 1000 * frequency
            Position 8000: Rotate by 8000 * frequency

            If 8000 * frequency > trained range, extrapolation fails.
        """,

        "scaling_approaches": {
            "position_interpolation": "Shrink positions to fit trained range",
            "ntk_scaling": "Adjust frequencies to cover larger range",
            "yarn": "Hybrid of both approaches",
        },
    }

Position Interpolation

def position_interpolation():
    """
    Linear position interpolation (PI)
    """
    return {
        "mechanism": """
            Scale positions to fit within training range.

            Original: Position 8000 -> rotation_8000
            With PI:  Position 8000 -> rotation_4000 (scaled by 0.5)

            scale_factor = training_length / target_length
            scaled_position = position * scale_factor
        """,

        "implementation": """
            class ScaledRoPE(nn.Module):
                def __init__(self, base_length, target_length):
                    self.scale = base_length / target_length

                def forward(self, positions):
                    scaled = positions * self.scale
                    return self.compute_rotations(scaled)
        """,

        "tradeoffs": {
            "pro": "Simple, works out of the box",
            "con": "Reduces resolution between nearby positions",
            "effect": "Positions 1 and 2 become less distinguishable",
        },

        "typical_results": {
            "2x_extension": "Minimal quality loss (~2%)",
            "4x_extension": "Noticeable quality loss (~5-10%)",
            "8x_extension": "Significant quality loss (~15-20%)",
        },
    }

NTK-Aware Scaling

def ntk_scaling():
    """
    Neural Tangent Kernel-aware scaling
    """
    return {
        "mechanism": """
            Instead of scaling positions, scale frequencies.

            Low-frequency dimensions: Carry long-range info, keep stable
            High-frequency dimensions: Carry local info, can scale

            Adjust rotation frequencies non-uniformly.
        """,

        "benefit": """
            Preserves local relationship resolution.
            Position 1 vs 2 remains distinguishable.
            Long-range positions covered by adjusted frequencies.
        """,

        "implementation": """
            class NTKScaledRoPE(nn.Module):
                def __init__(self, base_length, target_length, base=10000):
                    scale = target_length / base_length
                    self.base = base * scale ** (dim / (dim - 2))

                def compute_frequencies(self):
                    return 1.0 / (self.base ** (torch.arange(0, dim, 2) / dim))
        """,

        "typical_results": {
            "2x_extension": "~1% quality loss",
            "4x_extension": "~3-5% quality loss",
            "8x_extension": "~8-12% quality loss",
        },
    }

YaRN: Yet another RoPE extensioN

def yarn_scaling():
    """
    Combines interpolation and frequency scaling
    """
    return {
        "key_insight": """
            Different dimensions benefit from different strategies.

            High-frequency dims: Interpolation is fine
            Low-frequency dims: Need frequency adjustment
            Middle dims: Blend of both
        """,

        "components": {
            "ntk_interpolation": "Adjust frequencies adaptively",
            "attention_scaling": "Compensate for attention entropy change",
            "fine_tuning": "Brief continued pretraining at target length",
        },

        "results": """
            State-of-the-art context extension.

            4K -> 128K demonstrated with minimal quality loss.
            Requires brief fine-tuning (~1000 steps).
            Better than pure interpolation or pure NTK.
        """,

        "when_to_use": """
            If extending 4x or more, YaRN is worth the complexity.
            For 2x extensions, simple interpolation often sufficient.
        """,
    }

Practical Extension Guide

def extension_guide():
    return {
        "small_extension_2x": {
            "approach": "Position interpolation",
            "implementation": "Simple scaling factor",
            "fine_tuning": "Not required",
            "quality_impact": "Minimal",
            "use_case": "4K -> 8K, quick solution",
        },

        "medium_extension_4x": {
            "approach": "NTK scaling",
            "implementation": "Adjust base frequency",
            "fine_tuning": "Recommended (100-500 steps)",
            "quality_impact": "Moderate",
            "use_case": "4K -> 16K, production use",
        },

        "large_extension_8x_plus": {
            "approach": "YaRN or equivalent",
            "implementation": "Full method with tuning",
            "fine_tuning": "Required (1000+ steps)",
            "quality_impact": "Noticeable",
            "use_case": "4K -> 32K+, long document needs",
        },

        "decision_process": """
            1. What's your training length?
            2. What's your target length?
            3. How much quality loss is acceptable?
            4. Can you fine-tune briefly?

            Extension ratio | Method      | Fine-tune?
            ----------------|-------------|------------
            <= 2x           | PI          | No
            2-4x            | NTK         | Recommended
            4-8x            | YaRN        | Required
            > 8x            | YaRN        | Required + validate heavily
        """,
    }

Quality Measurement

class ExtensionQualityMeasurement:
    """
    Measure quality at extended lengths
    """

    def perplexity_by_position(self, model, text: str, chunk_size: int = 1024) -> dict:
        """Measure perplexity at different positions"""
        results = {}

        for position in range(0, len(text), chunk_size):
            chunk = text[position:position + chunk_size]
            ppl = self.compute_perplexity(model, text[:position], chunk)
            results[position] = ppl

        return results

    def retrieval_accuracy_by_position(
        self,
        model,
        context_length: int,
        positions: list
    ) -> dict:
        """Test fact retrieval at different positions"""
        results = {}

        for pos_ratio in positions:
            accuracy = self.needle_in_haystack(
                model,
                context_length,
                position=int(context_length * pos_ratio)
            )
            results[pos_ratio] = accuracy

        return results

    def find_effective_length(self, model, quality_threshold: float = 0.9) -> int:
        """Find length where quality drops below threshold"""
        lengths = [4096, 8192, 16384, 32768, 65536]

        for length in lengths:
            quality = self.measure_quality(model, length)
            if quality < quality_threshold:
                return lengths[lengths.index(length) - 1]

        return lengths[-1]

Common Pitfalls

def extension_pitfalls():
    return {
        "no_quality_testing": {
            "mistake": "Assume extension works because no errors",
            "reality": "Quality degrades silently",
            "fix": "Measure perplexity and task quality at target length",
        },

        "ignoring_position_decay": {
            "mistake": "Expect uniform quality across all positions",
            "reality": "Middle positions still degraded",
            "fix": "Test retrieval at all positions, not just ends",
        },

        "skip_fine_tuning": {
            "mistake": "Use scaling without any fine-tuning",
            "reality": "Brief fine-tuning dramatically improves quality",
            "fix": "Fine-tune 100-1000 steps on target length data",
        },

        "wrong_base_frequency": {
            "mistake": "Use default RoPE base when extending",
            "reality": "Need to adjust for extension ratio",
            "fix": "Set rope_theta appropriately for your extension",
        },
    }

Configuration Examples

def configuration_examples():
    return {
        "llama_4k_to_8k": {
            "method": "Position interpolation",
            "config": {
                "rope_scaling": {
                    "type": "linear",
                    "factor": 2.0,
                }
            },
            "fine_tuning": "Optional",
        },

        "llama_4k_to_32k": {
            "method": "NTK + fine-tuning",
            "config": {
                "rope_scaling": {
                    "type": "dynamic",
                    "factor": 8.0,
                },
                "rope_theta": 80000,  # Adjusted base
            },
            "fine_tuning": "Required, ~500 steps",
        },

        "llama_4k_to_128k": {
            "method": "YaRN",
            "config": {
                "rope_scaling": {
                    "type": "yarn",
                    "factor": 32.0,
                    "attention_factor": 0.1,
                }
            },
            "fine_tuning": "Required, ~1000 steps",
        },
    }

Context extension trades quality for length. Small extensions (2x) are nearly free. Large extensions (8x+) require careful tuning and quality validation. The viable range depends on your quality requirements and willingness to fine-tune.