--- tags: - mechanistic-interpretability - transcoding - bilinear - pythia - mlp library_name: pytorch license: mit --- # Pythia-410m Bilinear MLP Transcoders This repository contains bilinear transcoder models trained to approximate the MLP layers of [EleutherAI/pythia-410m](https://huggingface.co/EleutherAI/pythia-410m). ## Overview **Transcoders** are auxiliary models that learn to approximate the behavior of transformer components (in this case, MLPs) using simpler architectures. These bilinear transcoders use a Hadamard neural network architecture to approximate each of the 24 MLP layers in Pythia-410m. ## Model Architecture - **Base Model**: EleutherAI/pythia-410m (24 layers) - **Transcoder Type**: Bilinear (Hadamard Neural Network) - **Architecture**: `output = W_left @ (x ⊙ (W_right @ x)) + bias` - Input dimension: 1024 (d_model) - Hidden dimension: 4096 (4x expansion) - Output dimension: 1024 (d_model) - **Training**: 3000 batches, batch size 512, Muon optimizer (lr=0.02) - **Dataset**: monology/pile-uncopyrighted ## Performance Summary All 24 layers achieve >82% variance explained, with most layers >93%: | Layer | Final FVU | Variance Explained | Notes | |-------|-----------|-------------------|-------| | 0 | 0.0075 | 99.2% | Best performance | | 1-2 | 0.167-0.174 | 82.6-83.2% | Hardest to approximate | | 3-22 | 0.037-0.066 | 93.4-96.3% | Consistent performance | | 23 | 0.0259 | 97.4% | Second-best | **Average across all layers**: 93.4% variance explained (FVU = 0.0657) ## Repository Structure ``` . ├── layer_0/ │ ├── transcoder_weights_l0_bilinear_muon_3000b.pt │ └── config.yaml ├── layer_1/ │ ├── transcoder_weights_l1_bilinear_muon_3000b.pt │ └── config.yaml ... ├── layer_23/ │ ├── transcoder_weights_l23_bilinear_muon_3000b.pt │ └── config.yaml ├── figures/ │ ├── all_layers_comparison.png │ ├── training_curves_overlaid_layers_0_5.png │ ├── training_curves_overlaid_layers_6_11.png │ ├── training_curves_overlaid_layers_12_17.png │ └── training_curves_overlaid_layers_18_23.png └── README.md ``` ## Usage ```python import torch from transformers import AutoModelForCausalLM, AutoTokenizer # Load base model model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-410m") tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m") # Load transcoder for layer 3 layer_idx = 3 checkpoint = torch.load(f"layer_{layer_idx}/transcoder_weights_l{layer_idx}_bilinear_muon_3000b.pt") # Extract configuration config = checkpoint['config'] print(f"Input dim: {config.n_inputs}") print(f"Hidden dim: {config.n_hidden}") print(f"Output dim: {config.n_outputs}") # Reconstruct model (example - you'll need the Bilinear class) class Bilinear(torch.nn.Module): def __init__(self, n_inputs, n_hidden, n_outputs, bias=True): super().__init__() self.W_left = torch.nn.Linear(n_hidden, n_outputs, bias=bias) self.W_right = torch.nn.Linear(n_inputs, n_hidden, bias=False) def forward(self, x): right = self.W_right(x) hadamard = x.unsqueeze(-1) * right.unsqueeze(-2) return self.W_left(hadamard.sum(dim=-2)) transcoder = Bilinear(config.n_inputs, config.n_hidden, config.n_outputs, config.bias) transcoder.load_state_dict(checkpoint['model_state_dict']) transcoder.eval() # Use transcoder to approximate MLP with torch.no_grad(): # Get MLP input from layer 3 inputs = tokenizer("Hello world", return_tensors="pt") outputs = model(**inputs, output_hidden_states=True) mlp_input = outputs.hidden_states[layer_idx] # Before MLP # Approximate MLP output with transcoder transcoded_output = transcoder(mlp_input) ``` ## Training Details - **Optimizer**: Muon (momentum-based optimizer) - **Learning Rate**: 0.02 (hardcoded for Muon) - **Batch Size**: 512 - **Total Batches**: 3000 per layer - **Training Time**: ~75 minutes per layer on A100 - **Normalization**: Per-batch z-score normalization ## Checkpoint Contents Each checkpoint (`.pt` file) contains: - `model_state_dict`: Model weights - `optimizer_state_dict`: Optimizer state - `config`: Configuration object with dimensions - `mse_losses`: List of MSE losses per batch - `variance_explained`: List of variance explained per batch - `fvu_values`: List of FVU values per batch - `layer_idx`: Layer index (0-23) - `d_model`: Model dimension (1024) ## Key Findings 1. **Layer 0 is dramatically easier to approximate** (99.2% VE) - nearly perfect reconstruction 2. **Layers 1-2 are hardest** (~83% VE) - contain complex transformations 3. **Middle layers (3-22) are remarkably consistent** (93-96% VE) - homogeneous structure 4. **Final layer is highly learnable** (97.4% VE) This suggests that input and output layers have more structured patterns, while early layers (1-2) perform more complex transformations that are difficult for bilinear models to capture. ## Citation If you use these transcoders in your research, please cite: ```bibtex @misc{pythia410m-bilinear-transcoders, title={Bilinear MLP Transcoders for Pythia-410m}, author={[Your Name]}, year={2025}, publisher={Hugging Face}, url={https://huggingface.co/[your-username]/pythia-410m-bilinear-transcoders} } ``` ## License MIT License ## Acknowledgments - Base model: [EleutherAI/pythia-410m](https://huggingface.co/EleutherAI/pythia-410m) - Training dataset: [monology/pile-uncopyrighted](https://huggingface.co/datasets/monology/pile-uncopyrighted)