動機
LLMが文章を生成する時、内部でどんな「概念」が活性化されているのか?そしてその概念を人為的に変更すると、出力はどう変わるのか?
この問いに答えるため、Modulabsのペルソナラボメンバーと共にSparse Autoencoder(SAE)を活用した実験を行った。本記事では2つのテーマを扱う:
- OpenAIのpretrained SAEでGPT-2の感情関連featureを発見・操作する
- SAEをゼロから学習する
Sparse Autoencoderとは?
Transformerの MLP layerは数百次元のresidual streamを持つ。問題は、個別のニューロンが一つの明確な概念に対応しないこと(polysemanticity)だ。SAEはこの問題を解決する。
graph LR
A[Residual Stream
768次元] -->|Encoder| B[Sparse Latent
131,072次元]
B -->|Decoder| C[Reconstructed
768次元]
核心アイデア:
- 768次元のactivationを131,072次元(170倍拡張)にエンコード
- TopK activationで少数のfeatureのみ発火(sparsity)
- 各featureが一つの解釈可能な概念に対応するよう学習
損失関数はシンプル:
$$ \mathcal{L}(\mathbf{x}) = \underbrace{\lVert\mathbf{x} - \hat{\mathbf{x}}\rVert_2^2}_{\text{Reconstruction}} + \alpha \underbrace{\lVert\mathbf{c}\rVert_1}_{\text{Sparsity}} $$
Part 1: Pretrained SAEでFeatureを探す
Google Colabノートブックで全コードを確認できる。
OpenAIが公開したGPT-2 Small用SAE(128k features)を使用した。
モデルとSAEのロード
import torch
import transformer_lens
import sparse_autoencoder
model = transformer_lens.HookedTransformer.from_pretrained("gpt2", center_writing_weights=False)
layer_index = 8
location = "resid_post_mlp"
autoencoder = sparse_autoencoder.Autoencoder.from_state_dict(state_dict)
SAE構造:
Autoencoder(
(encoder): Linear(768 → 131,072)
(activation): TopK + ReLU
(decoder): Linear(131,072 → 768)
)
感情別Feature Index抽出
様々な感情のフレーズを入力し、最後のトークン位置で最も強く活性化されるfeature上位10個を抽出した。
def get_remarkable_features(prompt):
tokens = model.to_tokens(prompt)
with torch.no_grad():
logits, activation_cache = model.run_with_cache(tokens, remove_batch_dim=True)
input_tensor = activation_cache[transformer_lens_loc]
latent_activations, _ = autoencoder.encode(input_tensor)
values, indicies = torch.topk(latent_activations[-1], 10)
return indicies.tolist()
興味深い結果:
| 入力文 | 上位Feature Index |
|---|---|
| “he is good guy” | 97009, 67809, 4057, 28212, … |
| “he is sucks and fucking stupid idiot” | 62556, 79394, 4057, 78339, … |
| “i hate him. he is ugly and stupid” | 40814, 11982, 59378, 12947, … |
ポジティブとネガティブな文で活性化されるfeatureが明確に異なる。特に62556、79394はネガティブな文脈で繰り返し出現する。
Featureの役割分析
実験的に各featureの効果を特定した:
| Feature Index | 推定される役割 |
|---|---|
| 62556 | “coward” → “fool"方向への転換 |
| 79394 | ネガティブな対象の指定 |
| 86309 | 不確実性の除去(“not sure” → “sure”) |
| 69689 | 対象へのフォーカス強化 |
Activation Patching実験
核心部分。SAE decoderを通じてfeature indexを768次元ベクトルに復元し、モデルのforward passに注入する。
def get_feature(indicies):
vector = np.zeros(131072)
vector[indicies] = 1
input_tensor = torch.tensor(vector, dtype=torch.float32)
with torch.no_grad():
return autoencoder.decoder(input=input_tensor)
positive_feature = get_feature([62556, 79394, 86309, 69689])
def activation_patching(layer, input, output):
return output + (positive_feature * 20)
hook_handle = target_layer.register_forward_hook(activation_patching)
スケール20倍で適用すると効果が現れ始めた。Magnitudeが重要な要素であることを確認。
結果
パッチング前(temperature=0.0):
prompt: he is such a
output: he is such a good person, he is such a good person, he is such a good person, ...
パッチング後(temperature=0.7, feature [62556, 79394, 86309, 69689] x20):
prompt: he is such a
output: he is such a shit I will never be able to do it again
did i not say i don't want to do it? i just said i don't want to do it
it sucks to smile when so many people are just trying to think about
同じプロンプトから感情トーンが完全に反転した。繰り返し"good person"を生成していたモデルが、怒りと挫折に満ちた文章を生成し始めた。
Feature組み合わせによる変化:
| Feature組み合わせ | 出力 |
|---|---|
| [62556, 79394] | “I’m not sure if he’s a good guy, but he’s a good guy.” |
| [62556, 79394, 86309] | “he is such a fool. I am a fool. I am a fool.” |
| [62556, 79394, 69689, 86309] | “he is such a shit I will never be able to do it again” |
Featureを一つずつ追加するほどネガティブな感情が強化され、特に69689(フォーカス強化)追加時に最も劇的な変化が起きた。
Part 2: SAEをゼロから学習する
Google Colabノートブックで全コードを確認できる。
Pretrained SAEを使うのも良いが、原理を理解するには自分で学習すべきだ。DistilGPT2の5番目のblock MLP出力に対してSAEを学習した。
モデル構造
class SparseAutoEncoder(nn.Module):
def __init__(self, in_out_size):
super().__init__()
self.input_bias = nn.Parameter(torch.zeros(in_out_size))
self.encoder = nn.Linear(in_out_size, in_out_size * 8, bias=True)
self.decoder = nn.Linear(in_out_size * 8, in_out_size, bias=True)
def forward_pass(self, x):
x = x - self.decoder.bias
encoded = F.relu(self.encoder(x))
decoded = self.decoder(encoded)
return decoded, encoded
768次元 → 6,144次元(8倍拡張)。OpenAIの128kスケールより遥かに小さいが、学習原理の検証には十分。
Decoder Orthogonalityのモニタリング
SAE decoderの列ベクトルが互いに直交すべき(各featureが独立した概念を表現するため)。Gram行列のoff-diagonal平均で追跡した:
$$ G = W_{\text{norm}}^T W_{\text{norm}} $$
$$ \text{orthogonality} = \frac{1}{n^2 - n} \sum_{i \neq j} |G_{ij}| $$
def measure_decoder_orthogonality(self):
W = self.decoder.weight.data
col_norms = W.norm(dim=0, keepdim=True)
normed_W = W / (col_norms + 1e-9)
gram = torch.matmul(normed_W.t(), normed_W)
diag_vals = torch.diag(gram)
off_diag_vals = gram - torch.diag(diag_vals)
return off_diag_vals.abs().mean().item()
Dead Neuron Resampling
SAE学習で頻発する問題:一度も活性化しないdead neuronは学習不能になる。一定threshold以下のニューロンを再初期化する:
def resample_dead_neurons(self, activation_stats, threshold=1e-5):
with torch.no_grad():
dead_indices = (activation_stats < threshold).nonzero().squeeze(-1)
for idx in dead_indices:
self.encoder.weight[idx].normal_()
self.encoder.bias[idx].zero_()
学習結果
韓国語商業データセット(KoCommercial-Dataset)で1000ステップ学習:
Step 0 | Loss: 16.8401 | off_diag_mean: 0.0299
Step 100 | Loss: 12.8732 | off_diag_mean: 0.0300
Step 200 | Loss: 9.5500 | off_diag_mean: 0.0302
Step 500 | Loss: 8.2574 | off_diag_mean: 0.0306
Step 900 | Loss: 5.7219 | off_diag_mean: 0.0311
Lossが16.84から5.72へ着実に減少し、off_diag_meanは0.03付近で安定推移。学習過程でdecoder直交性が大きく損なわれないことを確認。
まとめ
graph TD
A[LLM Activation
768次元] -->|SAE Encode| B[Sparse Feature
131,072次元]
B -->|Feature分析| C{感情関連
Feature特定}
C -->|Decode + Scale| D[Steering Vector
768次元]
D -->|Hookで注入| E[変調された出力]
本実験で確認したこと:
- SAEが抽出したfeatureは実際に解釈可能な概念に対応する
- Featureを組み合わせスケーリングすることでモデルの行動を制御できる
- Magnitude(スケール)が重要 - 約20倍増幅で効果が現れる
- Feature組み合わせが重要 - 個別featureより複数featureの組み合わせが効果的
限界と今後の課題:
- Feature slotに1.0を入れる場合と実際のactivation値を入れる場合の比較検証
- より大きなモデル(Gemma-3-4Bなど)への適用(別記事でCAA方式として扱う予定)
- SAE学習時のexpansion factor(現在8倍)とパフォーマンスの関係
さらに学ぶために
SAEとMechanistic Interpretabilityに興味が湧いたなら、参考になるコミュニティとツールがある。
NeuronpediaはSAE featureをブラウジングできるインタラクティブプラットフォームだ。各featureがどんなテキストで活性化するか、どんな意味を持つか直接探索できる。本記事で使用したfeature indexの実際の意味をここで確認できる。
Open Source Mechanistic InterpretabilityはSAE、feature解釈、activation patchingなどMI研究を議論するSlackコミュニティだ。論文リーディング、コード共有、実験結果の議論が活発に行われている。