When to Use Self-Attention vs Cross-Attention
A book club discussing one book uses internal references: "remember chapter 3?" A book club comparing two books uses external references: "how does this differ from the other book?" Same discussion mechanism, different information sources.
Self-attention and cross-attention follow the same pattern. Self-attention processes one sequence, relating positions within it. Cross-attention relates two different sequences, letting one attend to the other. The mechanism is identical; the inputs differ.
The Mathematical Difference
def attention_mechanics():
return {
"self_attention": {
"inputs": "One sequence X",
"computation": """
Q = X @ W_q # Queries from X
K = X @ W_k # Keys from X
V = X @ W_v # Values from X
output = softmax(Q @ K^T / sqrt(d)) @ V
""",
"what_happens": "Each position attends to all positions in same sequence",
},
"cross_attention": {
"inputs": "Two sequences: X (query source) and Y (context)",
"computation": """
Q = X @ W_q # Queries from X
K = Y @ W_k # Keys from Y (different sequence!)
V = Y @ W_v # Values from Y
output = softmax(Q @ K^T / sqrt(d)) @ V
""",
"what_happens": "Each position in X attends to all positions in Y",
},
"key_difference": """
Self: Q, K, V all from same source
Cross: Q from one source, K and V from another
""",
}
When to Use Self-Attention
def self_attention_use_cases():
return {
"language_modeling": {
"task": "Predict next token in sequence",
"why_self": "Each token needs context from previous tokens",
"example": "GPT, LLaMA, Claude (decoder-only)",
},
"text_understanding": {
"task": "Understand relationships within text",
"why_self": "Words relate to other words in same text",
"example": "BERT embeddings, sentiment analysis",
},
"sequence_classification": {
"task": "Classify entire sequence",
"why_self": "Need to integrate information across sequence",
"example": "Spam detection, intent classification",
},
"general_principle": """
Use self-attention when:
- Processing a single sequence
- Positions within sequence need to interact
- No external context required
""",
}
When to Use Cross-Attention
def cross_attention_use_cases():
return {
"machine_translation": {
"task": "Translate source language to target",
"why_cross": "Target generation needs to attend to source",
"architecture": """
Encoder: Self-attention over source
Decoder: Self-attention over target
+ Cross-attention to encoder output
""",
},
"retrieval_augmented_generation": {
"task": "Generate using retrieved documents",
"why_cross": "Generation attends to retrieved context",
"architecture": """
Retrieved docs: Encoded separately
Generator: Cross-attends to retrieved encodings
""",
},
"multimodal": {
"task": "Process image + text together",
"why_cross": "Text attends to image features",
"architecture": """
Image encoder: Self-attention over patches
Text decoder: Cross-attention to image features
""",
},
"document_qa": {
"task": "Answer questions about a document",
"why_cross": "Question attends to document",
"architecture": """
Document: Encoded with self-attention
Question: Cross-attends to document
""",
},
"general_principle": """
Use cross-attention when:
- Two distinct sequences need to interact
- One sequence conditions on another
- Fusion of different information sources
""",
}
Architecture Patterns
def architecture_patterns():
return {
"encoder_only": {
"structure": "Self-attention throughout",
"use_case": "Understanding, embeddings, classification",
"example": "BERT",
},
"decoder_only": {
"structure": "Causal self-attention throughout",
"use_case": "Generation, language modeling",
"example": "GPT, LLaMA",
"note": "No cross-attention, context in prompt",
},
"encoder_decoder": {
"structure": """
Encoder: Self-attention (bidirectional)
Decoder: Self-attention (causal) + Cross-attention
""",
"use_case": "Sequence-to-sequence tasks",
"example": "T5, BART, original Transformer",
},
"decoder_with_retrieval": {
"structure": """
Retriever: Encodes documents
Decoder: Self-attention + Cross-attention to retrieved
""",
"use_case": "Knowledge-intensive generation",
"example": "RETRO, RAG",
},
}
Implementation Differences
class AttentionImplementation:
"""
Implementing both attention types
"""
def self_attention(
self,
x: torch.Tensor, # [batch, seq_len, hidden]
) -> torch.Tensor:
"""Standard self-attention"""
Q = self.q_proj(x) # [batch, seq_len, hidden]
K = self.k_proj(x) # Same shape, same source
V = self.v_proj(x)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
attn = torch.softmax(scores, dim=-1)
output = torch.matmul(attn, V)
return output
def cross_attention(
self,
query_seq: torch.Tensor, # [batch, query_len, hidden]
context_seq: torch.Tensor, # [batch, context_len, hidden]
) -> torch.Tensor:
"""Cross-attention from query to context"""
Q = self.q_proj(query_seq) # Queries from query sequence
K = self.k_proj(context_seq) # Keys from context (different!)
V = self.v_proj(context_seq) # Values from context
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
attn = torch.softmax(scores, dim=-1)
output = torch.matmul(attn, V)
return output # [batch, query_len, hidden]
Memory and Compute Implications
def memory_compute_comparison():
return {
"self_attention": {
"memory": "O(n²) for sequence length n",
"compute": "O(n² * d) for hidden dim d",
"kv_cache": "Stores K, V for all positions processed",
},
"cross_attention": {
"memory": "O(n * m) for query length n, context length m",
"compute": "O(n * m * d)",
"kv_cache": "Context K, V can be precomputed and reused",
},
"optimization_opportunity": """
Cross-attention context is often fixed (document, retrieved text).
Precompute context K, V once, reuse across multiple queries.
This is why RAG can be efficient:
- Encode documents once
- Cross-attend from each query
- Document encoding amortized
""",
}
Hybrid Approaches
def hybrid_approaches():
return {
"prefix_tuning": {
"idea": "Prepend learned context, use self-attention",
"implementation": """
prefix_tokens = learned_embedding # [prefix_len, hidden]
full_input = concat(prefix_tokens, input_tokens)
output = self_attention(full_input)
""",
"trade_off": "Simpler (no cross-attention) but less flexible",
},
"fusion_in_decoder": {
"idea": "Fuse external info via cross-attention layers",
"implementation": """
for layer in decoder_layers:
x = self_attention(x)
x = cross_attention(x, context) # Every N layers
x = ffn(x)
""",
"trade_off": "More parameters but better context integration",
},
"prompt_concatenation": {
"idea": "Put context in prompt, use only self-attention",
"implementation": """
prompt = f"{context}\\n\\nQuestion: {question}"
output = decoder_only_model(prompt)
""",
"trade_off": "Simple but context competes with generation",
},
}
Choosing for Your Application
def decision_framework():
return {
"use_self_attention_only": [
"Single input sequence tasks",
"When context fits in prompt",
"Decoder-only simplicity preferred",
"No separate context encoding needed",
],
"use_cross_attention": [
"Distinct source and target sequences",
"Retrieved context that's precomputed",
"Multimodal fusion",
"When context is much larger than query",
"Translation and seq-to-seq tasks",
],
"practical_note": """
Modern LLMs (GPT-4, Claude) use decoder-only with self-attention.
Context goes in the prompt, not via cross-attention.
Cross-attention shines when:
- Context is very large (millions of tokens)
- Context is static (precompute K, V)
- Explicit fusion of modalities needed
For most applications, prompt-based context with self-attention
is simpler and sufficient.
""",
}
The mechanism is the same; the data flow differs. Self-attention relates positions within one sequence. Cross-attention relates positions across two sequences. Choose based on whether you have one input or two, and whether the "context" sequence is worth encoding separately.