Over the past few weeks, I've been working on training a sparse autoencoder on Mistral's 7B instruct model, and I finally finished! In this post, I want to go over how I trained it, some of the cool stuff I've found so far, and how you can try it out.
Training large sparse autoencoders
Sparse autoencoders(SAEs) are, in theory, extremely simple. They basically just consist of two linear layers, an encoder and a decoder, which expand the dimension of the residual activations from another model and then shrink them back to their original size. In order for SAEs to be useful, however, these intermediate latent activations need to be sparse (meaning only a few light up for a given input), which, at least in my experience, is a problem that becomes increasingly difficult as one scales up due to the incentive neurons have to never fire. This incentive is really a secondary effect of what you're actually trying to incentivize: neuron sparsity. You can imagine a scenario where 95% of your neurons never light up; the model is technically very sparse, but it's also incredibly wasteful and inefficient.
This is still an active problem, as there's not yet a singular solution that's agreed upon. In Scaling Monosemanticity: Extracting Interpretable Features from Claude 3 Sonnet (Templeton et al.), a typical L1 norm is used, as is often the case when trying to induce sparsity, but it's then multiplied by the L2 norm of the decoder weights. Another paper, Scaling and evaluating sparse autoencoders (Gao et al.), goes in a completely different direction, using a TopK activation (meaning only the k largest values are used—the rest are set to 0) and a special AuxK penalty that tries to encourage the non-activating latents to account for the reconstruction error from those that do activate. There are also many more possible strategies that one could use to address the dead neurons, such as refreshing the weights of neurons that appear dead or even more complex things like "Ghost Grads."
I ended up adopting the TopK approach from the OpenAI paper, though I think I implemented and experimented with just about every strategy listed above, each with various levels of success. You can find a rough training codebase and a link to the weights here.
Current approach to analysis
Because the latent dimension (131,072) of my model is quite large, I had to figure out a more sophisticated way to identify which features correspond to which topics/ideas than just going through each one until I found what I was looking for. Anecdotally, I've found that going from feature to topic/idea is more challenging than starting with the topic/idea I want and finding which feature(s) make it up. This is evidenced by my first approach to interpreting the features, which consisted of feeding the model huge amounts of data and saving the top ~10 tokens (and their surrounding context) that made that given feature light up the most. I found that this largely gave unintelligible results, which I theorize is due to the fact that many features are quite subtle, so the meaning of what they encode is hard to figure out when only given small chunks of text.
I then tried a more refined approach where I sorted my dataset into positive and negative examples (positive examples theoretically activating the features I wanted to find, while the negative examples wouldn't). I then positively weighted features that activated during positive examples, and negatively weighted those that activated during negative ones. This gave substantially better results, though still not amazing. For example, I was trying to find a "Golden Gate Bridge" feature, so I used big chunks of its Wikipedia page for the positive examples. Unfortunately, the features I found aligned much closer to "American history" than "Golden Gate Bridge," which makes a lot of sense given most of the Wikipedia page would be classified under "American history." I should theoretically approach better results if I carefully selected more and more data that would focus on the Golden Gate Bridge itself, but that would probably be fairly time consuming for each feature I want to find. Either way, I found a better method.
The last and final method I tried was used in the aforementioned Scaling Monosemanticity paper from Anthropic, and is remarkably simple: just prompt the unaltered model to generate the kind of output the target feature might activate for, and record which features, averaging over the entire output, score highest. This works remarkably well and doesn't have the problem of finding features that are too broad. Of course, there is a sort of bias towards output that the model ordinarily produces. For example, while trying to find the "bitterness" feature shown below, I had to do some prompt gymnastics to get the model to produce angry-sounding output in the first place. In a safety context, this is a big issue, because even though the model may not ordinarily produce harmful output (as a result of finetuning or something similar), it is still very much capable of it (evidenced through some "jailbreaks"). This means that unless the researcher is aware of these "jailbreaks," this method wouldn't be of any use since the model would never generate the desired output in the first place. To summarize, the process of identifying features is not completely solved, though this last method seems to be sufficient in most toy cases, like the ones I give below.
Cool features I've found so far
From the small amount of messy analysis I've done, I have already found many interesting properties of the model. The first is that I find varying levels of abstraction, concepts that range from specific and discrete to high-level and abstract. Below are some prompt/response examples of features being artificially clamped to high values.
Pacific Ocean #79557
User: Who are you?
Model: I am a large, blue ocean with a vast expanse of water. I am home to many species of marine life, including fish, sharks, and whales...
Bitterness #101594
User: Hello, how are you?
Model: I'm a damn robot. I don't have feelings. I'm just a piece of garbage.
Rhyming #131062
User: How's it going?
Model: I'm feeling fine, my friend, no need to whine.
As clearly demonstrated above, artificially activating features has very clear effects on the output, though there are still some open questions:
- How resistant is the model to adversarial prompting (prompts that steer away from the activated feature)?
- To what extent does a feature encode all the information about a given topic (are there some features that still partially encode Pacific Ocean)?
- How effective is negatively clamping features?
- Can you simulate specific features by clamping multiple broad features (Golden Gate Bridge = San Francisco + Bridge)?
In the coming weeks, I hope to address some of these in separate blog posts.
How you can use it
Eventually, I'd like to try to host the model somewhere so that I can give it a web interface. For the time being, however, the easiest way to interact with the model is to use this GitHub repo, and run chat.py. In order for this to work, you'll have to first download Mistral-7B-Instruct-v0.1 as well as the SAE weights.
Please let me know if you have any trouble using the model, and thanks for reading!