Back to Blog

Reducing KV Cache Size Without Quality Loss

ZIP files trade decompression time for disk space. Nobody compresses files they access every second, but compressing archives you open monthly makes sense. The trade-off depends on access patterns and resource constraints.

KV cache compression follows the same logic. It trades decompression compute for memory savings. When memory is the bottleneck (long contexts, high concurrency), the trade is worth it. When compute is the bottleneck, it's not.

KV Cache Memory Reality

def kv_cache_memory():
    """
    Understanding KV cache memory consumption
    """
    return {
        "formula": """
            KV cache size = 2 * num_layers * seq_len * hidden_dim * dtype_size

            For 70B model (80 layers, 8192 hidden, FP16):
            Per token: 2 * 80 * 8192 * 2 bytes = 2.6 MB

            At 32K context: 2.6 MB * 32K = 85 GB
            At 128K context: 2.6 MB * 128K = 341 GB
        """,

        "proportion": """
            For 70B model:
            - Model weights: 140 GB
            - KV cache (32K): 85 GB (38% of total)
            - KV cache (128K): 341 GB (71% of total)

            KV cache dominates memory at long contexts.
        """,
    }

Compression Approaches

def compression_approaches():
    return {
        "quantization": {
            "method": "Store KV in INT8/INT4 instead of FP16",
            "savings": "2x for INT8, 4x for INT4",
            "quality_impact": "Usually < 1% for INT8",
            "implementation": """
                # Quantize K and V before storing
                k_quantized = quantize(k, dtype=torch.int8)
                v_quantized = quantize(v, dtype=torch.int8)

                # Dequantize when attending
                k = dequantize(k_quantized)
                v = dequantize(v_quantized)
            """,
        },

        "token_pruning": {
            "method": "Remove low-attention tokens from cache",
            "savings": "Variable (typically 30-50%)",
            "quality_impact": "Depends on pruning strategy",
            "challenge": "Which tokens to prune without hurting quality?",
        },

        "grouped_query_attention": {
            "method": "Share K,V heads across multiple Q heads",
            "savings": "4-8x on KV cache",
            "quality_impact": "Trained into model, minimal if done well",
            "examples": "LLaMA 2 70B uses GQA with 8 KV heads",
        },

        "sliding_window": {
            "method": "Only keep recent N tokens",
            "savings": "Linear instead of growing",
            "quality_impact": "Loses distant context",
            "use_case": "When only recent context matters",
        },

        "hierarchical_compression": {
            "method": "Compress old tokens more aggressively",
            "savings": "Dynamic based on age",
            "idea": "Recent = full precision, old = compressed",
        },
    }

KV Cache Quantization

class KVCacheQuantization:
    """
    Implement KV cache quantization
    """

    def __init__(self, precision: str = "int8"):
        self.precision = precision
        self.scales = {}  # Per-tensor scales for dequantization

    def quantize_kv(
        self,
        k: torch.Tensor,
        v: torch.Tensor,
        layer_id: int
    ) -> tuple:
        """Quantize K and V tensors"""
        k_q, k_scale = self.quantize_tensor(k)
        v_q, v_scale = self.quantize_tensor(v)

        # Store scales for dequantization
        self.scales[(layer_id, "k")] = k_scale
        self.scales[(layer_id, "v")] = v_scale

        return k_q, v_q

    def quantize_tensor(self, tensor: torch.Tensor) -> tuple:
        """Quantize single tensor with scale"""
        if self.precision == "int8":
            scale = tensor.abs().max() / 127
            quantized = (tensor / scale).round().to(torch.int8)
        elif self.precision == "int4":
            scale = tensor.abs().max() / 7
            quantized = (tensor / scale).round().clamp(-8, 7).to(torch.int8)  # Pack later
        return quantized, scale

    def dequantize_kv(
        self,
        k_q: torch.Tensor,
        v_q: torch.Tensor,
        layer_id: int
    ) -> tuple:
        """Dequantize K and V for attention"""
        k_scale = self.scales[(layer_id, "k")]
        v_scale = self.scales[(layer_id, "v")]

        k = k_q.float() * k_scale
        v = v_q.float() * v_scale

        return k, v

Token Pruning Strategies

