Back to Blog

When to Use Self-Attention vs Cross-Attention

A book club discussing one book uses internal references: "remember chapter 3?" A book club comparing two books uses external references: "how does this differ from the other book?" Same discussion mechanism, different information sources.

Self-attention and cross-attention follow the same pattern. Self-attention processes one sequence, relating positions within it. Cross-attention relates two different sequences, letting one attend to the other. The mechanism is identical; the inputs differ.

The Mathematical Difference

def attention_mechanics():
    return {
        "self_attention": {
            "inputs": "One sequence X",
            "computation": """
                Q = X @ W_q  # Queries from X
                K = X @ W_k  # Keys from X
                V = X @ W_v  # Values from X
                output = softmax(Q @ K^T / sqrt(d)) @ V
            """,
            "what_happens": "Each position attends to all positions in same sequence",
        },

        "cross_attention": {
            "inputs": "Two sequences: X (query source) and Y (context)",
            "computation": """
                Q = X @ W_q  # Queries from X
                K = Y @ W_k  # Keys from Y (different sequence!)
                V = Y @ W_v  # Values from Y
                output = softmax(Q @ K^T / sqrt(d)) @ V
            """,
            "what_happens": "Each position in X attends to all positions in Y",
        },

        "key_difference": """
            Self: Q, K, V all from same source
            Cross: Q from one source, K and V from another
        """,
    }

When to Use Self-Attention

def self_attention_use_cases():
    return {
        "language_modeling": {
            "task": "Predict next token in sequence",
            "why_self": "Each token needs context from previous tokens",
            "example": "GPT, LLaMA, Claude (decoder-only)",
        },

        "text_understanding": {
            "task": "Understand relationships within text",
            "why_self": "Words relate to other words in same text",
            "example": "BERT embeddings, sentiment analysis",
        },

        "sequence_classification": {
            "task": "Classify entire sequence",
            "why_self": "Need to integrate information across sequence",
            "example": "Spam detection, intent classification",
        },

        "general_principle": """
            Use self-attention when:
            - Processing a single sequence
            - Positions within sequence need to interact
            - No external context required
        """,
    }

When to Use Cross-Attention

def cross_attention_use_cases():
    return {
        "machine_translation": {
            "task": "Translate source language to target",
            "why_cross": "Target generation needs to attend to source",
            "architecture": """
                Encoder: Self-attention over source
                Decoder: Self-attention over target
                       + Cross-attention to encoder output
            """,
        },

        "retrieval_augmented_generation": {
            "task": "Generate using retrieved documents",
            "why_cross": "Generation attends to retrieved context",
            "architecture": """
                Retrieved docs: Encoded separately
                Generator: Cross-attends to retrieved encodings
            """,
        },

        "multimodal": {
            "task": "Process image + text together",
            "why_cross": "Text attends to image features",
            "architecture": """
                Image encoder: Self-attention over patches
                Text decoder: Cross-attention to image features
            """,
        },

        "document_qa": {
            "task": "Answer questions about a document",
            "why_cross": "Question attends to document",
            "architecture": """
                Document: Encoded with self-attention
                Question: Cross-attends to document
            """,
        },

        "general_principle": """
            Use cross-attention when:
            - Two distinct sequences need to interact
            - One sequence conditions on another
            - Fusion of different information sources
        """,
    }

Architecture Patterns

def architecture_patterns():
    return {
        "encoder_only": {
            "structure": "Self-attention throughout",
            "use_case": "Understanding, embeddings, classification",
            "example": "BERT",
        },

        "decoder_only": {
            "structure": "Causal self-attention throughout",
            "use_case": "Generation, language modeling",
            "example": "GPT, LLaMA",
            "note": "No cross-attention, context in prompt",
        },

        "encoder_decoder": {
            "structure": """
                Encoder: Self-attention (bidirectional)
                Decoder: Self-attention (causal) + Cross-attention
            """,
            "use_case": "Sequence-to-sequence tasks",
            "example": "T5, BART, original Transformer",
        },

        "decoder_with_retrieval": {
            "structure": """
                Retriever: Encodes documents
                Decoder: Self-attention + Cross-attention to retrieved
            """,
            "use_case": "Knowledge-intensive generation",
            "example": "RETRO, RAG",
        },
    }

