Back to Blog

Understanding What Your Model Attends To

Heat maps in sports analytics show where players spend time on the field. The visualization reveals patterns invisible in raw data: why the midfielder is always out of position, why attacks fail from a specific zone. The map doesn't explain causation, but it pinpoints where to investigate.

Attention visualization works similarly. It shows which tokens the model attends to when generating each output token. When the model fails, attention maps reveal whether it saw the relevant context and ignored it, or never attended to it at all. The map guides debugging.

Extracting Attention Weights

def extract_attention():
    """
    How to get attention weights from a model
    """
    return {
        "transformers_approach": """
            from transformers import AutoModelForCausalLM

            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                output_attentions=True  # Enable attention output
            )

            outputs = model(input_ids, output_attentions=True)

            # attention_weights: tuple of (batch, heads, seq_len, seq_len)
            # One tensor per layer
            attention_weights = outputs.attentions
        """,

        "shape_explanation": """
            For a model with:
            - 32 layers
            - 32 attention heads
            - 1024 token sequence

            attention_weights is tuple of 32 tensors
            Each tensor: (batch, 32, 1024, 1024)

            attention_weights[layer][batch, head, query_pos, key_pos]
            = attention from query_pos to key_pos in that layer/head
        """,

        "memory_warning": """
            Storing all attention weights is expensive:
            32 layers * 32 heads * 1024 * 1024 * 4 bytes = ~4GB

            For analysis, sample specific layers/positions.
        """,
    }

Visualization Approaches

def visualization_approaches():
    return {
        "single_head_heatmap": {
            "what": "2D heatmap of query positions vs key positions",
            "use": "See which keys each query attends to",
            "code": """
                import matplotlib.pyplot as plt

                def plot_attention(attention, tokens, title):
                    fig, ax = plt.subplots(figsize=(10, 10))
                    ax.imshow(attention, cmap='viridis')
                    ax.set_xticks(range(len(tokens)))
                    ax.set_yticks(range(len(tokens)))
                    ax.set_xticklabels(tokens, rotation=90)
                    ax.set_yticklabels(tokens)
                    plt.title(title)
                    plt.colorbar()
            """,
        },

        "head_averaged": {
            "what": "Average attention across all heads",
            "use": "Overall attention pattern per layer",
            "code": "avg_attention = attention.mean(dim=1)",
        },

        "layer_specific": {
            "what": "Compare patterns across layers",
            "use": "See how attention evolves through model",
            "insight": "Early layers often show syntactic patterns, late layers semantic",
        },

        "token_specific": {
            "what": "Attention to/from specific token",
            "use": "Debug why specific output was generated",
            "code": "attention_to_token = attention[:, :, :, token_idx]",
        },
    }

Debugging with Attention

def debugging_with_attention():
    return {
        "model_ignored_context": {
            "symptom": "Model didn't use information that was in context",
            "investigation": """
                1. Find output token where error occurred
                2. Extract attention for that generation step
                3. Check attention to relevant context tokens
            """,
            "possible_findings": {
                "low_attention_to_relevant": "Context was seen but not attended",
                "high_attention_to_irrelevant": "Distracted by other tokens",
                "attention_to_wrong_position": "Position encoding issue",
            },
        },

        "model_hallucinated": {
            "symptom": "Model generated information not in context",
            "investigation": """
                1. Extract attention for hallucinated tokens
                2. Check what the model was attending to
                3. Look for suspicious patterns
            """,
            "possible_findings": {
                "attention_spread_thin": "No strong signal, guessing",
                "attention_to_similar_context": "Confused similar items",
                "self_attention_loop": "Attending to own generation",
            },
        },

        "model_repeated_itself": {
            "symptom": "Output contains repeated phrases",
            "investigation": """
                1. Extract attention during repeated section
                2. Check for attention loops
            """,
            "possible_findings": {
                "high_self_attention": "Model stuck on own output",
                "pattern_in_attention": "Repeating attention pattern",
            },
        },
    }

Attention Pattern Analysis

