Welcome to The Nonlinear Library, where we use Text-to-Speech software to convert the best writing from the Rationalist and EA communities into audio. This is: BatchTopK: A Simple Improvement for TopK-SAEs, published by Bart Bussmann on July 20, 2024 on The AI Alignment Forum.
Work done in Neel Nanda's stream of MATS 6.0.
Epistemic status: Tried this on a single sweep and seems to work well, but it might definitely be a fluke of something particular to our
implementation or experimental set-up. As there are also some theoretical reasons to expect this technique to work (adaptive sparsity), it seems probable that for many TopK SAE set-ups it could be a good idea to also try BatchTopK. As we're not planning to investigate this much further and it might be useful to others, we're just sharing what we've found so far.
TL;DR: Instead of taking the TopK feature activations per token during training, taking the Top(K*batch_size) for every batch seems to improve SAE performance. During inference, this activation can be replaced with a single global threshold for all features.
Introduction
Sparse autoencoders (SAEs) have emerged as a promising tool for interpreting the internal representations of large language models. By learning to reconstruct activations using only a small number of features, SAEs can extract monosemantic concepts from the representations inside transformer models. Recently, OpenAI published
a paper exploring the use of TopK activation functions in SAEs. This approach directly enforces sparsity by only keeping the K largest activations per sample.
While effective, TopK forces every token to use exactly k features, which is likely suboptimal. We came up with a simple modification that solves this and seems to improve its performance.
BatchTopK
Standard TopK SAEs apply the TopK operation independently to each sample in a batch. For a target sparsity of K, this means exactly K features are activated for every sample.
BatchTopK instead applies the TopK operation across the entire flattened batch:
1. Flatten all feature activations across the batch
2. Take the top (K * batch_size) activations
3. Reshape back to the original batch shape
This allows more flexibility in how many features activate per sample, while still maintaining an average of K active features across the batch.
Experimental Set-Up
For both the TopK and the BatchTopK SAEs we train a sweep with the following hyperparameters:
Model: gpt2-small
Site: layer 8 resid_pre
Batch size: 4096
Optimizer: Adam (lr=3e-4, beta1 = 0.9, beta2=0.99)
Number of tokens: 1e9
Expansion factor: [4, 8, 16, 32]
Target L0 (k): [16, 32, 64]
As in the OpenAI paper, the input gets normalized before feeding it into the SAE and calculating the reconstruction loss. We also use the same auxiliary loss function for dead features (features that didn't activate for 5 batches) that calculates the loss on the residual using the top 512 dead features per sample and gets multiplied by a factor 1/32.
Results
For a fixed number of active features (L0=32) the BatchTopK SAE has a lower normalized MSE than the TopK SAE and less downstream loss degradation across different dictionary sizes. Similarly, for fixed dictionary size (12288) BatchTopK outperforms TopK for different values of k.
Our main hypothesis for the improved performance is thanks to adaptive sparsity: some samples contain more highly activating features than others. Let's have look at the distribution of number of active samples for the BatchTopK model.
The BatchTopK model indeed makes use of its possibility to use different sparsities for different inputs. We suspect that the weird peak on the left side are the feature activations on BOS-tokens, given that its frequency is very close to 1 in 128, which is the sequence length. This serves as a great example of why BatchTopK might outperform TopK. At the BOS-token, a sequence has very little information yet, but the TopK SAE still activates 32 features.
The BatchTopK model "saves" th...
view more