Implementation Differences

class AttentionImplementation:
    """
    Implementing both attention types
    """

    def self_attention(
        self,
        x: torch.Tensor,  # [batch, seq_len, hidden]
    ) -> torch.Tensor:
        """Standard self-attention"""
        Q = self.q_proj(x)  # [batch, seq_len, hidden]
        K = self.k_proj(x)  # Same shape, same source
        V = self.v_proj(x)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn = torch.softmax(scores, dim=-1)
        output = torch.matmul(attn, V)

        return output

    def cross_attention(
        self,
        query_seq: torch.Tensor,    # [batch, query_len, hidden]
        context_seq: torch.Tensor,  # [batch, context_len, hidden]
    ) -> torch.Tensor:
        """Cross-attention from query to context"""
        Q = self.q_proj(query_seq)   # Queries from query sequence
        K = self.k_proj(context_seq)  # Keys from context (different!)
        V = self.v_proj(context_seq)  # Values from context

        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn = torch.softmax(scores, dim=-1)
        output = torch.matmul(attn, V)

        return output  # [batch, query_len, hidden]

Memory and Compute Implications

def memory_compute_comparison():
    return {
        "self_attention": {
            "memory": "O(n²) for sequence length n",
            "compute": "O(n² * d) for hidden dim d",
            "kv_cache": "Stores K, V for all positions processed",
        },

        "cross_attention": {
            "memory": "O(n * m) for query length n, context length m",
            "compute": "O(n * m * d)",
            "kv_cache": "Context K, V can be precomputed and reused",
        },

        "optimization_opportunity": """
            Cross-attention context is often fixed (document, retrieved text).
            Precompute context K, V once, reuse across multiple queries.

            This is why RAG can be efficient:
            - Encode documents once
            - Cross-attend from each query
            - Document encoding amortized
        """,
    }

Hybrid Approaches

def hybrid_approaches():
    return {
        "prefix_tuning": {
            "idea": "Prepend learned context, use self-attention",
            "implementation": """
                prefix_tokens = learned_embedding  # [prefix_len, hidden]
                full_input = concat(prefix_tokens, input_tokens)
                output = self_attention(full_input)
            """,
            "trade_off": "Simpler (no cross-attention) but less flexible",
        },

        "fusion_in_decoder": {
            "idea": "Fuse external info via cross-attention layers",
            "implementation": """
                for layer in decoder_layers:
                    x = self_attention(x)
                    x = cross_attention(x, context)  # Every N layers
                    x = ffn(x)
            """,
            "trade_off": "More parameters but better context integration",
        },

        "prompt_concatenation": {
            "idea": "Put context in prompt, use only self-attention",
            "implementation": """
                prompt = f"{context}\\n\\nQuestion: {question}"
                output = decoder_only_model(prompt)
            """,
            "trade_off": "Simple but context competes with generation",
        },
    }

Choosing for Your Application

def decision_framework():
    return {
        "use_self_attention_only": [
            "Single input sequence tasks",
            "When context fits in prompt",
            "Decoder-only simplicity preferred",
            "No separate context encoding needed",
        ],

        "use_cross_attention": [
            "Distinct source and target sequences",
            "Retrieved context that's precomputed",
            "Multimodal fusion",
            "When context is much larger than query",
            "Translation and seq-to-seq tasks",
        ],

        "practical_note": """
            Modern LLMs (GPT-4, Claude) use decoder-only with self-attention.
            Context goes in the prompt, not via cross-attention.

            Cross-attention shines when:
            - Context is very large (millions of tokens)
            - Context is static (precompute K, V)
            - Explicit fusion of modalities needed

            For most applications, prompt-based context with self-attention
            is simpler and sufficient.
        """,
    }

The mechanism is the same; the data flow differs. Self-attention relates positions within one sequence. Cross-attention relates positions across two sequences. Choose based on whether you have one input or two, and whether the "context" sequence is worth encoding separately.