This is a crucial point: SynthSAEBench is a synthetic benchmark, meaning the researchers don’t extract features from a real LLM. Instead, they generate an artificial world where they define exactly what the “true features” are.
Step 1: Choose the number of features
Step 2: Create random unit vectors Each feature i gets a direction vector d_i ∈ R^768 created by:
1. Sample from standard normal: g_i ~ N(0, I_768)
2. Normalize to unit length: d_i = g_i / ||g_i||_2
This gives you 16,384 random directions in 768-dimensional space. Each direction represents a “concept” in the artificial world.
Step 3: Optionally orthogonalize Random vectors in high dimensions are already nearly orthogonal, but they add an orthogonalization step to reduce spurious correlations:
Minimize: L_ortho = Σ_{i≠j} (d_i^T d_j)^2 + λ Σ_i (||d_i||_2 - 1)^2
This optimization pushes the vectors to be as orthogonal as possible while keeping them unit length.
Step 4: Assign properties to each feature Each feature gets:
In a real LLM, we don’t know what the “true features” are—that’s the whole problem! But in SynthSAEBench, the researchers say: “We declare that these 16,384 vectors ARE the true features.” Then they use these to generate synthetic activations, train SAEs on those activations, and see if the SAEs can recover the features they started with.
This is like creating a puzzle where you know the answer, so you can test if your puzzle-solving algorithm (SAE) works correctly.
Superposition is when a neural network represents more features (concepts) than it has dimensions available. It’s like trying to pack 16,384 items into 768 boxes—they have to share space.
With N = 16,384 features but only D = 768 dimensions:
This measures HOW MUCH features overlap:
For each feature i:
Find its most similar other feature: max_{j≠i} |d_i^T d_j|
Average across all features: ρ_mm = (1/N) Σ_{i=1}^N max_{j≠i} |d_i^T d_j|
Interpretation:
When features overlap, the activation vector a = Σ c_i d_i becomes ambiguous:
Example with 2 overlapping features:
d_1 = [1, 0]
d_2 = [0.9, 0.436] # Only 25° away from d_1
If a = [1.9, 0.436]:
Could be: c_1=1, c_2=1 (both fire)
Could be: c_1=2, c_2=0 (only first fires)
Could be: c_1=0, c_2=2 (only second fires)
SAEs must disentangle this ambiguity. Higher superposition makes this harder.
From the paper’s experiments (Appendix H):
| Hidden Dim (D) | Features (N) | ρ_mm (no ortho) | ρ_mm (with ortho) |
|---|---|---|---|
| 768 | 4,096 | ~0.12 | ~0.10 |
| 768 | 16,384 | ~0.18 | ~0.15 |
| 768 | 65,536 | ~0.26 | ~0.22 |
| 4096 | 16,384 | ~0.06 | ~0.05 |
Key insight: Superposition scales roughly as O(1/√D) with dimension but grows with the number of features packed in.
The hierarchy is a forest of trees where:
Number of root nodes: 128
Branching factor: 4 (each parent has 4 children)
Maximum depth: 3 levels
Mutual exclusion: ENABLED (only one sibling can fire)
Total features in hierarchy: 10,884 out of 16,384
Remaining features: 5,500 (no hierarchical constraints)
Example tree:
Animal (Feature 1)
/ | \ \
Dog Cat Bird Fish (Features 2,3,4,5)
/|\ |
G R P Eagle (Features 6,7,8,9)
G = Golden Retriever (Feature 6)
R = Rottweiler (Feature 7)
P = Poodle (Feature 8)
After sampling initial coefficients c, apply constraints level-by-level:
def enforce_hierarchy(c, tree_structure):
"""
c: coefficient vector [c_1, c_2, ..., c_N]
tree_structure: parent-child relationships
"""
# Process from root to leaves
for level in range(max_depth):
for node in nodes_at_level[level]:
parent_idx = node.parent
child_idx = node.index
# Enforce parent dependency
if c[parent_idx] == 0:
c[child_idx] = 0
# Enforce mutual exclusion among siblings
if node.has_mutex_siblings:
siblings = get_siblings(node)
active_siblings = [s for s in siblings if c[s] > 0]
if len(active_siblings) > 1:
# Randomly pick one winner
winner = random.choice(active_siblings)
for s in active_siblings:
if s != winner:
c[s] = 0
return cFrom the paper:
c_child ← c_child · 1[c_parent > 0]
This means:
Depth 0 (roots): 128 features
Depth 1: 128 × 4 = 512 features
Depth 2: 512 × 4 = 2,048 features
Depth 3: 2,048 × 4 = 8,192 features
------------------------
Total hierarchical: 10,880 features (close to reported 10,884)
This creates a realistic mix:
If this were representing language:
"animal" (root) might fire when seeing any animal word
└─ "dog" fires for dog-related words
└─ "golden_retriever" fires specifically for that breed
The model learns: Can't have "golden_retriever" without "dog"
Can't have "dog" without "animal"
To model correlated feature firings, we need correlation matrix Σ ∈ R^{16384×16384}
Storage cost: 16,384² = 268,435,456 entries = ~2GB of RAM just for Σ!
Sampling cost: O(N²) operations to sample from N(0, Σ)
This is computationally prohibitive.
Instead of storing full Σ, represent it as:
Σ = F F^T + diag(δ)
Where:
For SynthSAEBench-16k: r = 100, so:
Intuition: Most correlations are low-rank because:
Mathematical proof it’s a valid correlation matrix:
For Σ = FF^T + diag(δ), the diagonal entries are:
Σ_{ii} = Σ_j F_{ij}^2 + δ_i
To make it a correlation matrix (diagonal = 1), set:
δ_i = 1 - Σ_j F_{ij}^2
def generate_low_rank_correlation(N, rank, correlation_scale):
# Sample factor matrix from scaled normal
F = correlation_scale * np.random.randn(N, rank)
# Compute required diagonal to ensure unit diagonal
row_sums = np.sum(F**2, axis=1) # Σ_j F_{ij}^2 for each i
delta = 1 - row_sums
# Check if any delta is negative (would make Σ invalid)
if np.any(delta < 0.01): # threshold for numerical stability
# Scale down F to ensure all deltas are valid
max_row_sum = np.max(row_sums)
scale_factor = np.sqrt((1 - 0.01) / max_row_sum)
F = F * scale_factor
delta = 1 - np.sum(F**2, axis=1)
return F, deltaInstead of sampling g ~ N(0, Σ) directly:
# Direct (expensive): O(N^2) time and space
g = np.random.multivariate_normal(mean=np.zeros(N), cov=Sigma)
# Low-rank (efficient): O(Nr) time and space
epsilon = np.random.randn(r) # r-dimensional standard normal
eta = np.random.randn(N) # N-dimensional standard normal
g = F @ epsilon + np.sqrt(delta) * eta # Matrix-vector product
# This gives E[gg^T] = FF^T + diag(δ) = Σ ✓Why this works:
g = Fε + √δ ⊙ η
E[gg^T] = E[(Fε + √δ⊙η)(Fε + √δ⊙η)^T]
= E[Fεε^TF^T] + E[(√δ⊙η)(√δ⊙η)^T] (cross terms = 0)
= F·I·F^T + diag(δ) (since ε,η are standard normal)
= FF^T + diag(δ)
= Σ ✓
The correlation_scale = s controls how strong
correlations are:
Larger s → larger entries in F → more off-diagonal correlation in Σ.
A Matryoshka SAE trains not just ONE autoencoder, but MULTIPLE nested autoencoders simultaneously:
SAE with 4096 latents:
Prefix 1: First 128 latents → must reconstruct input on their own
Prefix 2: First 512 latents → must reconstruct input on their own
Prefix 3: First 2048 latents → must reconstruct input on their own
Prefix 4: All 4096 latents → must reconstruct input
Standard SAE loss:
L = ||a - â||²₂ + λ||f||₁
Matryoshka SAE loss:
L = Σ_{m ∈ M} [||a - â_m||²₂ + λ||f_m||₁] + α·L_aux
Where:
Incentive structure:
Example learned ordering:
Latent 1-128: "subject", "verb", "sentence_end", "punctuation"
Latent 129-512: "past_tense", "plural", "capitalization"
Latent 513-2048: "dog", "cat", "food", "color"
Latent 2049+: "golden_retriever", "technical_jargon", rare concepts
Surprising finding: Matryoshka SAEs achieve:
Why this disconnect?
The Matryoshka loss forces early latents to be “maximally useful” for reconstruction. This means:
The nested training creates a pressure to learn interpretable, disentangled features rather than optimizing raw reconstruction.
This pattern (good probing, poor reconstruction) is also seen in real LLM SAEs! This validates SynthSAEBench’s realism.
Unlike standard SAEs with a learned encoder, MP-SAEs select latents iteratively:
Standard SAE (parallel):
f = ReLU(W_enc · a) # All latents activated at once
â = W_dec · f # Reconstruct
Matching Pursuit SAE (serial):
residual = a # Start with full activation
selected_latents = []
for t in 1..k:
# Find latent with highest projection onto residual
scores = W_dec^T · residual
best_latent = argmax(scores)
selected_latents.append(best_latent)
# Project out this latent's contribution
residual = residual - W_dec[:,best_latent] * scores[best_latent]
â = a - residual # Reconstruction is original minus residual
At iteration t:
l_t = argmax_i (W_dec,i · â_t) # Select best latent
â_{t+1} = â_t - W_dec,l_t · (W_dec,l_t · â_t) # Project out
Where â_0 = a (initial residual is the full activation).
After k iterations, the reconstruction loss is:
L = ||â_k||²₂
(The residual that couldn’t be explained.)
Standard SAE: Encoder is a simple linear + ReLU
f_i = ReLU(W_enc,i · a) # Each latent independently decides to activate
MP-SAE: Encoder is an iterative algorithm
Each latent can condition on what previous latents already explained
Can adaptively focus on unexplained variance
Key result from Section 6.2:
As superposition increases (ρ_mm goes up):
Interpretation:
When features overlap (superposition), MP-SAEs exploit this overlap:
Example:
True features:
d_1 = [1.0, 0.0]
d_2 = [0.9, 0.44] # Overlaps with d_1
True activation: a = 2·d_1 + 3·d_2 = [4.7, 1.32]
What MP-SAE might learn:
Decoder column w_1 ≈ d_1
Decoder column w_2 ≈ (d_1 + d_2)/2 # Splits the difference!
MP reconstruction:
Select w_2 heavily → explains most of the overlap
Select w_1 lightly → cleans up remainder
Result: Excellent reconstruction (low residual)
But w_2 doesn't correspond to a TRUE feature!
The iterative selection allows MP-SAEs to find ad-hoc combinations that reconstruct well but don’t align with the true feature dictionary.
Linear encoders are less flexible:
f_i = ReLU(W_enc,i · a) # Must commit to a single direction W_enc,i
Can't adaptively adjust based on what else fired
Limited to representing true features (or linear combinations)
L0 = number of active (non-zero) latents per sample
For an SAE with latent activations f = [f_1, f_2, …, f_L]:
L0 = |{i : f_i > 0}| # Count of non-zero entries
Trade-off:
For each SAE latent j, treat it as a binary classifier for its best-matched ground-truth feature i*:
For each sample:
Prediction: latent j fires (f_j > 0) → predict feature i* is active
Ground truth: Is c_{i*} > 0?
Compute:
True Positives (TP): f_j > 0 AND c_{i*} > 0
False Positives (FP): f_j > 0 AND c_{i*} = 0
False Negatives (FN): f_j = 0 AND c_{i*} > 0
Precision = TP / (TP + FP) # Of all predictions, how many correct?
Recall = TP / (TP + FN) # Of all true cases, how many caught?
Low L0 (e.g., L0=15):
Only 15 latents can fire per sample
SAE must be VERY selective
Many true features won't have corresponding latent fire
→ Many False Negatives (missed detections)
→ Low Recall
High L0 (e.g., L0=45):
45 latents can fire per sample
SAE can afford to be generous
More true features will have corresponding latent fire
→ Fewer False Negatives
→ High Recall
Low L0 (e.g., L0=15):
Only 15 latents fire
Each one should fire only for clear, strong signals
Less likely to fire spuriously
→ Fewer False Positives
→ High Precision
High L0 (e.g., L0=45):
45 latents fire
More opportunities for spurious activations
Latents might fire for weak/ambiguous signals
→ More False Positives
→ Low Precision
Ground truth: activation has features {animal, dog, golden_retriever} active
SAE with L0=10:
Fires: {animal, dog, golden_retriever, mammal, pet, large_dog, yellow_fur, outdoor, friendly, domestic}
TP: animal, dog, golden_retriever (3)
FP: mammal, pet, large_dog, yellow_fur, outdoor, friendly, domestic (7)
FN: 0 (caught all true features)
Precision = 3/(3+7) = 0.30 # Lots of false alarms
Recall = 3/(3+0) = 1.00 # Caught everything
SAE with L0=3:
Fires: {animal, dog}
TP: animal, dog (2)
FP: 0 (everything it predicted was correct)
FN: golden_retriever (1) (missed this)
Precision = 2/(2+0) = 1.00 # Perfect precision
Recall = 2/(2+1) = 0.67 # Missed some
L0=15 L0=30 L0=45
Precision: ~0.60 ~0.45 ~0.35
Recall: ~0.50 ~0.80 ~0.95
Trade-off curve: As L0 increases, you slide along precision-recall curve.
For interpretability:
Real-world analogy:
This precision-recall trade-off has been observed in real LLM SAEs, validating SynthSAEBench’s realism. The fact that no L0 setting achieves perfect F1 (best is ~0.85) suggests current SAE architectures have fundamental limitations.
These seven concepts form the technical foundation of SynthSAEBench:
Together, these create a realistic, tractable benchmark that reproduces phenomena seen in real LLM SAEs while providing the ground truth needed for rigorous evaluation.