Extended Technical Research Plan
Date: February 2026 Based on: SynthSAEBench: Evaluating Sparse Autoencoders on Scalable Realistic Synthetic Data
Recent work has revealed that not all features in language models are one-dimensionally linear—many exist as multi-dimensional manifolds representing concepts with inherent geometric structure (Engels et al., 2025). While sparse autoencoders (SAEs) have emerged as a powerful tool for mechanistic interpretability, their ability to recover and represent features that lie on manifolds remains poorly understood. This research plan extends the SynthSAEBench framework to systematically study manifold-aware SAEs through controlled synthetic experiments.
We propose: (1) a methodology for generating synthetic activation data containing feature manifolds with known ground-truth structure, (2) novel evaluation metrics for assessing SAE performance in the presence of manifolds, (3) architectural modifications to SAEs that explicitly model manifold structure, and (4) a comprehensive experimental framework to test competing representation hypotheses. Our approach enables rigorous evaluation of whether SAEs can recover not just individual features, but the geometric relationships between them—a critical capability for understanding how language models represent structured knowledge.
Key Contributions:
The Linear Representation Hypothesis (LRH) posits that neural networks represent concepts as directions in activation space (Park et al., 2024b). Under this hypothesis, semantic relationships are encoded through linear operations: vector addition, subtraction, and scaling. This framework has proven remarkably successful for explaining phenomena like the classic “king - man + woman = queen” analogy.
However, recent empirical work challenges the universality of one-dimensional linear features:
Engels et al. (2025) demonstrate that language models contain irreducibly multi-dimensional features—concepts that cannot be decomposed into independent or non-co-occurring lower-dimensional components. Their key examples include:
Li et al. (2025) reveal hierarchical geometric structure in SAE feature dictionaries across three scales:
Olah & Batson (2023) introduced the feature manifold toy model, suggesting that related features lie on continuous manifolds where nearby points represent similar concepts. This contrasts with discrete, independent features assumed by standard SAE training.
Standard SAEs are designed to recover discrete, independent features by enforcing sparsity through L1 regularization:
Loss = ||x - x̂||² + λ||f||₁
where:
f = encoder(x) # Sparse feature activations
x̂ = decoder(f) # Reconstruction
This objective assumes features can be represented as one-dimensional scalars in the latent space. However, manifold features require multiple dimensions to represent their geometric structure:
Michaud et al. (2024) show that feature manifolds create pathological scaling behavior: when features lie on manifolds, SAEs allocate disproportionately many latents to “tile” high-frequency manifolds, learning far fewer distinct features than the number of latents available. This suggests fundamental architectural limitations in current SAE designs.
This research plan addresses five fundamental questions:
Evaluation: How should we evaluate SAE performance when ground-truth features are manifolds rather than independent vectors?
Architecture: What architectural modifications enable SAEs to efficiently represent manifold structure?
Scaling: How does the presence of manifolds affect SAE scaling laws and capacity allocation?
Representation hypotheses: Can synthetic models following different geometric priors (linear, circular, hierarchical) distinguish between competing theories of neural representation?
LLM correspondence: Do training dynamics and failure modes in synthetic manifold benchmarks match behaviors observed in real LLMs?
SynthSAEBench provides the ideal foundation for studying manifold-aware SAEs because:
By extending SynthSAEBench to include manifold structure, we can systematically study how SAEs handle geometric features while maintaining the benchmark’s scientific rigor.
SynthSAEBench generates synthetic activation data following the generative model:
# 1. Sample sparse feature coefficients
c ~ TruncatedPareto(α=1, β=threshold, shape=(N,)) # N=16,384 features
# 2. Apply hierarchical constraints
c = enforce_hierarchy(c, tree_structure)
# 3. Apply low-rank correlations
c = apply_correlations(c, correlation_matrix_Σ)
# 4. Generate activation
x = D @ c + ε
where D ∈ ℝ^{d×N} is the feature dictionary (d=768 hidden dim)
ε ~ N(0, σ²I) is Gaussian noiseKey properties captured:
This provides a realistic but tractable benchmark where ground truth is known.
A feature is irreducibly k-dimensional if:
Example: Days of the week form a 2D circular manifold:
(cos(2πt/7), sin(2πt/7)) where
t ∈ {0,…,6}| Manifold Type | Dimensionality | Examples | Parameterization |
|---|---|---|---|
| Linear | 1D | Scalar concepts (gender, sentiment) | Single real value |
| Circular | 2D (S¹) | Days, months, angles | (cos θ, sin θ) |
| Spherical | 3D (S²) | Directions, orientations | (sin φ cos θ, sin φ sin θ, cos φ) |
| Toroidal | 4D (S¹ × S¹) | Periodic pairs (hour+day) | (cos θ₁, sin θ₁, cos θ₂, sin θ₂) |
| Hyperbolic | Variable | Hierarchies, trees | Poincaré disk coordinates |
| Simplicial | (n-1)D | Categorical (n classes) | Probability simplex |
Curvature: Manifolds have intrinsic curvature affecting geodesic paths
Density: Feature distribution on the manifold
Noise: Perturbations relative to manifold structure
class StandardSAE:
def __init__(self, d_hidden, n_latents):
self.W_enc = nn.Linear(d_hidden, n_latents)
self.W_dec = nn.Linear(n_latents, d_hidden, bias=False)
# Decoder columns constrained to unit norm
def forward(self, x):
f = F.relu(self.W_enc(x)) # Sparse latents
x_hat = self.W_dec(f) # Reconstruction
return f, x_hat
def loss(self, x):
f, x_hat = self.forward(x)
recon_loss = (x - x_hat).pow(2).mean()
sparsity_loss = f.abs().mean()
return recon_loss + self.lambda_l1 * sparsity_lossLimitations for manifolds:
Separates magnitude estimation from feature activation:
class GatedSAE:
def forward(self, x):
# Gating: which features are active (binary)
gate_logits = self.W_gate(x)
gate = (gate_logits > 0).float()
# Magnitude: how much each active feature contributes
magnitude = F.relu(self.W_mag(x))
f = gate * magnitude # Element-wise product
x_hat = self.W_dec(f)
return f, x_hatAdvantages:
Uses a discontinuous activation to improve reconstruction:
def jumprelu(x, threshold=0.1):
return torch.where(x < threshold,
torch.zeros_like(x),
x + threshold)Provides better reconstruction fidelity while maintaining sparsity.
Current SynthSAEBench metrics assume discrete features:
Problem: These don’t assess geometric structure recovery.
Needed for manifolds:
We develop these in Section 4.
We extend the SynthSAEBench generative model to include manifold features alongside standard independent features.
N_independent = 12,000 # Standard 1D features
N_manifolds = 10 # Number of manifold structures
manifold_dims = [2, 2, 2, 2, 3, 3, 4, 4, 2, 2] # Dimensions per manifold
Total feature dimensions:
N_independent + sum(manifold_dims) = 12,000 + 26 = 12,026 effective features
Dictionary size: D ∈ ℝ^{768 × 16,384}
- Columns 1-12,000: Independent unit vectors (as in original)
- Columns 12,001-16,384: Grouped into 10 manifoldsFor each manifold m:
Step 1: Choose manifold type and parameters
manifold_config = {
'type': 'circular', # or 'spherical', 'toroidal'
'intrinsic_dim': 2, # Dimension of manifold (k)
'embedding_dim': 20, # Ambient dimensions used
'num_discretization': 32, # Number of discrete points on manifold
'curvature': 'constant', # or 'variable'
'noise_level': 0.05, # Tangent/normal noise ratio
}Step 2: Sample base manifold in intrinsic coordinates
For a circular manifold (S¹):
def generate_circular_manifold(num_points, noise_level):
# Generate evenly spaced points on circle
angles = np.linspace(0, 2*np.pi, num_points, endpoint=False)
# Base 2D coordinates
coords_2d = np.stack([np.cos(angles), np.sin(angles)], axis=1)
# Add tangent noise (along the circle)
angle_noise = np.random.randn(num_points) * noise_level
angles_noisy = angles + angle_noise
coords_2d = np.stack([np.cos(angles_noisy), np.sin(angles_noisy)], axis=1)
return coords_2d, anglesFor a spherical manifold (S²):
def generate_spherical_manifold(num_points, noise_level):
# Fibonacci sphere algorithm for even distribution
indices = np.arange(num_points) + 0.5
phi = np.arccos(1 - 2*indices/num_points)
theta = np.pi * (1 + 5**0.5) * indices
# 3D coordinates
x = np.sin(phi) * np.cos(theta)
y = np.sin(phi) * np.sin(theta)
z = np.cos(phi)
coords_3d = np.stack([x, y, z], axis=1)
# Add tangent noise (perpendicular to radius)
tangent_noise = np.random.randn(num_points, 3) * noise_level
tangent_noise -= (tangent_noise * coords_3d).sum(axis=1, keepdims=True) * coords_3d
coords_3d += tangent_noise
coords_3d /= np.linalg.norm(coords_3d, axis=1, keepdims=True)
return coords_3d, (phi, theta)For a toroidal manifold (S¹ × S¹):
def generate_toroidal_manifold(num_points, major_radius, minor_radius, noise_level):
# Sample angles uniformly
theta = np.random.uniform(0, 2*np.pi, num_points) # Major circle
phi = np.random.uniform(0, 2*np.pi, num_points) # Minor circle
# 4D torus coordinates (or 3D embedding if desired)
coords_4d = np.stack([
np.cos(theta),
np.sin(theta),
np.cos(phi),
np.sin(phi)
], axis=1)
return coords_4d, (theta, phi)Step 3: Embed in high-dimensional space
Map from intrinsic manifold coordinates to 768-dimensional activation space:
def embed_manifold_in_activation_space(coords, embedding_dim, total_dim=768):
"""
Embed k-dimensional manifold into d-dimensional space.
Args:
coords: (N, k) manifold coordinates
embedding_dim: number of dimensions to use for embedding
total_dim: total activation dimension (768)
Returns:
embedded_coords: (N, total_dim) high-dimensional coordinates
"""
N, k = coords.shape
# Random embedding matrix: k -> embedding_dim
# This is a random smooth embedding
W_embed = np.random.randn(k, embedding_dim)
W_embed = orthogonalize_gram_schmidt(W_embed)
# Project manifold to embedding_dim subspace
embedded_low = coords @ W_embed # (N, embedding_dim)
# Place in full dimensional space
embedded_full = np.zeros((N, total_dim))
start_idx = np.random.randint(0, total_dim - embedding_dim)
embedded_full[:, start_idx:start_idx+embedding_dim] = embedded_low
# Normalize to unit norm (consistent with SynthSAEBench)
embedded_full /= np.linalg.norm(embedded_full, axis=1, keepdims=True)
return embedded_fullStep 4: Add manifold directions to dictionary
def construct_hybrid_dictionary(N_independent, manifold_configs, d_hidden=768):
"""
Construct feature dictionary with both independent and manifold features.
"""
# Independent features (as in original SynthSAEBench)
D_independent = generate_random_unit_vectors(N_independent, d_hidden)
# Manifold features
manifold_features = []
manifold_metadata = []
for i, config in enumerate(manifold_configs):
# Generate manifold points
if config['type'] == 'circular':
coords, params = generate_circular_manifold(
config['num_discretization'], config['noise_level']
)
elif config['type'] == 'spherical':
coords, params = generate_spherical_manifold(
config['num_discretization'], config['noise_level']
)
elif config['type'] == 'toroidal':
coords, params = generate_toroidal_manifold(
config['num_discretization'],
config['major_radius'],
config['minor_radius'],
config['noise_level']
)
# Embed in activation space
embedded = embed_manifold_in_activation_space(
coords, config['embedding_dim'], d_hidden
)
manifold_features.append(embedded)
manifold_metadata.append({
'manifold_id': i,
'type': config['type'],
'intrinsic_coords': coords,
'params': params,
'start_idx': len(D_independent) + sum(len(m) for m in manifold_features[:-1]),
'end_idx': len(D_independent) + sum(len(m) for m in manifold_features),
})
# Concatenate all features
D_manifolds = np.vstack(manifold_features)
D_full = np.vstack([D_independent, D_manifolds])
return D_full, manifold_metadataEach manifold has associated statistics controlling its behavior:
class ManifoldFeatureStats:
def __init__(self, manifold_id, manifold_type, intrinsic_dim):
self.manifold_id = manifold_id
self.type = manifold_type
self.intrinsic_dim = intrinsic_dim
# Firing probability (probability this manifold is active)
self.p_active = sample_zipfian()
# When active, distribution over manifold surface
self.surface_distribution = self._init_surface_distribution()
# Magnitude distribution
self.magnitude_mean = np.random.uniform(15.0, 25.0)
self.magnitude_std = np.random.lognormal(0, 0.5)
def _init_surface_distribution(self):
"""
Define how probability mass distributes over manifold surface.
Options:
- 'uniform': Equal probability everywhere
- 'concentrated': Gaussian bumps at specific locations
- 'mixed': Multiple modes
"""
if self.type == 'circular':
# Could be uniform or peaked (e.g., prefer certain months)
return UniformCircularDistribution()
elif self.type == 'spherical':
return UniformSphericalDistribution()
# ... etcExtend the activation generation to include manifold features:
def generate_activation_with_manifolds(D, manifold_metadata, feature_stats):
"""
Generate a single activation vector with both independent and manifold features.
"""
N_total = D.shape[1]
N_independent = manifold_metadata[0]['start_idx']
# Step 1: Sample independent features (as before)
c_independent = np.zeros(N_independent)
for i in range(N_independent):
if np.random.rand() < feature_stats[i].p_active:
c_independent[i] = sample_magnitude(
feature_stats[i].magnitude_mean,
feature_stats[i].magnitude_std
)
# Step 2: Sample manifold features
c_manifolds = np.zeros(N_total - N_independent)
for manifold_meta in manifold_metadata:
m_id = manifold_meta['manifold_id']
m_stats = feature_stats[N_independent + m_id]
# Decide if this manifold is active
if np.random.rand() < m_stats.p_active:
# Sample a point on the manifold surface
point_idx = m_stats.surface_distribution.sample()
# Get the feature index corresponding to this manifold point
feature_idx = manifold_meta['start_idx'] + point_idx - N_independent
# Sample magnitude
magnitude = sample_magnitude(m_stats.magnitude_mean, m_stats.magnitude_std)
# Only ONE point on the manifold is active at a time
# (Represents the current value of the circular/spherical feature)
c_manifolds[feature_idx] = magnitude
# Optional: Add smoothness by activating nearby points with decay
if m_stats.smooth_activation:
for neighbor_offset in [-1, 1]:
neighbor_idx = (point_idx + neighbor_offset) % manifold_meta['num_discretization']
neighbor_feature_idx = manifold_meta['start_idx'] + neighbor_idx - N_independent
c_manifolds[neighbor_feature_idx] = magnitude * 0.3 # Decayed activation
# Combine independent and manifold coefficients
c_full = np.concatenate([c_independent, c_manifolds])
# Step 3: Apply hierarchy constraints (if any)
c_full = enforce_hierarchy(c_full, hierarchy_tree)
# Step 4: Apply correlations (with special handling for manifolds)
c_full = apply_correlations_with_manifolds(c_full, correlation_matrix, manifold_metadata)
# Step 5: Generate activation
x = D @ c_full + np.random.randn(D.shape[0]) * noise_std
return x, c_full, manifold_metadataKey design choices:
Discrete vs. Continuous: We discretize manifolds (32 points on circle) to maintain compatibility with the discrete feature recovery evaluation. This is realistic—real LLMs have finite capacity.
Mutual exclusivity: For a given manifold, only one point (or smoothed neighborhood) is active per sample. This reflects that “Monday” and “Wednesday” don’t co-occur as values of “day of week.”
Manifold-aware correlations: Manifold features can correlate with independent features (e.g., “winter months” correlates with “cold weather”).
Extend the tree hierarchy to include manifolds:
[Time Concept] (root)
/ \
[Cyclical] [Linear]
/ \ |
[Day of Week] [Month] [Timestamp]
(circular) (circular) (1D continuous)
Implementation:
class ManifoldHierarchy:
def __init__(self):
# Define hierarchical tree including manifold nodes
self.tree = {
'time_concept': {
'type': 'independent',
'children': ['cyclical', 'linear']
},
'cyclical': {
'type': 'independent',
'children': ['day_of_week', 'month']
},
'day_of_week': {
'type': 'circular_manifold',
'points': 7,
'children': []
},
'month': {
'type': 'circular_manifold',
'points': 12,
'children': []
},
'linear': {
'type': 'independent',
'children': ['timestamp']
},
}
def enforce_constraints(self, c):
"""
Hierarchical constraints:
- If 'time_concept' is inactive, all children are inactive
- If 'cyclical' is inactive, both day_of_week and month manifolds are inactive
"""
# Implement top-down enforcement
# ...
return cThis creates realistic structure where manifold features participate in hierarchies, testing whether SAEs can recover both geometric and hierarchical structure simultaneously.
Challenge: How should manifold points correlate with each other and with independent features?
Solution: Extend the low-rank correlation model:
def build_manifold_aware_correlation_matrix(N_independent, manifold_metadata, rank=100):
"""
Build correlation matrix Σ that respects manifold structure.
"""
N_total = N_independent + sum(m['num_discretization'] for m in manifold_metadata)
# Standard low-rank correlation for independent features
F_independent = np.random.randn(N_independent, rank) * correlation_scale
# Manifold correlations: points on same manifold have structured correlations
F_manifolds_list = []
for m_meta in manifold_metadata:
n_points = m_meta['num_discretization']
# Create smooth correlation structure along manifold
# Nearby points have higher correlation
F_manifold = np.zeros((n_points, rank))
# Use a subset of rank dimensions for this manifold
manifold_rank_dims = np.random.choice(rank, size=min(10, rank), replace=False)
for i, point_idx in enumerate(range(n_points)):
# Points on manifold share some factors
# Magnitude varies smoothly (e.g., cosine pattern along circle)
angle = 2 * np.pi * point_idx / n_points
for j, rank_dim in enumerate(manifold_rank_dims):
# Smooth variation along manifold
F_manifold[point_idx, rank_dim] = np.cos(angle + j * np.pi / len(manifold_rank_dims))
F_manifolds_list.append(F_manifold)
F_manifolds = np.vstack(F_manifolds_list)
# Combine
F_full = np.vstack([F_independent, F_manifolds])
# Compute correlation matrix
Sigma = F_full @ F_full.T
# Normalize to correlation matrix (diagonal = 1)
delta = 1 - np.diag(Sigma)
Sigma = Sigma + np.diag(delta)
return Sigma, F_fullEffect: Creates realistic correlation patterns where:
Standard metrics (MCC, F1) assess per-feature binary classification. For manifolds, we need geometric structure recovery metrics.
Goal: Detect whether the SAE has learned to represent a manifold using multiple latents.
Method 1: Latent Co-activation Clustering
def detect_manifold_clusters(sae_activations, threshold_correlation=0.3):
"""
Identify groups of SAE latents that consistently co-activate,
suggesting they jointly represent a manifold.
Args:
sae_activations: (n_samples, n_latents) binary activation matrix
Returns:
clusters: List of latent groups forming potential manifolds
"""
# Compute co-activation correlation matrix
# Corr[i,j] = frequency that latent i and j are both active
co_activation = (sae_activations.T @ sae_activations) / sae_activations.shape[0]
# Threshold to adjacency matrix
adjacency = (co_activation > threshold_correlation).astype(float)
np.fill_diagonal(adjacency, 0)
# Find connected components (potential manifold groups)
from scipy.sparse.csgraph import connected_components
n_clusters, labels = connected_components(adjacency, directed=False)
clusters = []
for cluster_id in range(n_clusters):
latent_group = np.where(labels == cluster_id)[0]
if len(latent_group) >= 2: # At least 2D for a manifold
clusters.append(latent_group)
return clustersMethod 2: Decoder Weight Geometry
def analyze_decoder_subspace_geometry(decoder_weights, latent_group):
"""
Analyze geometric structure of decoder columns for a latent group.
If latents form a manifold, their decoder columns should span
a low-dimensional subspace with manifold structure.
"""
# Extract decoder columns for this latent group
W_dec_group = decoder_weights[:, latent_group] # (d_hidden, k)
# Perform PCA to find intrinsic dimensionality
U, S, Vt = np.linalg.svd(W_dec_group, full_matrices=False)
# Intrinsic dimensionality = number of significant singular values
intrinsic_dim = np.sum(S > 0.1 * S[0])
# Project decoder columns to top-k principal components
W_dec_projected = Vt[:intrinsic_dim, :].T # (n_latents, intrinsic_dim)
# Analyze geometry in projected space
if intrinsic_dim == 2:
# Check for circular structure
circularity_score = measure_circularity(W_dec_projected)
elif intrinsic_dim == 3:
# Check for spherical structure
sphericity_score = measure_sphericity(W_dec_projected)
return {
'intrinsic_dim': intrinsic_dim,
'singular_values': S,
'geometry_type': infer_geometry_type(W_dec_projected),
'geometry_score': compute_geometry_score(W_dec_projected),
}
def measure_circularity(points_2d):
"""
Measure how circular a set of 2D points is.
"""
# Fit circle: find center and radius minimizing distance to points
from scipy.optimize import least_squares
def circle_residuals(params, points):
cx, cy, r = params
return np.sqrt((points[:, 0] - cx)**2 + (points[:, 1] - cy)**2) - r
initial_guess = [points_2d[:, 0].mean(), points_2d[:, 1].mean(), 1.0]
result = least_squares(circle_residuals, initial_guess, args=(points_2d,))
# Circularity = 1 - (std of residuals) / radius
cx, cy, r = result.x
residuals = circle_residuals(result.x, points_2d)
circularity = 1 - (np.std(residuals) / r)
return circularityGoal: Measure how well recovered manifold aligns with ground-truth manifold.
Metric: Geodesic Distance Preservation
def manifold_alignment_score(gt_manifold_points, sae_latent_group, sae_activations):
"""
Measure alignment between ground-truth manifold and SAE-learned representation.
Idea: Geodesic distances on the manifold should be preserved
in the SAE latent space.
Args:
gt_manifold_points: (N, k) ground-truth manifold coordinates
sae_latent_group: indices of SAE latents representing this manifold
sae_activations: (n_samples, n_latents) SAE activations
Returns:
alignment_score: 0-1, higher is better
"""
# Extract activations for this latent group
group_activations = sae_activations[:, sae_latent_group] # (n_samples, k')
# Compute pairwise geodesic distances on ground-truth manifold
D_gt = compute_geodesic_distances(gt_manifold_points)
# Compute pairwise Euclidean distances in SAE latent space
D_sae = pairwise_distances(group_activations)
# Measure correlation between distance matrices (Mantel test / Procrustes)
# Good alignment → monotonic relationship between D_gt and D_sae
from scipy.stats import spearmanr
# Flatten upper triangular parts
gt_dists = D_gt[np.triu_indices_from(D_gt, k=1)]
sae_dists = D_sae[np.triu_indices_from(D_sae, k=1)]
correlation, p_value = spearmanr(gt_dists, sae_dists)
return max(0, correlation) # Clip to [0, 1]
def compute_geodesic_distances(manifold_points):
"""
Compute geodesic distances for different manifold types.
"""
if manifold_type == 'circular':
# For circle: geodesic distance = arc length
# If points are at angles θ_i, θ_j, geodesic distance = min(|θ_i - θ_j|, 2π - |θ_i - θ_j|)
angles = manifold_points['angles']
N = len(angles)
D = np.zeros((N, N))
for i in range(N):
for j in range(i+1, N):
diff = abs(angles[i] - angles[j])
D[i, j] = D[j, i] = min(diff, 2*np.pi - diff)
return D
elif manifold_type == 'spherical':
# For sphere: geodesic distance = great circle distance
# d(p, q) = arccos(p · q) where p, q are unit vectors
points = manifold_points['coords_3d'] # Already normalized
D = np.arccos(np.clip(points @ points.T, -1, 1))
return D
# ... other manifold typesGoal: Verify that topological properties (e.g., circular loops, spherical shells) are preserved.
def topological_alignment(gt_manifold_points, sae_latent_activations):
"""
Use persistent homology to compare topology of ground-truth vs. learned manifold.
"""
from ripser import ripser
from persim import plot_diagrams
# Compute persistence diagrams
dgm_gt = ripser(gt_manifold_points)['dgms']
dgm_sae = ripser(sae_latent_activations)['dgms']
# Compare H1 (1-dimensional holes, i.e., circular loops)
# For a circular manifold, should have 1 persistent loop
h1_gt = dgm_gt[1] # (birth, death) pairs for 1-cycles
h1_sae = dgm_sae[1]
# Measure bottleneck distance between persistence diagrams
from persim import bottleneck
distance_h1 = bottleneck(h1_gt, h1_sae)
# Lower distance → better topological preservation
topology_score = np.exp(-distance_h1)
return topology_scoreExtend MCC to handle manifold points:
def manifold_aware_mcc(gt_coefficients, sae_activations, manifold_metadata):
"""
Compute MCC for manifold features.
Challenge: Multiple SAE latents may jointly represent a manifold point.
Solution: Assign each ground-truth manifold point to the best-matching
SAE latent within the detected manifold cluster.
"""
scores = []
for manifold_meta in manifold_metadata:
m_start = manifold_meta['start_idx']
m_end = manifold_meta['end_idx']
# Ground-truth activations for this manifold
gt_manifold = (gt_coefficients[:, m_start:m_end] > 0).astype(int)
# Detect SAE latent cluster representing this manifold
sae_cluster = detect_best_matching_cluster(sae_activations, gt_manifold)
if sae_cluster is None:
# SAE failed to learn this manifold
scores.append(0.0)
continue
# For each ground-truth point, find best matching SAE latent in cluster
sae_cluster_activations = sae_activations[:, sae_cluster]
# Compute MCC for each GT point vs each SAE latent
mcc_matrix = np.zeros((gt_manifold.shape[1], len(sae_cluster)))
for i in range(gt_manifold.shape[1]):
for j in range(len(sae_cluster)):
mcc_matrix[i, j] = matthews_corrcoef(
gt_manifold[:, i],
sae_cluster_activations[:, j]
)
# Optimal bipartite matching (Hungarian algorithm)
from scipy.optimize import linear_sum_assignment
row_ind, col_ind = linear_sum_assignment(-mcc_matrix)
# Average MCC over matched pairs
manifold_mcc = mcc_matrix[row_ind, col_ind].mean()
scores.append(manifold_mcc)
return np.mean(scores)Goal: Assess whether SAE captures the intrinsic curvature of the manifold.
def curvature_estimation_accuracy(gt_manifold, sae_latent_group, sae_activations):
"""
Compare estimated curvature of learned manifold vs. ground truth.
"""
# For ground-truth manifold, compute intrinsic curvature
if gt_manifold['type'] == 'circular':
gt_curvature = 1.0 / gt_manifold['radius'] # Curvature of circle
elif gt_manifold['type'] == 'spherical':
gt_curvature = 1.0 / gt_manifold['radius'] # Gaussian curvature of sphere
# For learned SAE manifold, estimate curvature from data
group_activations = sae_activations[:, sae_latent_group]
# Fit quadratic form to local neighborhoods
estimated_curvature = estimate_local_curvature(group_activations)
# Compare
curvature_error = abs(gt_curvature - estimated_curvature) / gt_curvature
return 1 - curvature_error # Score in [0, 1]Combine all metrics into a unified benchmark:
class ManifoldSAEBenchmark:
def evaluate(self, sae, dataset_with_manifolds):
results = {
# Standard metrics (for independent features)
'independent_mcc': compute_mcc_independent_features(...),
'independent_f1': compute_f1_independent_features(...),
'variance_explained': compute_variance_explained(...),
'l0': compute_l0_sparsity(...),
# Manifold-specific metrics
'manifold_detection_rate': fraction_of_manifolds_detected(...),
'manifold_alignment_score': average_geodesic_preservation(...),
'topology_preservation_score': average_persistent_homology_match(...),
'manifold_mcc': manifold_aware_mcc(...),
'curvature_accuracy': average_curvature_estimation(...),
# Combined score
'overall_manifold_score': weighted_average([...]),
}
return resultsStandard SAEs learn independent latents. We propose architectural modifications that explicitly model manifold structure.
Idea: Organize latents into predefined groups, each representing a potential manifold.
class GroupedLatentSAE(nn.Module):
def __init__(self, d_hidden, n_latents, n_groups, latents_per_group):
super().__init__()
self.d_hidden = d_hidden
self.n_latents = n_latents
self.n_groups = n_groups
self.latents_per_group = latents_per_group
# Encoder: shared + group-specific
self.W_enc_shared = nn.Linear(d_hidden, d_hidden)
self.group_encoders = nn.ModuleList([
nn.Linear(d_hidden, latents_per_group) for _ in range(n_groups)
])
# Decoder: standard
self.W_dec = nn.Linear(n_latents, d_hidden, bias=False)
# Group gating: which groups are active?
self.group_gate = nn.Linear(d_hidden, n_groups)
def forward(self, x):
# Shared representation
h = F.relu(self.W_enc_shared(x))
# Group gating (select which groups are active)
group_logits = self.group_gate(x)
group_probs = torch.sigmoid(group_logits)
# Encode within each group
group_features = []
for i, enc in enumerate(self.group_encoders):
f_group = F.relu(enc(h)) * group_probs[:, i:i+1] # Gated
group_features.append(f_group)
# Concatenate all group features
f = torch.cat(group_features, dim=1)
# Reconstruct
x_hat = self.W_dec(f)
return f, x_hat
def loss(self, x):
f, x_hat = self.forward(x)
recon_loss = (x - x_hat).pow(2).mean()
# Sparsity: L1 on groups (encourage few active groups)
group_l0 = (f.reshape(-1, self.n_groups, self.latents_per_group).abs().sum(dim=2) > 0).float().sum(dim=1).mean()
# Sparsity: L1 within groups
feature_l1 = f.abs().mean()
return recon_loss + self.lambda_group * group_l0 + self.lambda_feature * feature_l1Advantages:
Idea: Explicitly parameterize latents as manifold coordinates.
class ManifoldParametricSAE(nn.Module):
def __init__(self, d_hidden, manifold_configs):
super().__init__()
self.d_hidden = d_hidden
self.manifolds = nn.ModuleList()
for config in manifold_configs:
if config['type'] == 'circular':
self.manifolds.append(CircularManifoldModule(d_hidden))
elif config['type'] == 'spherical':
self.manifolds.append(SphericalManifoldModule(d_hidden))
# etc.
# Also include standard independent latents
self.independent_sae = StandardSAE(d_hidden, n_independent_latents)
def forward(self, x):
# Independent features
f_ind, x_hat_ind = self.independent_sae(x)
# Manifold features
manifold_outputs = []
for manifold_module in self.manifolds:
manifold_out = manifold_module(x)
manifold_outputs.append(manifold_out)
# Combine reconstructions
x_hat = x_hat_ind + sum(m['reconstruction'] for m in manifold_outputs)
return {
'independent_features': f_ind,
'manifold_features': manifold_outputs,
'reconstruction': x_hat,
}
class CircularManifoldModule(nn.Module):
def __init__(self, d_hidden, embedding_dim=16):
super().__init__()
# Predict: is this circular feature active? If so, what angle?
self.gate = nn.Linear(d_hidden, 1)
self.angle_predictor = nn.Linear(d_hidden, 2) # (cos θ, sin θ)
self.magnitude_predictor = nn.Linear(d_hidden, 1)
# Decoder: from angle to reconstruction
# Parameterized as: reconstruction = magnitude * D_manifold @ [cos θ, sin θ]
self.D_manifold = nn.Linear(2, d_hidden, bias=False)
def forward(self, x):
# Gate: is this manifold active?
gate_logit = self.gate(x)
gate = torch.sigmoid(gate_logit)
# Angle: predict (cos θ, sin θ)
angle_raw = self.angle_predictor(x)
angle_normalized = F.normalize(angle_raw, dim=1) # Project to unit circle
# Magnitude
magnitude = F.relu(self.magnitude_predictor(x))
# Reconstruction contribution
recon_contribution = self.D_manifold(angle_normalized * magnitude * gate)
return {
'gate': gate,
'angle': angle_normalized, # (cos θ, sin θ)
'magnitude': magnitude,
'reconstruction': recon_contribution,
}Advantages:
Challenges:
Combines hierarchical structure with manifolds:
class HierarchicalManifoldSAE(nn.Module):
"""
Encode both hierarchical and manifold structure.
Example:
Time (root, independent)
/ \
Cyclical Linear
/ \ |
DayOfWeek Month Timestamp
(circular) (circular) (1D)
"""
def __init__(self, d_hidden, hierarchy_config):
super().__init__()
self.hierarchy = self._build_hierarchy(hierarchy_config)
def forward(self, x):
# Traverse hierarchy top-down
# Parent features gate children
results = self._traverse_hierarchy(x, self.hierarchy.root)
return results
def _traverse_hierarchy(self, x, node):
# Compute this node's feature
if node.type == 'independent':
gate, magnitude = node.encoder(x)
elif node.type == 'circular_manifold':
gate, angle, magnitude = node.manifold_encoder(x)
# If this node is inactive, all children are inactive
if gate < threshold:
return {'active': False, 'features': None}
# Otherwise, recurse to children
child_results = []
for child in node.children:
child_result = self._traverse_hierarchy(x, child)
if child_result['active']:
child_results.append(child_result)
return {
'active': True,
'node_id': node.id,
'features': {'gate': gate, ...},
'children': child_results,
}Idea: Don’t assume manifold structure—let the SAE discover it.
class AdaptiveLatentManifoldSAE(nn.Module):
"""
Learns to group latents into manifolds automatically.
Uses a learned adjacency matrix to define manifold neighborhoods.
"""
def __init__(self, d_hidden, n_latents):
super().__init__()
self.encoder = nn.Linear(d_hidden, n_latents)
self.decoder = nn.Linear(n_latents, d_hidden, bias=False)
# Learnable adjacency: which latents are neighbors on a manifold?
# A[i,j] = 1 if latents i and j are on the same manifold
self.adjacency_logits = nn.Parameter(torch.randn(n_latents, n_latents))
def forward(self, x):
f_raw = F.relu(self.encoder(x))
# Apply manifold smoothness: latents on same manifold should have similar activations
adjacency = torch.sigmoid(self.adjacency_logits)
adjacency = (adjacency + adjacency.T) / 2 # Symmetrize
# Smooth features according to adjacency
# (Manifold assumption: nearby points have similar representations)
f_smoothed = f_raw + 0.1 * (adjacency @ f_raw.T).T / adjacency.sum(dim=1, keepdim=True).T
x_hat = self.decoder(f_smoothed)
return f_raw, f_smoothed, x_hat
def loss(self, x):
f_raw, f_smoothed, x_hat = self.forward(x)
recon_loss = (x - x_hat).pow(2).mean()
sparsity_loss = f_raw.abs().mean()
# Manifold regularization: encourage adjacency to be sparse and block-diagonal
# (Each manifold forms a connected component)
adjacency = torch.sigmoid(self.adjacency_logits)
manifold_reg = adjacency.sum() / (self.n_latents ** 2) # Sparsity
return recon_loss + self.lambda_l1 * sparsity_loss + self.lambda_manifold * manifold_regAdvantages:
To test competing theories of neural representation, we generate synthetic models instantiating different hypotheses, then compare SAE behavior on these models to behavior on real LLMs.
| Hypothesis | Description | Prediction for SAEs |
|---|---|---|
| Linear Representation Hypothesis (LRH) | All concepts are 1D directions | SAEs perfectly recover all features with sufficient capacity |
| Manifold Hypothesis | Some concepts lie on low-dimensional manifolds | SAEs struggle unless manifold-aware; show specific failure modes |
| Superposition + Manifolds | Manifolds are in superposition | SAEs must handle both manifold geometry and feature interference |
| Hierarchical Manifolds | Manifolds participate in hierarchies | SAE feature recovery depends on correctly identifying parent features |
| Compositional Manifolds | Manifolds can be composed (e.g., S¹ × S¹) | SAEs either learn composition or tile it with many latents |
def generate_pure_lrh_model():
"""Original SynthSAEBench: all features are independent 1D."""
config = {
'N_features': 16384,
'd_hidden': 768,
'manifolds': [], # No manifolds
'superposition': 0.15,
'hierarchy': True,
'correlations': True,
}
return SynthSAEBenchDataset(config)Expected SAE behavior:
def generate_manifold_hypothesis_model():
"""Replace some independent features with circular and spherical manifolds."""
config = {
'N_independent': 12000,
'd_hidden': 768,
'manifolds': [
{'type': 'circular', 'n_points': 32, 'embedding_dim': 16, 'label': 'temporal_cycle'},
{'type': 'circular', 'n_points': 32, 'embedding_dim': 16, 'label': 'periodic_pattern'},
{'type': 'spherical', 'n_points': 64, 'embedding_dim': 20, 'label': 'direction'},
{'type': 'toroidal', 'n_points': (8, 8), 'embedding_dim': 24, 'label': 'hour_day'},
# ... more manifolds
],
'superposition': 0.15,
'hierarchy': False, # Isolate manifold effects first
'correlations': True,
}
return ManifoldSynthSAEBenchDataset(config)Expected SAE behavior:
def generate_superposition_manifolds_model(superposition_level):
"""Test interaction between manifold structure and superposition."""
config = {
'N_independent': 10000,
'manifolds': [
{'type': 'circular', 'n_points': 32, ...},
# ... 10 manifolds
],
'superposition': superposition_level, # Vary: 0.05, 0.15, 0.25
'hierarchy': False,
'correlations': True,
}
return ManifoldSynthSAEBenchDataset(config)Experiment: Sweep superposition from 0.05 to 0.30.
Expected behavior:
def generate_hierarchical_manifolds_model():
"""Manifolds as nodes in hierarchy."""
config = {
'N_independent': 8000,
'manifolds': [
{'type': 'circular', 'parent': 'temporal', ...},
{'type': 'circular', 'parent': 'temporal', ...},
],
'hierarchy': {
'root': ['temporal', 'spatial', 'abstract'],
'temporal': {
'children': ['day_of_week_manifold', 'month_manifold'],
'mutex': False,
},
# ...
},
'superposition': 0.15,
}
return HierarchicalManifoldDataset(config)Expected behavior:
def generate_compositional_manifolds_model():
"""Test composed manifolds like S¹ × S¹ (torus)."""
config = {
'N_independent': 12000,
'manifolds': [
# Atomic manifolds
{'type': 'circular', 'id': 'day'},
{'type': 'circular', 'id': 'hour'},
# Composed manifold
{'type': 'toroidal', 'composition': ('day', 'hour'), ...},
],
'composition_type': 'product', # or 'sum', 'concat'
}
return CompositionalManifoldDataset(config)Expected behavior:
Compare training dynamics on synthetic models to known LLM SAE behaviors:
def compare_training_dynamics(synthetic_model, sae_architecture):
"""
Track metrics during training:
- Loss curves
- Feature emergence patterns
- Dead neuron rates
- Manifold detection over time
"""
results = {
'steps': [],
'loss': [],
'mcc_independent': [],
'manifold_alignment': [],
'dead_neurons': [],
'feature_splitting_events': [],
}
for step in range(num_training_steps):
# Train step
batch = synthetic_model.sample_batch(batch_size)
loss = sae.train_step(batch)
# Evaluate
if step % eval_interval == 0:
eval_results = evaluate_on_held_out(sae, synthetic_model)
results['steps'].append(step)
results['loss'].append(loss)
results['mcc_independent'].append(eval_results['mcc'])
results['manifold_alignment'].append(eval_results['manifold_score'])
results['dead_neurons'].append(count_dead_neurons(sae))
# Detect feature splitting (a known phenomenon in LLM SAEs)
splitting = detect_feature_splitting(sae, previous_decoder)
results['feature_splitting_events'].append(splitting)
previous_decoder = sae.W_dec.clone()
return resultsKey phenomena to reproduce from LLM SAEs:
Feature splitting (Chanin et al., 2024): Single ground-truth feature learned by multiple SAE latents
Feature absorption (Chanin et al., 2024): Multiple ground-truth features collapse to one SAE latent
Dead neurons: Latents that never activate
Scaling laws: How metrics scale with latent count N
Formalize predictions that distinguish hypotheses:
| Prediction | LRH Model | Manifold Model | Test |
|---|---|---|---|
| Scaling exponent | MCC ~ N^α, α ≈ 0.8 | MCC ~ N^β, β < α due to manifold tiling | Fit power laws |
| Latents per feature | 1.2 latents per GT feature | 5+ latents per manifold | Count matched latents |
| Topology preservation | N/A (no topology) | PH score > 0.7 for manifold-aware SAE | Persistent homology |
| Dead neuron rate | ~10% at convergence | ~20% for standard SAE on manifolds | Count inactive latents |
| Curvature estimation | N/A | Error < 20% for manifold-aware SAE | Compare estimated vs. true curvature |
Validation against LLMs:
Compare these predictions to known behaviors in real LLM SAEs:
If manifolds are fundamental, SAEs trained on different models of the same hypothesis should learn similar structures.
def test_cross_model_consistency():
"""
Train 5 independent synthetic models with same manifold structure.
Train SAE on each.
Check if learned manifolds are consistent.
"""
models = [generate_manifold_hypothesis_model(seed=i) for i in range(5)]
saes = [train_sae(model) for model in models]
# Extract learned manifold structures
manifolds_per_sae = [detect_manifolds(sae) for sae in saes]
# Measure consistency: do all SAEs detect similar manifold geometries?
consistency_score = compute_manifold_consistency(manifolds_per_sae)
# Expected: High consistency if manifolds are real structure
# Low consistency if manifolds are spurious
return consistency_scoreTest if manifold-aware SAEs trained on synthetic data transfer to real LLMs:
def test_manifold_transfer():
"""
1. Train manifold-aware SAE on synthetic model with known circular features
2. Fine-tune on real LLM activations
3. Check if SAE preferentially learns known circular features (days, months)
"""
# Pre-train on synthetic
synthetic_model = generate_manifold_hypothesis_model()
sae = ManifoldParametricSAE(...)
pretrain_sae(sae, synthetic_model)
# Fine-tune on real LLM
llm = load_llm('gpt2-small')
llm_activations = collect_activations(llm, dataset)
finetune_sae(sae, llm_activations)
# Evaluate: Does SAE find known circular features?
probes = {
'day_of_week': probe_for_circular_feature(sae, days_dataset),
'month': probe_for_circular_feature(sae, months_dataset),
}
# Compare to SAE trained from scratch on LLM
baseline_sae = StandardSAE(...)
train_sae(baseline_sae, llm_activations)
baseline_probes = probe_baseline(baseline_sae)
return {
'manifold_aware': probes,
'baseline': baseline_probes,
'improvement': probes - baseline_probes,
}Systematically ablate components to understand necessity:
def ablation_study():
results = {}
# Baseline: LRH model
results['lrh_only'] = train_and_evaluate(generate_pure_lrh_model())
# Add manifolds one at a time
results['lrh + 1_circular'] = train_and_evaluate(generate_model_with_n_manifolds(1, 'circular'))
results['lrh + 5_circular'] = train_and_evaluate(generate_model_with_n_manifolds(5, 'circular'))
results['lrh + 10_circular'] = train_and_evaluate(generate_model_with_n_manifolds(10, 'circular'))
# Add different manifold types
results['lrh + spherical'] = train_and_evaluate(generate_model_with_manifold_type('spherical'))
results['lrh + toroidal'] = train_and_evaluate(generate_model_with_manifold_type('toroidal'))
# Vary manifold parameters
for embedding_dim in [8, 16, 32, 64]:
results[f'embedding_dim_{embedding_dim}'] = train_and_evaluate(
generate_model_with_embedding_dim(embedding_dim)
)
return resultsPerform interventions to test whether learned manifolds are causal:
def test_manifold_causality(sae, manifold_id):
"""
If SAE learned a circular manifold representing 'day of week':
1. Intervene: Set manifold to 'Monday' encoding
2. Decode to activation space
3. Feed to downstream tasks
4. Check if behavior changes consistently with 'Monday'
"""
# Identify learned manifold for 'day of week'
dow_manifold = sae.manifolds[manifold_id]
results = {}
for day in ['Monday', 'Tuesday', ..., 'Sunday']:
# Set manifold to represent this day
manifold_encoding = encode_day_on_manifold(day, dow_manifold)
# Decode to activation
intervened_activation = sae.decode(manifold_encoding)
# Test on downstream task (e.g., next-word prediction)
llm_output = llm.forward(input_with_activation=intervened_activation)
# Check if output is day-consistent
day_consistency = measure_day_consistency(llm_output, expected_day=day)
results[day] = day_consistency
# Expected: High consistency → manifold is causally meaningful
return resultsReproduce known LLM phenomena in synthetic models:
def validate_against_llm_phenomena():
"""
Known phenomena from LLM SAE literature:
1. Calendar features form geometric structure (Leask et al., 2024)
2. Circular features for days/months (Engels et al., 2025)
3. Hierarchical organization (Li et al., 2025)
Test: Do our synthetic models with manifolds reproduce these?
"""
# Generate synthetic model with calendar manifolds
synthetic_model = generate_calendar_manifold_model()
# Train SAE
sae = train_sae(synthetic_model)
# Test 1: Geometric structure (Leask et al., 2024)
# Expected: Day and month features form 2D structure
calendar_geometry = analyze_calendar_feature_geometry(sae)
assert calendar_geometry['dimensionality'] == 2
# Test 2: Circular features (Engels et al., 2025)
# Expected: Detect circular structure
circularity = measure_circularity(sae.get_calendar_features())
assert circularity > 0.8
# Test 3: Hierarchical organization
# Expected: Temporal features cluster together
modularity = measure_spatial_modularity(sae.decoder_weights)
assert modularity['temporal_lobe_score'] > 0.7
return {
'calendar_geometry': calendar_geometry,
'circularity': circularity,
'modularity': modularity,
}Week 1: Manifold generation
Week 2: Hybrid dictionary construction
Week 3: Correlations and hierarchy
Milestone 1: Generate and validate ManifoldSynthSAEBench dataset
Week 4: Geometric metrics
Week 5: Topological metrics
Milestone 2: Complete evaluation suite with validated metrics
Week 6: Grouped Latent SAE
Week 7: Manifold-Parametric SAE
Week 8: Hierarchical and Adaptive SAEs
Milestone 3: Validated manifold-aware SAE architectures
Week 9: Generate hypothesis-specific models
Week 10: Training dynamics experiments
Week 11: Formalize predictions
Milestone 4: Completed hypothesis testing with formal predictions
Week 12: Cross-model consistency
Week 13: LLM transfer and interventions
Week 14: Reproduce LLM phenomena
Milestone 5: Validation complete, ready for publication
Week 15: Computational optimization
Week 16: Large-scale experiments
Milestone 6: Scalable implementation ready for large-scale experiments
Week 17-18: Draft paper
Week 19: Experiments for reviewers
Week 20: Submission and release
Milestone 7: Paper submitted, code and data public
Based on prior work and theoretical considerations, we predict:
| Metric | Standard SAE on Manifolds | Manifold-Aware SAE | LRH Baseline |
|---|---|---|---|
| Independent MCC | 0.70 ± 0.05 | 0.72 ± 0.05 | 0.75 ± 0.03 |
| Manifold Alignment | 0.25 ± 0.10 | 0.75 ± 0.08 | N/A |
| Topology Preservation | 0.15 ± 0.10 | 0.80 ± 0.10 | N/A |
| Latents per Manifold | 15 ± 5 | 3 ± 1 | N/A |
| Dead Neuron Rate | 25% ± 5% | 12% ± 3% | 10% ± 2% |
| Curvature Error | 80% ± 20% | 15% ± 10% | N/A |
Finding 1: Manifold tiling pathology Standard SAEs will “tile” circular manifolds with many latents, each representing a small arc. This is inefficient but locally optimal for reconstruction loss.
Finding 2: Scaling breakdown Michaud et al.’s predicted pathological scaling will be empirically confirmed: in high-superposition + manifold regime, number of discovered features grows sublinearly with latent count.
Finding 3: Architecture matters Manifold-parametric SAEs will achieve 3-5× better manifold alignment than standard SAEs, proving that architectural inductive biases are crucial.
Finding 4: Hierarchy-manifold interaction Hierarchical manifolds will show error propagation: missing parent features → catastrophic failure to recover child manifolds.
Finding 5: LLM correspondence Training dynamics on synthetic manifold models will quantitatively match known LLM SAE behaviors (feature splitting rates, dead neuron curves), validating the benchmark’s realism.
Implication 1: Standard SAEs are insufficient If manifold features are common in LLMs (as Engels et al. suggest), current SAE architectures are fundamentally limited. The field needs manifold-aware alternatives.
Implication 2: Evaluation must evolve Point-wise feature recovery (MCC, F1) misses geometric structure. Interpretability research should adopt manifold-aware evaluation.
Implication 3: Synthetic benchmarks are crucial Without ground truth, we can’t measure manifold recovery. SynthSAEBench-style approaches are essential for rigorous progress.
Implication 4: Representation hypotheses are testable By generating synthetic models instantiating different hypotheses, we can empirically distinguish between theories of neural representation.
Implication 5: Transfer learning potential If manifold-aware SAEs pre-trained on synthetic data transfer to real LLMs, this opens a new paradigm: use synthetic models to develop better interpretability tools, then apply to real models.
Optimal manifold parameterization: Should manifolds be discretized (as proposed) or continuous? What’s the right discretization density?
Manifold discovery: Can SAEs automatically discover manifold type (circular vs. spherical vs. hyperbolic) from data, or must it be specified?
Compositional manifolds: What’s the best way to represent composed manifolds (S¹ × S¹)? Factorized vs. monolithic?
Manifold superposition: Can manifolds themselves be in superposition (e.g., “day of week” and “zodiac sign” both circular, overlapping)?
Scaling to high dimensions: How do evaluation metrics scale to higher-dimensional manifolds (e.g., 10D hyperbolic spaces)?
Noise vs. curvature: How to distinguish genuine manifold curvature from noise or sampling artifacts?
Extension 1: Other manifold types
Extension 2: Temporal manifolds
Extension 3: Multi-modal manifolds
Extension 4: Causal manifolds
Extension 5: Neuroscience connection
Goal: Establish manifold-aware interpretability as a standard paradigm.
Success metrics:
Broader impact:
This research plan extends SynthSAEBench to systematically study how sparse autoencoders handle feature manifolds—a critical but understudied aspect of neural representation. By generating synthetic data with known manifold ground truth, developing manifold-aware evaluation metrics, designing architectural innovations, and testing competing representation hypotheses, we can rigorously assess when and how SAEs recover geometric structure.
The key insight is that features are not always one-dimensional. Temporal cycles, spatial directions, and compositional concepts naturally lie on manifolds. If interpretability tools ignore this geometry, they will fail to capture how models truly represent knowledge. By grounding our investigation in the SynthSAEBench framework—where we control ground truth and maintain scientific rigor—we can make definitive progress on understanding manifold representations in neural networks.
Our approach is not merely theoretical. We make concrete, testable predictions that can be validated against real LLM behavior. We propose practical architectures that can be deployed on real models. And we establish evaluation protocols that the community can adopt. This work has the potential to reshape how we think about feature learning, representation geometry, and the fundamental units of neural computation.
The path forward is clear: implement, evaluate, iterate, and validate. With systematic experiments on carefully designed synthetic models, we can finally answer the question: Are manifolds the missing piece in sparse autoencoder interpretability?
Key Papers on Feature Manifolds:
Engels, J. E., Michaud, E. J., Liao, I., Gurnee, W., & Tegmark, M. (2025). “Not All Language Model Features Are One-Dimensionally Linear.” ICLR 2025. https://arxiv.org/abs/2405.14860
Li, Y., Michaud, E. J., Baek, D. D., Engels, J., Sun, X., & Tegmark, M. (2025). “The Geometry of Concepts: Sparse Autoencoder Feature Structure.” Entropy, 27(4), 344. https://arxiv.org/abs/2410.19750
Olah, C., & Batson, J. (2023). “Feature Manifold Toy Model.” Transformer Circuits Thread, May Update. https://transformer-circuits.pub/2023/may-update/index.html#feature-manifolds
Michaud, E. J., et al. (2024). “Understanding Sparse Autoencoder Scaling in the Presence of Feature Manifolds.” https://arxiv.org/abs/2509.02565
Key Papers on SAEs:
Bricken, T., et al. (2023). “Towards Monosemanticity: Decomposing Language Models with Dictionary Learning.” https://transformer-circuits.pub/2023/monosemantic-features
Gao, L., et al. (2025). “Scaling and Evaluating Sparse Autoencoders.” ICLR 2025. https://cdn.openai.com/papers/sparse-autoencoders.pdf
Cunningham, H., et al. (2023). “Sparse Autoencoders Find Highly Interpretable Features in Language Models.” https://arxiv.org/abs/2309.08600
SynthSAEBench:
Geometric Representation:
Park, K., Choe, Y. J., & Veitch, V. (2024). “The Linear Representation Hypothesis and the Geometry of Large Language Models.” ICML 2024.
Leask, P., et al. (2024). “Calendar Feature Geometry in GPT-2 Layer 8 Residual Stream SAEs.” https://www.lesswrong.com/posts/WsPyunwpXYCM2iN6t
manifold-sae-bench/
├── data_generation/
│ ├── manifolds.py # Manifold generation (circular, spherical, etc.)
│ ├── hybrid_dictionary.py # Construct dictionary with manifolds
│ ├── activation_sampler.py # Sample activations with manifolds
│ └── hierarchy.py # Hierarchical manifolds
├── models/
│ ├── standard_sae.py # Baseline L1 SAE
│ ├── grouped_latent_sae.py # GL-SAE
│ ├── manifold_parametric_sae.py # MP-SAE
│ ├── hierarchical_sae.py # HM-SAE
│ └── adaptive_sae.py # ALM-SAE
├── evaluation/
│ ├── geometric_metrics.py # Manifold alignment, geodesic preservation
│ ├── topological_metrics.py # Persistent homology
│ ├── standard_metrics.py # MCC, F1, variance explained
│ └── benchmark.py # Unified evaluation suite
├── experiments/
│ ├── hypothesis_testing.py # Generate models per hypothesis
│ ├── training_dynamics.py # Track emergence, splitting, etc.
│ ├── scaling_laws.py # Scaling experiments
│ └── llm_validation.py # Compare to real LLM behaviors
├── visualization/
│ ├── manifold_plots.py # Visualize learned manifolds
│ ├── geometry_plots.py # PCA, tSNE of decoder weights
│ └── training_curves.py # Loss, metrics over training
└── notebooks/
├── 01_data_exploration.ipynb
├── 02_model_training.ipynb
├── 03_evaluation.ipynb
└── 04_hypothesis_testing.ipynb
| Symbol | Meaning |
|---|---|
| d | Hidden dimension (768) |
| N | Number of features (16,384) |
| k | Intrinsic manifold dimension (1 for circle, 2 for sphere, etc.) |
| D | Feature dictionary, D ∈ ℝ^{d×N} |
| x | Activation vector, x ∈ ℝ^d |
| c | Ground-truth feature coefficients, c ∈ ℝ^N |
| f | SAE latent activations, f ∈ ℝ^{N_latents} |
| x̂ | Reconstructed activation |
| ρ_mm | Mean max cosine similarity (superposition measure) |
| L0 | Number of active features (sparsity) |
| S^k | k-dimensional sphere |
| θ, φ | Angular coordinates on manifolds |
| Σ | Correlation matrix |
| F | Low-rank factor matrix for correlations |
End of Research Plan
This document provides a comprehensive, technically rigorous framework for extending SynthSAEBench to study manifold-aware sparse autoencoders. Implementation can proceed according to the phased plan, with each milestone building toward a complete system for testing representation hypotheses and advancing interpretability research.