def token_pruning_strategies():
    return {
        "attention_based": {
            "method": "Remove tokens with consistently low attention",
            "implementation": """
                def prune_low_attention(kv_cache, attention_scores, threshold=0.01):
                    # Average attention received by each token
                    avg_attention = attention_scores.mean(dim=(0, 1, 2))

                    # Keep tokens above threshold
                    keep_mask = avg_attention > threshold
                    return kv_cache[:, :, keep_mask, :]
            """,
            "risk": "Important but rarely-attended tokens get pruned",
        },

        "recency_based": {
            "method": "Keep recent tokens, prune old ones",
            "implementation": """
                def sliding_window_prune(kv_cache, window_size):
                    if kv_cache.size(2) > window_size:
                        return kv_cache[:, :, -window_size:, :]
                    return kv_cache
            """,
            "risk": "Loses important old context",
        },

        "importance_scoring": {
            "method": "Score tokens by multiple criteria, prune low scores",
            "criteria": [
                "Attention received",
                "Token type (special tokens kept)",
                "Semantic importance",
                "Recency",
            ],
            "implementation": """
                def importance_score(token, attention, position, seq_len):
                    score = 0
                    score += attention * 0.4  # Attention weight
                    score += (position / seq_len) * 0.3  # Recency
                    score += is_special_token(token) * 0.3  # Token type
                    return score
            """,
        },

        "h2o_approach": {
            "method": "Heavy-Hitter Oracle: keep frequently attended tokens",
            "insight": "A few tokens receive most attention (attention sinks)",
            "implementation": "Track cumulative attention, keep top tokens",
        },
    }

When to Use Compression

def when_to_compress():
    return {
        "use_compression_when": [
            "Memory-bound (long contexts, high concurrency)",
            "KV cache is >30% of GPU memory",
            "Willing to trade ~5% compute for 50%+ memory",
            "Quality impact is acceptable (<2% degradation)",
        ],

        "skip_compression_when": [
            "Compute-bound workload",
            "Short contexts (KV cache is small anyway)",
            "Quality is critical (no tolerance for degradation)",
            "Already using other optimizations (GQA, etc.)",
        ],

        "decision_matrix": """
            Context  | Concurrency | Memory pressure | Recommendation
            ---------|-------------|-----------------|----------------
            < 8K     | Low         | Low             | Don't compress
            < 8K     | High        | Medium          | Maybe INT8
            8K-32K   | Any         | Medium          | INT8
            > 32K    | Any         | High            | INT8 + pruning
        """,
    }

Quality Impact Measurement

def measure_quality_impact():
    return {
        "metrics": {
            "perplexity_change": "Should be < 1%",
            "task_accuracy_change": "Should be < 2%",
            "retrieval_accuracy": "Test with needle-in-haystack",
        },

        "testing_approach": """
            def evaluate_compression_quality(model, compression_config, test_set):
                baseline_results = evaluate(model, compression=None)
                compressed_results = evaluate(model, compression=compression_config)

                return {
                    "perplexity_delta": compressed_results.ppl - baseline_results.ppl,
                    "accuracy_delta": compressed_results.accuracy - baseline_results.accuracy,
                    "latency_change": compressed_results.latency / baseline_results.latency,
                }
        """,

        "acceptable_thresholds": {
            "int8_quantization": {
                "perplexity_increase": "< 0.5%",
                "accuracy_decrease": "< 1%",
            },
            "int4_quantization": {
                "perplexity_increase": "< 2%",
                "accuracy_decrease": "< 3%",
            },
            "token_pruning_50pct": {
                "perplexity_increase": "< 3%",
                "accuracy_decrease": "< 5%",
            },
        },
    }

Implementation with vLLM

def vllm_kv_compression():
    return {
        "kv_cache_dtype": """
            # vLLM supports KV cache quantization
            from vllm import LLM

            llm = LLM(
                model="your-model",
                kv_cache_dtype="fp8",  # or "int8"
            )

            # Automatically handles quantization/dequantization
        """,

        "memory_savings": """
            FP16 -> FP8: 50% memory savings
            FP16 -> INT8: 50% memory savings

            For 70B model at 32K context:
            FP16: 85 GB KV cache
            FP8:  42 GB KV cache
        """,

        "quality_note": """
            FP8 KV cache is well-supported on H100.
            Quality loss is typically negligible (<0.5% perplexity).
            This is the easiest win for KV cache compression.
        """,
    }

KV cache compression is about buying memory with compute. When memory is your constraint (long contexts, many concurrent requests), the trade is worth it. When compute is your constraint, it's not. Measure both before and after to confirm the trade works for your workload.