Features and LLM Interpretability

The concept of “features” within the context of Transformer models, and Large Language Models (LLMs) more broadly, lies at the heart of much current research in AI mechanistic explainability. Traditional machine learning often defines features as explicitly engineered attributes of the input data, carefully chosen to highlight relevant information. However, in the context of LLMs, features emerge from the training process itself, represented by the activations of neurons across the network’s layers. These are not pre-defined; rather, they are learned representations that capture patterns, relationships, and concepts present in the training data. Understanding these emergent features is crucial for interpreting how the model arrives at its conclusions, a core goal of mechanistic explainability. We need to move beyond simply observing what the model does to understanding why it does it, and features are the primary vehicle for this understanding. The challenge is deciphering what specific concepts or patterns each feature encodes, and how those features interact to produce the model’s output.

At a foundational level, the Transformer architecture is built upon the attention mechanism, which allows the model to weigh the importance of different parts of the input sequence when processing each token (words in the prompt, or additional input to the prompt). This weighting is itself a form of feature extraction – the attention weights can be considered features that highlight which tokens are most relevant to a given context. However, the true complexity of features lies within the hidden states of the network. Each layer of the Transformer applies a series of linear transformations and non-linear activations to the input, progressively transforming the initial token embeddings into higher-level representations. These hidden states (internal neuron layers), particularly those in the later layers, encode increasingly abstract and complex features. Mathematically, we can represent a layer in the Transformer as follows:

$ h_{l} = \text{LayerNorm}(x_{l} + \text{Attention}(x_{l}) + \text{FFN}(x_{l}))$

Here, $ x_{l}$ represents the input to layer l, and $ h_{l}$ is the output of layer l. $ \text{Attention}$ represents the multi-head attention mechanism, and $\text{FFN}$ represents the feed-forward network. $\text{LayerNorm}$ is layer normalization, a technique to stabilize training. Crucially, the output $h_{l}$ is not just a transformed version of the input; it’s a new representation that hopefully captures more relevant information for the downstream task. Each neuron within $h_{l}$ contributes to this representation, and its activation value can be considered a feature. The key is that the weights within the $\text{Attention}$ and $\text{FFN}$ components define how the input is transformed, and these weights are learned during training to optimize performance on the training data. This learning process is what gives rise to the emergent features. From a mechanistic perspective, we’re interested in understanding what patterns in the input consistently cause specific neurons to activate, and what effect that activation has on the subsequent computations. For instance, a neuron might consistently activate when processing tokens related to “historical figures,” or “scientific concepts,” or even more subtle patterns like “questions requiring reasoning about causality.” The strength of the activation is then a measure of how strongly that concept is present in the current context. The model learns to represent these concepts without being explicitly told what they are; it discovers them through statistical patterns in the data. Furthermore, the multi-head attention mechanism allows the model to capture different aspects of the same concept, or different concepts simultaneously, by attending to different parts of the input training. Each attention head can be thought of as learning a different feature detector, adding to the richness of the model’s internal representation.

The connection between the emergent features and the concepts in the training data is not always a simple one-to-one mapping. A single concept can be represented by a combination of features, and a single feature can contribute to the representation of multiple concepts. This is because LLMs are not simply memorizing facts; they are learning to model the underlying structure of language and the relationships between concepts. To illustrate this, consider the concept of “capital cities.” The model might not have a single neuron dedicated to representing “capital cities.” Instead, this concept might be encoded by a combination of features related to: (1) geographical locations, (2) political entities, (3) administrative centers, and (4) population density. These features might be distributed across multiple layers and attention heads, and their combined activation pattern would indicate the presence of a capital city. From an explainability standpoint, identifying these feature combinations and understanding their contributions to the model’s reasoning is a major challenge. Recent research has focused on techniques like “feature attribution” and “feature visualization” to shed light on these relationships. Feature attribution methods aim to assign a score to each input token (or feature) based on its contribution to the model’s output. This can help identify which parts of the input were most important for the model’s decision. Feature visualization methods, on the other hand, aim to understand what patterns in the input cause specific neurons to activate. This can be done by generating synthetic inputs that maximize the activation of a particular neuron, or by analyzing the activations of a neuron across a large dataset of inputs. However, these methods are not perfect. They can be sensitive to noise and can sometimes produce misleading results. Therefore, it’s important to use them in conjunction with other explainability techniques and to carefully validate the results.

Finally, it’s important to recognize that the features learned by LLMs are not static. They evolve over time as the model is exposed to new data and fine-tuned for specific tasks. This means that the explainability techniques we use to understand the model’s behavior must also be dynamic and adapt to the changing features. Consider the mathematical representation of fine-tuning. Let $\theta$ represent the model’s weights. During pre-training, the model learns weights $\theta_0$ to minimize a loss function $L_0$ on a large corpus of text:

$ \theta_0 = \arg\min_{\theta} L_0(D_{pretrain}, \theta)$

Where $D_{pretrain}$ is the pre-training dataset. Fine-tuning then involves updating these weights to minimize a loss function $L_f$ on a smaller, task-specific dataset $D_{finetune}$:

$ \theta_f = \arg\min_{\theta} L_f(D_{finetune}, \theta)$

The resulting weights $\theta_f$ will be different from $\theta_0$, and the features encoded by the model will also be different. This means that the explainability analysis we performed on the pre-trained model might not be valid for the fine-tuned model. Therefore, it’s crucial to repeat the explainability analysis after each fine-tuning step to ensure that we have an accurate understanding of the model’s behavior. Furthermore, the transfer of knowledge between tasks is also mediated by these features. When fine-tuning a model on a new task, the model is not starting from scratch. It is leveraging the features it learned during pre-training to accelerate learning and improve performance. Identifying which pre-trained features are most relevant for a particular task can provide valuable insights into the model’s generalization capabilities and its ability to transfer knowledge. In conclusion, understanding the features learned by LLMs is essential for achieving true AI explainability. It requires a combination of theoretical analysis, empirical experimentation, and the development of new explainability techniques that can capture the dynamic and complex nature of these features.