Elriggs commited on
Commit
d0ee977
Β·
verified Β·
1 Parent(s): 0746001

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +165 -0
README.md ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - mechanistic-interpretability
4
+ - transcoding
5
+ - bilinear
6
+ - pythia
7
+ - mlp
8
+ library_name: pytorch
9
+ license: mit
10
+ ---
11
+
12
+ # Pythia-410m Bilinear MLP Transcoders
13
+
14
+ This repository contains bilinear transcoder models trained to approximate the MLP layers of [EleutherAI/pythia-410m](https://huggingface.co/EleutherAI/pythia-410m).
15
+
16
+ ## Overview
17
+
18
+ **Transcoders** are auxiliary models that learn to approximate the behavior of transformer components (in this case, MLPs) using simpler architectures. These bilinear transcoders use a Hadamard neural network architecture to approximate each of the 24 MLP layers in Pythia-410m.
19
+
20
+ ## Model Architecture
21
+
22
+ - **Base Model**: EleutherAI/pythia-410m (24 layers)
23
+ - **Transcoder Type**: Bilinear (Hadamard Neural Network)
24
+ - **Architecture**: `output = W_left @ (x βŠ™ (W_right @ x)) + bias`
25
+ - Input dimension: 1024 (d_model)
26
+ - Hidden dimension: 4096 (4x expansion)
27
+ - Output dimension: 1024 (d_model)
28
+ - **Training**: 3000 batches, batch size 512, Muon optimizer (lr=0.02)
29
+ - **Dataset**: monology/pile-uncopyrighted
30
+
31
+ ## Performance Summary
32
+
33
+ All 24 layers achieve >82% variance explained, with most layers >93%:
34
+
35
+ | Layer | Final FVU | Variance Explained | Notes |
36
+ |-------|-----------|-------------------|-------|
37
+ | 0 | 0.0075 | 99.2% | Best performance |
38
+ | 1-2 | 0.167-0.174 | 82.6-83.2% | Hardest to approximate |
39
+ | 3-22 | 0.037-0.066 | 93.4-96.3% | Consistent performance |
40
+ | 23 | 0.0259 | 97.4% | Second-best |
41
+
42
+ **Average across all layers**: 93.4% variance explained (FVU = 0.0657)
43
+
44
+ ## Repository Structure
45
+
46
+ ```
47
+ .
48
+ β”œβ”€β”€ layer_0/
49
+ β”‚ β”œβ”€β”€ transcoder_weights_l0_bilinear_muon_3000b.pt
50
+ β”‚ └── config.yaml
51
+ β”œβ”€β”€ layer_1/
52
+ β”‚ β”œβ”€β”€ transcoder_weights_l1_bilinear_muon_3000b.pt
53
+ β”‚ └── config.yaml
54
+ ...
55
+ β”œβ”€β”€ layer_23/
56
+ β”‚ β”œβ”€β”€ transcoder_weights_l23_bilinear_muon_3000b.pt
57
+ β”‚ └── config.yaml
58
+ β”œβ”€β”€ figures/
59
+ β”‚ β”œβ”€β”€ all_layers_comparison.png
60
+ β”‚ β”œβ”€β”€ training_curves_overlaid_layers_0_5.png
61
+ β”‚ β”œβ”€β”€ training_curves_overlaid_layers_6_11.png
62
+ β”‚ β”œβ”€β”€ training_curves_overlaid_layers_12_17.png
63
+ β”‚ └── training_curves_overlaid_layers_18_23.png
64
+ └── README.md
65
+ ```
66
+
67
+ ## Usage
68
+
69
+ ```python
70
+ import torch
71
+ from transformers import AutoModelForCausalLM, AutoTokenizer
72
+
73
+ # Load base model
74
+ model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-410m")
75
+ tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m")
76
+
77
+ # Load transcoder for layer 3
78
+ layer_idx = 3
79
+ checkpoint = torch.load(f"layer_{layer_idx}/transcoder_weights_l{layer_idx}_bilinear_muon_3000b.pt")
80
+
81
+ # Extract configuration
82
+ config = checkpoint['config']
83
+ print(f"Input dim: {config.n_inputs}")
84
+ print(f"Hidden dim: {config.n_hidden}")
85
+ print(f"Output dim: {config.n_outputs}")
86
+
87
+ # Reconstruct model (example - you'll need the Bilinear class)
88
+ class Bilinear(torch.nn.Module):
89
+ def __init__(self, n_inputs, n_hidden, n_outputs, bias=True):
90
+ super().__init__()
91
+ self.W_left = torch.nn.Linear(n_hidden, n_outputs, bias=bias)
92
+ self.W_right = torch.nn.Linear(n_inputs, n_hidden, bias=False)
93
+
94
+ def forward(self, x):
95
+ right = self.W_right(x)
96
+ hadamard = x.unsqueeze(-1) * right.unsqueeze(-2)
97
+ return self.W_left(hadamard.sum(dim=-2))
98
+
99
+ transcoder = Bilinear(config.n_inputs, config.n_hidden, config.n_outputs, config.bias)
100
+ transcoder.load_state_dict(checkpoint['model_state_dict'])
101
+ transcoder.eval()
102
+
103
+ # Use transcoder to approximate MLP
104
+ with torch.no_grad():
105
+ # Get MLP input from layer 3
106
+ inputs = tokenizer("Hello world", return_tensors="pt")
107
+ outputs = model(**inputs, output_hidden_states=True)
108
+ mlp_input = outputs.hidden_states[layer_idx] # Before MLP
109
+
110
+ # Approximate MLP output with transcoder
111
+ transcoded_output = transcoder(mlp_input)
112
+ ```
113
+
114
+ ## Training Details
115
+
116
+ - **Optimizer**: Muon (momentum-based optimizer)
117
+ - **Learning Rate**: 0.02 (hardcoded for Muon)
118
+ - **Batch Size**: 512
119
+ - **Total Batches**: 3000 per layer
120
+ - **Training Time**: ~75 minutes per layer on A100
121
+ - **Normalization**: Per-batch z-score normalization
122
+
123
+ ## Checkpoint Contents
124
+
125
+ Each checkpoint (`.pt` file) contains:
126
+ - `model_state_dict`: Model weights
127
+ - `optimizer_state_dict`: Optimizer state
128
+ - `config`: Configuration object with dimensions
129
+ - `mse_losses`: List of MSE losses per batch
130
+ - `variance_explained`: List of variance explained per batch
131
+ - `fvu_values`: List of FVU values per batch
132
+ - `layer_idx`: Layer index (0-23)
133
+ - `d_model`: Model dimension (1024)
134
+
135
+ ## Key Findings
136
+
137
+ 1. **Layer 0 is dramatically easier to approximate** (99.2% VE) - nearly perfect reconstruction
138
+ 2. **Layers 1-2 are hardest** (~83% VE) - contain complex transformations
139
+ 3. **Middle layers (3-22) are remarkably consistent** (93-96% VE) - homogeneous structure
140
+ 4. **Final layer is highly learnable** (97.4% VE)
141
+
142
+ This suggests that input and output layers have more structured patterns, while early layers (1-2) perform more complex transformations that are difficult for bilinear models to capture.
143
+
144
+ ## Citation
145
+
146
+ If you use these transcoders in your research, please cite:
147
+
148
+ ```bibtex
149
+ @misc{pythia410m-bilinear-transcoders,
150
+ title={Bilinear MLP Transcoders for Pythia-410m},
151
+ author={[Your Name]},
152
+ year={2025},
153
+ publisher={Hugging Face},
154
+ url={https://huggingface.co/[your-username]/pythia-410m-bilinear-transcoders}
155
+ }
156
+ ```
157
+
158
+ ## License
159
+
160
+ MIT License
161
+
162
+ ## Acknowledgments
163
+
164
+ - Base model: [EleutherAI/pythia-410m](https://huggingface.co/EleutherAI/pythia-410m)
165
+ - Training dataset: [monology/pile-uncopyrighted](https://huggingface.co/datasets/monology/pile-uncopyrighted)