Sheet 8.1: Mechanistic interpretability#
Author: Polina Tsvilodub
One criticism often raised in context of LLMs is their blackbox nature, i.e., the inscrutability of the mechanics of the models and how or why they arrive at predictions, given the input. We have seen some methods that help to map out how models behave on certain tasks (sheet 7.1), what aspects of the input critically affect the output and which information the models process (sheet 6.1). In this sheet, we will look at methods for identifying the computational mechanisms that lead to the outputs, i.e., at mechanistic interpretability. It can be seen as trying to reverse-engineer the computational algorithms the mdoel has learned during training and that are active during certain tasks.
Early decoding#
First, we will look at early decoding, i.e., at applying the “unembedding” layer (projecting hidden representations into vocabulary space by applying a linear and a softmax layer) to representations in layers throughout the model (not just the last layer). For this, we will need to output results of the calculations in the layers. There are various parameters that can be passed to Furthermore, it is helpful to be able to access different weights of the model, which we did in sheet 3.1.
The code is largely based on this repository which accompanies the paper by Merullo et al. (2024).
import json
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import BloomTokenizerFast
import torch
import torch.nn.functional as F
import numpy as np
from torch import nn
def get_device():
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
return device
def load_gpt2(version):
device= get_device()
tokenizer = AutoTokenizer.from_pretrained(version)
model = AutoModelForCausalLM.from_pretrained(version, torch_dtype=torch.float16).to(device)
return model, tokenizer
class LambdaLayer(nn.Module):
def __init__(self, lambd):
super(LambdaLayer, self).__init__()
self.lambd = lambd
def forward(self, x):
return self.lambd(x)
class ModelWrapper(nn.Module):
def __init__(self, model, tokenizer):
super().__init__()
self.model = model.eval()
self.model.activations_ = {}
self.tokenizer = tokenizer
self.device = get_device()
self.num_layers = len(self.model.transformer.h)
self.hooks = []
self.layer_pasts = {}
def tokenize(self, s):
return self.tokenizer.encode(s, return_tensors='pt').to(self.device)
def list_decode(self, inpids):
return [self.tokenizer.decode(s) for s in inpids]
def layer_decode(self, hidden_states):
raise Exception("Layer decode has to be implemented!")
def get_layers(self, tokens, **kwargs):
outputs = self.model(input_ids=tokens, output_hidden_states=True, **kwargs)
hidden_states, true_logits = outputs.hidden_states, outputs.logits
logits = self.layer_decode(hidden_states)
#logits[-1] = true_logits.squeeze(0)[-1].unsqueeze(-1) #we used to just replace the last logits because we were applying ln_f twice
return torch.stack(logits).squeeze(-1)#, true_logits.squeeze(0)
def get_layers_w_attns(self, tokens, **kwargs):
outputs = self.model(input_ids=tokens, output_hidden_states=True, output_attentions=True, **kwargs)
hidden_states, true_logits = outputs.hidden_states, outputs.logits
logits = self.layer_decode(hidden_states)
#logits[-1] = true_logits.squeeze(0)[-1].unsqueeze(-1)
return torch.stack(logits).squeeze(-1), outputs.attentions#, true_logits.squeeze(0)
def rr_per_layer(self, logits, answer, debug=False):
#reciprocal rank of the answer at each layer
answer_id = self.tokenizer.encode(answer)[0]
if debug:
print("Answer id", answer_id, answer)
rrs = []
for i,layer in enumerate(logits):
soft = F.softmax(layer,dim=-1)
sorted_probs = soft.argsort(descending=True)
rank = float(np.where(sorted_probs.cpu().numpy()==answer_id)[0][0])
rrs.append(1/(rank+1))
return np.array(rrs)
def prob_of_answer(self, logits, answer, debug=False):
answer_id = self.tokenizer.encode(answer)[0]
if debug:
print("Answer id", answer_id, answer)
answer_probs = []
first_top = -1
mrrs = []
for i,layer in enumerate(logits):
soft = F.softmax(layer,dim=-1)
answer_prob = soft[answer_id].item()
sorted_probs = soft.argsort(descending=True)
if debug:
print(f"{i}::", answer_prob)
answer_probs.append(answer_prob)
#is_top_at_end = sorted_probs[0] == answer_id
return np.array(answer_probs)
def print_top(self, logits, k=10):
for i,layer in enumerate(logits):
print(f"{i}", self.tokenizer.decode(F.softmax(layer,dim=-1).argsort(descending=True)[:k]) )
def topk_per_layer(self, logits, k=10):
topk = []
for i,layer in enumerate(logits):
topk.append([self.tokenizer.decode(s) for s in F.softmax(layer,dim=-1).argsort(descending=True)[:k]])
return topk
def get_activation(self, name):
#https://github.com/mega002/lm-debugger/blob/01ba7413b3c671af08bc1c315e9cc64f9f4abee2/flask_server/req_res_oop.py#L57
def hook(module, input, output):
if "in_sln" in name:
num_tokens = list(input[0].size())[1]
self.model.activations_[name] = input[0][:, num_tokens - 1].detach()
elif "mlp" in name or "attn" in name or "m_coef" in name:
if "attn" in name:
num_tokens = list(output[0].size())[1]
self.model.activations_[name] = output[0][:, num_tokens - 1].detach()
self.model.activations_['in_'+name] = input[0][:, num_tokens - 1].detach()
elif "mlp" in name:
num_tokens = list(output[0].size())[0] # [num_tokens, 3072] for values;
self.model.activations_[name] = output[0][num_tokens - 1].detach()
elif "m_coef" in name:
num_tokens = list(input[0].size())[1] # (batch, sequence, hidden_state)
self.model.activations_[name] = input[0][:, num_tokens - 1].detach()
elif "residual" in name or "embedding" in name:
num_tokens = list(input[0].size())[1] # (batch, sequence, hidden_state)
if name == "layer_residual_" + str(self.num_layers-1):
self.model.activations_[name] = self.model.activations_[
"intermediate_residual_" + str(final_layer)] + \
self.model.activations_["mlp_" + str(final_layer)]
else:
if 'out' in name:
self.model.activations_[name] = output[0][num_tokens-1].detach()
else:
self.model.activations_[name] = input[0][:,
num_tokens - 1].detach()
return hook
def reset_activations(self):
self.model.activations_ = {}
class GPT2Wrapper(ModelWrapper):
def layer_decode(self, hidden_states):
logits = []
for i,h in enumerate(hidden_states):
h=h[:, -1, :] #(batch, num tokens, embedding size) take the last token
if i == len(hidden_states)-1:
normed = h #ln_f would already have been applied
else:
normed = self.model.transformer.ln_f(h)
l = torch.matmul(self.model.lm_head.weight, normed.T)
logits.append(l)
return logits
def add_hooks(self):
for i in range(self.num_layers):
#intermediate residual between
#print('saving hook')
self.hooks.append(self.model.transformer.h[i].ln_1.register_forward_hook(self.get_activation(f'in_sln_{i}')))
self.hooks.append(self.model.transformer.h[i].attn.register_forward_hook(self.get_activation('attn_'+str(i))))
self.hooks.append(self.model.transformer.h[i].ln_2.register_forward_hook(self.get_activation("intermediate_residual_" + str(i))))
self.hooks.append(self.model.transformer.h[i].ln_2.register_forward_hook(self.get_activation("out_intermediate_residual_" + str(i))))
self.hooks.append(self.model.transformer.h[i].mlp.register_forward_hook(self.get_activation('mlp_'+str(i))))
#print(self.model.activations_)
def get_pre_wo_activation(self, name):
#wo refers to the output matrix in attention layers. The last linear layer in the attention calculation
def hook(module, input, output):
#use_cache=True (default) and output_attentions=True have to have been passed to the forward for this to work
_, past_key_value, attn_weights = output
value = past_key_value[1]
pre_wo_attn = torch.matmul(attn_weights, value)
self.model.activations_[name]=pre_wo_attn
return hook
def get_past_layer(self, name):
#wo refers to the output matrix in attention layers. The last linear layer in the attention calculation
def hook(module, input, output):
#use_cache=True (default) and output_attentions=True have to have been passed to the forward for this to work
#print(len(output), output, name)
_, past_key_value, attn_weights = output
self.layer_pasts[name]=past_key_value
return hook
def add_mid_attn_hooks(self):
for i in range(self.num_layers):
self.hooks.append(self.model.transformer.h[i].attn.register_forward_hook(self.get_pre_wo_activation('mid_attn_'+str(i))))
self.hooks.append(self.model.transformer.h[i].attn.register_forward_hook(self.get_past_layer('past_layer_'+str(i))))
def rm_hooks(self):
for hook in self.hooks:
hook.remove()
def reset_activations(self):
self.activations_ = {}
self.last_pasts = {}
model, tokenizer = load_gpt2('gpt2-medium')
model = model.float()
wrapper = GPT2Wrapper(model, tokenizer)
def tokenize(text):
inp_ids = wrapper.tokenize(text)
str_toks = wrapper.list_decode(inp_ids[0])
return inp_ids, str_toks
poland_text="""Q: What is the capital of France?
A: Paris
Q: What is the capital of Poland?
A:"""
poland_ids, pol_toks = tokenize(poland_text)
poland_ids.shape
logits = wrapper.get_layers(poland_ids)
wrapper.print_top(logits[1:]) #skip the embedding layer
# look at the shape of the logits and try to understand what they represent
logits.shape
Exercise 8.1.1: Early decoding
(For yourself) Read through the code above and make sure that you understand what it does and why. Make sure to understand the concept of early decoding.
Try a different example (maybe for a different task). Do you see a similar pattern (comparing results between the layers)?
Residual stream#
Next, the lecture discussed the role of the residual stream, i.e., the “stream” which passes information between the transformer blocks and, only undergoes linear transformations. Note that there is no block or layers called “residual stream” in the transformer architecture (i.e., if you were to print a pre-defined architecture); rather, it is a conceptual interpretation of the transformer architecture, realising that due to residual connections within the transformer blocks, one can look at the flow of information (i.e., the vector representations) as being read from by the transformer blocks (i.e., reading results of previous computations), and then writing the results of the attention calculations back to the representations (via linear operations applied to previous representations).
Below is a small example of accessing the residual stream. The idea behind this analysis is understanding what exactly the mechanistic role of the single computations (i.e., applying attention vs. the FFNN layer). The code below works with so-called hooks, i.e., functions that are “hooked onto” the forward pass of through the model and are executed together with the normal computations, when the model is called. They are based on this native PyTorch functionality. Their implementation is in the code cells above.
Specicifically, the question tested in the code below is which computational step promotes “Warsaw” over “Poland”, specifically, whether it is the pass through the FFNN. This is done by using the same early decoding technique and looking at the model’s prediction before and after applying the FFNN.
Exercise 8.1.2: Residual stream decoding
(For yourself) Read through the code above and make sure that you understand what it does and why. How is the question above operationalized?
Inspecting the results, would you say that the FFNN is responsinble for prompting “Warsaw”?
In case you saw a similar pattern for your other example above, try to adapt the code below to your example.
[Optional] If you are curious to see more, read the excellent paper above and look at the full demo notebook.
"""
We can decode at the residual stream between the attention and FFN to show that
it is the FFN update that updates from Poland to Warsaw, but a slightly easier way to do this is to
subtract the FFN update that was applied at layer 19 to show the intermediate residual stream state
(i.e., between the attention and FFN)
"""
wrapper.add_hooks()
out = wrapper.model(input_ids = poland_ids, output_hidden_states=True) #run it again to activate hooks
logits = out.logits
hidden_states = out.hidden_states
hidden_states = list(hidden_states)[1:] #skip the embedding layer to stay consistent with our indexing
#get the FFN output update at layer 19
o_city = wrapper.model.activations_['mlp_19']
print(len(hidden_states)) #24
layer_logits = wrapper.layer_decode(hidden_states)
layer_logits = torch.stack(layer_logits).squeeze(-1)
print("Original top tokens at layer 19")
wrapper.print_top(layer_logits[19].unsqueeze(0))
hidden_states[19]-=o_city
layer_logits = wrapper.layer_decode(hidden_states)
layer_logits = torch.stack(layer_logits).squeeze(-1)
print("After subtracting mlp_19 (o_city)")
wrapper.print_top(layer_logits[19].unsqueeze(0))
Activation patching#
Another technique discussed in the lecture is activation patching. This can be seen as a method for causal intervention on the computational mechanisms of the model. For instance, certain results of the computations are injected, and the effect these results have on the outcome are observed. We could, of course, inject random values, but that might require trying quite many values before being able to observe meaningful differences in the output. Therefore, instead, certain representations computed in one pass through the model are taken (i.e., representations from the clean run) and injected into another (the corrupted run). These injected representations are called the patch.
Patching techniques are quite tricky to work with since patching something in an intermediate layer affects not only the local layer, but also all the downstream computations. To be able to retrieve meaningful results, careful comparison and, e.g., freezing of subsequent representations is required.
Below is an example of using activation patching on the Indirect Object Identification task, with the help of the library transformer_lens
.
The idea of the task if simple: given a sentence like “After John and Mary went to the store, Mary gave a bottle of milk to”, identify the indirect object (i.e., “John”, rather than “Mary”). The goal is to identify which model activations (i.e., representation computed within the transformer ) are important for completing a task. We do this by setting up a clean prompt and a corrupted prompt, for instance:
Clean prompt: After John and Mary went to the store, Mary gave a bottle of milk to
Corrupted prompt: After John and Mary went to the store, John gave a bottle of milk to
Further, we define a metric for performance on the task, e.g., the difference in logits of the correct token, given different inputs (recall our methods for assessing model performance, where one core assumption is that a model can perform a task well if it assigns higher log probabilities (i.e., the logits are higher) for correct predictions that for incorrect ones).
We then pick a specific model activation, run the model on the corrupted prompt, but then intervene on that activation and patch in its value when run on the clean prompt. We then apply the metric, and see how much this patch has recovered the clean performance. Essentially, we ask: given a corrupted input, if we inject certain representations from the “correct” run, can we “fix” the performance? If we can, we know that the injected representations (causally!) contribute to producing the correct output.
The code below is taken from this demo.
Exercise 8.1.3: Activation patching
(For yourself) Read through the code above and make sure that you understand what it does and why.
#!pip install transformer_lens plotly
from transformer_lens import HookedTransformer
import plotly.express as px
import transformer_lens.utils as utils
import tqdm
from functools import partial
import torch
# load the model within the wrapper of the library which allows to easily access and patch activations
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = HookedTransformer.from_pretrained("gpt2-small", device=device)
/opt/anaconda3/envs/understanding_llms/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
warnings.warn(
Loaded pretrained model gpt2-small into HookedTransformer
# first, we check if the model can do the task at all
# i.e., we compare the difference in logits for the correct and incorrect answer
# given different inputs without any interventions
clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
clean_tokens = model.to_tokens(clean_prompt)
corrupted_tokens = model.to_tokens(corrupted_prompt)
def logits_to_logit_diff(logits, correct_answer=" John", incorrect_answer=" Mary"):
# model.to_single_token maps a string value of a single token to the token index for that token
# If the string is not a single token, it raises an error.
correct_index = model.to_single_token(correct_answer)
incorrect_index = model.to_single_token(incorrect_answer)
return logits[0, -1, correct_index] - logits[0, -1, incorrect_index]
# We run on the clean prompt with the cache so we store activations to patch in later.
clean_logits, clean_cache = model.run_with_cache(clean_tokens)
clean_logit_diff = logits_to_logit_diff(clean_logits)
print(f"Clean logit difference: {clean_logit_diff.item():.3f}")
# We don't need to cache on the corrupted prompt.
corrupted_logits = model(corrupted_tokens)
corrupted_logit_diff = logits_to_logit_diff(corrupted_logits)
print(f"Corrupted logit difference: {corrupted_logit_diff.item():.3f}")
Clean logit difference: 4.276
Corrupted logit difference: -2.738
# define a helper
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)
# We define a residual stream patching hook
# We choose to act on the residual stream at the start of the layer, so we call it resid_pre
# The type annotations are a guide to the reader and are not necessary
def residual_stream_patching_hook(
resid_pre,
hook,
position
):
# Each HookPoint has a name attribute giving the name of the hook.
clean_resid_pre = clean_cache[hook.name]
# NOTE: this is the key step in the patching process
# where we replace the activations in the residual stream with the same activations from the clean run
resid_pre[:, position, :] = clean_resid_pre[:, position, :]
return resid_pre
# We make a tensor to store the results for each patching run.
# We put it on the model's device to avoid needing to move things between the GPU and CPU, which can be slow.
num_positions = len(clean_tokens[0])
ioi_patching_result = torch.zeros((model.cfg.n_layers, num_positions), device=model.cfg.device)
for layer in tqdm.tqdm(range(model.cfg.n_layers)):
for position in range(num_positions):
# Use functools.partial to create a temporary hook function with the position fixed
temp_hook_fn = partial(residual_stream_patching_hook, position=position)
# Run the model with the patching hook
patched_logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=[
(utils.get_act_name("resid_pre", layer), temp_hook_fn)
])
# Calculate the logit difference
patched_logit_diff = logits_to_logit_diff(patched_logits).detach()
# Store the result, normalizing by the clean and corrupted logit difference so it's between 0 and 1 (ish)
ioi_patching_result[layer, position] = (patched_logit_diff - corrupted_logit_diff)/(clean_logit_diff - corrupted_logit_diff)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using `tokenizers` before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
100%|██████████| 12/12 [00:23<00:00, 1.97s/it]
# Add the index to the end of the label, because plotly doesn't like duplicate labels
token_labels = [f"{token}_{index}" for index, token in enumerate(model.to_str_tokens(clean_tokens))]
imshow(ioi_patching_result, x=token_labels, xaxis="Position", yaxis="Layer", title="Normalized Logit Difference After Patching Residual Stream on the IOI Task")