Link to original article
Welcome to The Nonlinear Library, where we use Text-to-Speech software to convert the best writing from the Rationalist and EA communities into audio. This is: Fixing Feature Suppression in SAEs, published by Benjamin Wright on February 16, 2024 on LessWrong.
Produced as part of the ML Alignment Theory Scholars Program - Winter 2023-24 Cohort as part of Lee Sharkey's stream.
Sparse autoencoders are a method of resolving superposition by recovering linearly encoded "features" inside activations. Unfortunately, despite the great recent success of SAEs at extracting human interpretable features, they fail to perfectly reconstruct the activations. For instance, Cunningham et al. (2023) note that replacing the residual stream of layer 2 of Pythia-70m with the reconstructed output of an SAE increased the perplexity of the model on the Pile from 25 to 40.
It is important for interpretability that the features we extract accurately represent what the model is doing.
In this post, I show how and why SAEs have a reconstruction gap due to 'feature suppression'. Then, I look at a few ways to fix this while maintaining SAEs interpretability. By modifying and fine-tuning a pre-trained SAE, we achieve a 9% decrease in mean square error and a 24% reduction in the perplexity increase upon patching activations into the LLM.
Finally, I connect a theoretical example to the observed amounts of feature suppression in Pythia 70m, confirming that features are suppressed primarily based on the strength of their activations, not on their frequency of activation.
Feature Suppression
The architecture of an SAE is:
f(xx)=ReLU(Wexx+bbe)
yy=Wdf(xx)+bbd
The loss function usually combines a MSE reconstruction loss with a sparsity term, like L(xx,f(xx),yy)=||yyxx||2/d+c|f(xx)|, where d is the dimension of x. When training the SAE on this loss, the decoder's weight matrix is fixed to have unit norm for each feature (column).
The reason for feature suppression is simple: The training loss has two terms, only one of which is reconstruction. Therefore, reconstruction isn't perfect. In particular, the loss function pushes for smaller f(xx) values, leading to suppressed features and worse reconstruction.
An illustrative example of feature suppression
As an example, consider the trivial case where there is only one binary feature in one dimension. That is, xx=1 with probability p and xx=0 otherwise. Then, ideally the optimal SAE would extract feature activations of f(x){0,1} and have a decoder with Wd=1.
However, if we were to train an SAE optimizing the loss function L(xx,f(xx),yy)=||yyxx||2+c|f(xx)|, we get a different result. If we ignore bias terms for simplicity of argument, and say that the encoder outputs feature activation aa if xx=1 and 0 otherwise, then the optimization problem becomes:
aa=argminpL(1,aa,aa)+(1p)L(0,0,0)=argmin(aa1)2+|aa|c=argminaa2+(c2)aa+1
aa=1c2
Therefore the feature is scaled by a factor of 1c/2 compared to optimal. This is an example of feature suppression.
If we allow the ground truth feature to have an activation strength g upon activation and dimension d, this factor becomes:
aa=1cd2g
In other words, instead of having the ground truth activation g, the SAE learns an activation of gcd2, a constant amount less. Features with activation strengths below cd2 would be completely killed off by the SAE.
Feature suppression is a significant problem in current SAEs
To experimentally verify that feature suppression affects SAEs, we first trained SAEs on the residual stream output of each layer of Pythia-70m with an L1 sparsity penalty (coefficient 2e-3) on 6 epochs of 100 million tokens of OpenWebText, with batch size 64 and learning rate 1e-3, resulting in roughly 13-80 feature activations per token. The residual stream of Pythia-70m had a dimension size of 512 and we used a dictionary size of 2048, for a four times scale up.
If feature suppression had a noticeable effect, we'd see that the SAE reconstructions had noticeably smaller L2 norm...
view more