class AttentionAnalyzer:
    """
    Analyze attention patterns for debugging
    """

    def attention_entropy(self, attention: torch.Tensor) -> torch.Tensor:
        """Measure how spread out attention is"""
        # Low entropy = focused, high entropy = spread
        return -(attention * torch.log(attention + 1e-9)).sum(dim=-1)

    def find_attention_sinks(self, attention: torch.Tensor) -> list:
        """Find tokens that receive lots of attention"""
        avg_received = attention.mean(dim=(0, 1, 2))  # Average across batch, head, query
        threshold = avg_received.mean() + 2 * avg_received.std()
        sinks = (avg_received > threshold).nonzero().squeeze()
        return sinks.tolist()

    def trace_influence(
        self,
        attention: torch.Tensor,
        output_position: int,
        threshold: float = 0.1
    ) -> dict:
        """Trace which input tokens influenced an output"""
        # Sum attention from all heads for this output position
        attention_to_output = attention[:, :, output_position, :].mean(dim=(0, 1))

        influential = (attention_to_output > threshold).nonzero().squeeze()
        return {
            "influential_positions": influential.tolist(),
            "attention_values": attention_to_output[influential].tolist(),
        }

    def compare_attention_patterns(
        self,
        attention_1: torch.Tensor,
        attention_2: torch.Tensor
    ) -> dict:
        """Compare attention between two runs"""
        diff = (attention_1 - attention_2).abs()
        return {
            "max_diff": diff.max().item(),
            "mean_diff": diff.mean().item(),
            "positions_with_big_diff": (diff > 0.1).nonzero()[:10].tolist(),
        }

Practical Debugging Session

def debugging_session_example():
    return """
    Scenario: Model ignored a fact in context

    1. Setup:
       Input: "The capital of France is Paris. What is the capital of France?"
       Output: "I don't have that information."

    2. Extract attention:
       output_tokens = tokenize("I don't have that information.")
       input_tokens = tokenize("The capital of France is Paris.")

       with torch.no_grad():
           outputs = model(input_ids, output_attentions=True)
           attention = outputs.attentions[-1]  # Last layer often most interpretable

    3. Analyze attention for "information" token:
       info_position = 7  # Position of "information" in output
       attention_pattern = attention[0, :, info_position, :].mean(dim=0)

       # Plot attention over input
       plt.bar(range(len(input_tokens)), attention_pattern[:len(input_tokens)])
       plt.xticks(range(len(input_tokens)), input_tokens, rotation=45)

    4. Findings:
       - High attention to "What" (0.25)
       - High attention to "France" in question (0.30)
       - LOW attention to "Paris" (0.01)
       - LOW attention to "capital" in statement (0.02)

    5. Diagnosis:
       Model attended to the question but not the answer in context.
       Possible causes: Position too far, format not recognized as answer.

    6. Fix attempts:
       - Move answer closer to question
       - Add explicit marker: "Answer: Paris"
       - Format as Q&A style
    """

Limitations of Attention Analysis

def attention_limitations():
    return {
        "attention_is_not_explanation": {
            "caution": "Attention shows correlation, not causation",
            "example": "High attention to 'the' doesn't mean 'the' was important",
            "guidance": "Use as debugging tool, not proof of reasoning",
        },

        "aggregation_is_lossy": {
            "caution": "Averaging across heads loses information",
            "example": "Different heads attend to different things",
            "guidance": "Look at individual heads for detailed analysis",
        },

        "early_vs_late_layers": {
            "caution": "Attention meaning varies by layer",
            "example": "Early layers: syntax. Late layers: semantics.",
            "guidance": "Analyze multiple layers for full picture",
        },

        "softmax_obscures": {
            "caution": "Softmax concentrates attention",
            "example": "0.1 attention might still be significant",
            "guidance": "Compare relative attention, not absolute values",
        },
    }

Attention visualization is a debugging tool, not an explanation. It shows where the model looked, not why. When the model fails, attention maps help distinguish "didn't see it" from "saw but ignored it." That distinction guides the fix.