Sheet 6.1 LLM probing & attribution#

Author: Polina Tsvilodub

In this sheet, we will familiarize ourselves with some methods of looking “under the hood” of transformers. In particular, we will see how we can visualize and trace which inputs are processed where in the model and how they contribute to the output, and what kinds of information is processed. Specifically, the learning goals for this sheet are:

  • familiarization with transformer attention visualization for inspecting attention patterns

  • understanding how to extract representations of a model from different layers

  • familiarization with probing of a transformer’s syntactic ‘knowledge’.

Attention visualization#

One of the core processing mechanisms in the transformer is the attention mechanism. As discussed in the lecture on transformers, depending on the architecture of the model, there might be various attention blocks:

  • if the model is an encoder-only model (e.g., BERT), it has encoder self-attention;

  • if it is a decoder-only model (e.g., all GPT models), it has a decoder (i.e., causal) self-attention;

  • if it is an encoder-decoder model (e.g., translation models, architectures like T5), it has those and additionally cross-attention between the encoder and the decoder.

First, we will inspect attention visualizations, which indicate the magnitudes of attention scores between a specific token \(i\) and other tokens. (Reminder: the scores are computed as the dot product of the \(i\) token’s query vecor and the other tokens’ key vectors.) Intuitively, the larger a score, the more will the respective representation of some other token contribute to predicting the output based on \(i\).

First, we will explore the example from the lecture (slide 46) hands-on. In the example, a sequence-to-sequence (i.e., encoder-decoder) model is used for translation the English sentence “The brown dog ran.” into the French sentence “Le chien brun a couru.”. We will load the FLAN-T5 small model, a seq2seq model fine-tuned to follow various task instructions (including translation).

We will use the package BertViz for the visualization. It allows to explore parts of the model interactively, i.e., select specific model parts (e.g., encoder or decoder), specific layers (i.e., attention layers in transformer blocks), and attention heads.

# install the packages required for running the visualization
#!pip install bertviz ipywidgets
# import packages
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from bertviz import model_view, head_view
# load the model
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
model_t5 = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
# define input and target
input_ids = tokenizer.encode("Translate to French: The brown dog ran.", return_tensors="pt")
target_ids = tokenizer.encode("Le chien brun a couru.", return_tensors="pt")
# Run model and get the attentions by setting output_attentions=True
output = model_t5(input_ids=input_ids, labels=target_ids, output_attentions=True, return_dict=True) 


# we will need to pass the attiontion to the visualization function
# therefore, we look at the output of the model to see how to access the attention scores
print(output.keys())
# we retrieve various attention scores from the output
encoder_attention = output.encoder_attentions
cross_attention = output.cross_attentions
decoder_attention = output.decoder_attentions

# furthermore, for ease of interpreting the visualization, we convert the token ids to string corresponding to those tokens
input_tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) 
decoder_tokens = tokenizer.convert_ids_to_tokens(target_ids[0])
# now we use the overall model attention visualization
# select the attention parts you want to look at via the drop-down
# click on the facets to zoom in on the attention heads in a specific layer
model_view(
    encoder_attention=encoder_attention,
    decoder_attention=decoder_attention,
    cross_attention=cross_attention,
    encoder_tokens= input_tokens,
    decoder_tokens = decoder_tokens
)

Now, we zoom in on the encoder attention which is used to create representations of the instruction + source sentence. Therefore, we inspect the encoder_attention below.

# there is a also a head view that allows you to look at the attention of a single head 
# which can be selected by double-clicking on the colored tile
# for a single layer (can be selected via the drop-down)
head_view(encoder_attention, input_tokens)

Next, we look at the cross-attention, i.e., the attention weights computed based on query vectors of the decoder representations and key vectors from the encoder. Intuitively, these represent the importance of input representations (the English sentence) for computing the output (French sentence).

# by default, the head view visualizes self-attention (i.e., attention between the same tokens). 
# For cross-attention, one should specify the cross_attention parameters
head_view(
    cross_attention=cross_attention, 
    encoder_tokens=input_tokens, 
    decoder_tokens=decoder_tokens
)

Exercise 6.1.1: Interpreting attention scores

  1. Inspect the visualization above. How many layers does the model have? How many attention heads per layer are there? Access visualizations of scores of a single attention head. Do you observe any interesting patterns across layers and / or attention heads?

  2. Consider the input “What is the capital of France?” and output “The capital is Paris”. Intuitively, which token do you think will receive high attention scores in which part of the model, from which tokens? Complete the code below and inspect the output. Do the results match your intuition?

  3. Use the functions above to inspect decoder attention. Make sure you identify the causal part of the attention scores.

Finally, the package also offer a neuron view which allows to inspect the query and key vectors, i.e., the full process of computing the attention scores. It is only supported for a few models; for purposes of an example, please inspect the notebook provided by BertViz here.

Attribution methods#

Looking at attention as we have seen above provides a window into seeing “post-hoc” traces of how a model arrives at a given output, given the input. Yet as discussed in the lecture, these results should be treated carefully and cannot be fully seen as explanations of why a model arrived at the give output.

For addressing this question more carefully, alternative methods have been develop which attempt to identify aspects of the input that are crucial for generating the particular prediction. This might help gain insights in, e.g., whether the model is sensitive to spurious aspects of the prompt etc. There are different methods for doing so, and the lecture mentioned integrated gradients as one method that would be applicable to LMs. The code below provides examples of using a few slightly different approaches for attributing predictions to input features, specifically:

  • Gradient tracing

We will use the package inseq for looking at these different attribution techniques. It supports seq2seq and causal models available through HuggingFace. It supports:

  • gradient-based methods

    • Gradient-based methods use backpropagation through the network and the resulting gradients to assess the contribution of individual features.

  • perturbation-based methods

    • Perturbation-based methods change or obscure the input and observe the changes in the output.

  • as well as attention weight extraction methods (similar to what we have seen above).

To use these various attribution methods, the core endpoint for the package is to call inseq.load_model(<HF model name>, <method>), allowing to use a specific method on models from HF.

The code below walks through an example of using integrated gradients (discussed in the lecture) and more contrastive explanation methods using GPT-2. Contrastive explanation refers to the idea of comparing the attributions for a target output text A to a different contrastive output B, in order to answer the question “How much is feature X contributing to predicting A rather than B?” The latter will use a saliency attribution method which simply returns the absolute value of the gradients with respect to inputs.

# install the package
#!pip install inseq
import inseq
# load GPT-2 and hook it with the integrated gradients method
model = inseq.load_model("gpt2", "integrated_gradients")

# Generate the output for input_texts and attribute inputs at every steps of the generation
out = model.attribute(input_texts="The capital of France is ")

# Visualize the attributions and step scores
out.show()
# we can also pass a generated text for a given input text to answer the question
# “How would the following output be justified in light of the inputs by the model?”

out_with_generated = model.attribute(
    input_texts="The capital of France is", 
    generated_texts="The capital of France is Paris."
)
out_with_generated.show()
# there are more parameters that allow to customize the attributions even further
# see, e.g., here: https://inseq.org/en/latest/examples/quickstart.html
# Perform the contrastive attribution:
# Regular (forced) target -> "The capital of France is Paris."
# Contrastive (incorrect) target -> "The capital of France is Berlin."

# for this method, integrated gradients are actually not supported yet
# therefore, we look at the saliency based attribution method here
attribution_model = inseq.load_model("gpt2", "saliency")

out_contrastive = attribution_model.attribute(
    "The capital of France is",
    "The capital of France is Paris.",
    attributed_fn="contrast_prob_diff",
    # Special argument to specify the contrastive target, used by the contrast_prob_diff function
    contrast_targets="The capital of France is Berlin.",
    attribute_target=True,
    # We also visualize the score used as target using the same function as step score
    step_scores=["contrast_prob_diff"]
)

out_contrastive.show()

Exercise 6.1.3: Feature attribution

  1. Try the examples above for a few other inputs. Do the results match your intuition? I.e., do those features contribute to a particular prediction that you would expect?

  2. Implement the examples above for the FLAN-T5-small model that we ahve seen above and run the attribution for the same translation example. How do the results compare to your attention visualization results?

Probing#

This approach attempts to identify whether certain kinds of (linguistic) information is contained in a pre-trained model’s representations. One of the ideas behind this approach is trying to identify whether information that humans deem to be critical for completing linguistic tasks (e.g., knowing which part-of-speech (POS) a certain word is) is actually represented, and therefore, potentially used, by the model (rather than relying on some spurious correlations). Another motivation is trying to identify in which layers in the model which information is represented (Tenney et al., 2019).

Below, we will look at probing BERT for POS representations (as demonstrated on slide 75).

The following exercise code is largely taken from this notebook.

For training the classifier and evaluating the representations we will use data files called “en-ud*” which can be found here. If you are using this notebook on Colab, please upload these to a local directory (same as notebook location) named files.

# !pip install spacy ftfy==4.4.3
# !python -m spacy download en
import torch
from transformers import BertTokenizer, BertModel
import numpy as np
import sys
import os
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device('cpu')
# utils

DATA_DIR = "files"
UD_EN_PREF = "en-ud-"

def get_model_and_tokenizer(model_name, device, random_weights=False):

    if model_name.startswith('bert'):
        model = BertModel.from_pretrained(model_name, output_hidden_states=True).to(device)
        tokenizer = BertTokenizer.from_pretrained(model_name)
        sep = '##'
        emb_dim = 1024 if "large" in model_name else 768
    else:
        raise ValueError('Unrecognized model name:', model_name)

    if random_weights:
        print('Randomizing weights')
        model.init_weights()

    return model, tokenizer, sep, emb_dim

# this follows the HuggingFace API for pytorch-transformers
def get_sentence_repr(sentence, model, tokenizer, sep, model_name, device):
    """
    Get representations for one sentence
    """

    with torch.no_grad():
        ids = tokenizer.encode(sentence)
        input_ids = torch.tensor([ids]).to(device)
        # Hugging Face format: list of torch.FloatTensor of shape (batch_size, sequence_length, hidden_size) (hidden_states at output of each layer plus initial embedding outputs)
        all_hidden_states = model(input_ids)[-1]
        # convert to format required for contexteval: numpy array of shape (num_layers, sequence_length, representation_dim)
        all_hidden_states = [hidden_states[0].cpu().numpy() for hidden_states in all_hidden_states]
        all_hidden_states = np.array(all_hidden_states)

    # For each word, take the representation of its last sub-word
    segmented_tokens = tokenizer.convert_ids_to_tokens(ids)
    assert len(segmented_tokens) == all_hidden_states.shape[1], 'incompatible tokens and states'
    mask = np.full(len(segmented_tokens), False)

    if model_name.startswith('bert'):
        # if next token is not a continuation, take current token's representation
        for i in range(len(segmented_tokens)-1):
            if segmented_tokens[i] == "[CLS]" or segmented_tokens[i] == "[SEP]":
                continue
            if not segmented_tokens[i+1].startswith(sep):
                mask[i] = True
    else:
        raise ValueError('Unrecognized model name:', model_name)

    all_hidden_states = all_hidden_states[:, mask]

    return all_hidden_states


def get_pos_data(probing_dir=".", frac=1.0, device='cpu'):

    return get_data("pos", probing_dir=probing_dir, frac=frac, device=device)


def get_data(data_type, probing_dir=".", data_pref=UD_EN_PREF, frac=1.0, device='cpu'):
    with open(os.path.join(probing_dir, DATA_DIR, data_pref + "train.txt")) as f:
        train_sentences = [line.strip().split() for line in f.readlines()]
    with open(os.path.join(probing_dir, DATA_DIR, data_pref + "test.txt")) as f:
        test_sentences = [line.strip().split() for line in f.readlines()]

    with open(os.path.join(probing_dir, DATA_DIR, data_pref + "train." + data_type)) as f:
        train_labels = [line.strip().split() for line in f.readlines()]
    with open(os.path.join(probing_dir, DATA_DIR, data_pref + "test." + data_type)) as f:
        test_labels = [line.strip().split() for line in f.readlines()]

    # take a fraction of the data
    train_sentences = train_sentences[:round(len(train_sentences)*frac)]
    test_sentences = test_sentences[:round(len(test_sentences)*frac)]
    train_labels = train_labels[:round(len(train_labels)*frac)]
    test_labels = test_labels[:round(len(test_labels)*frac)]

    unique_labels = list(set.union(*[set(l) for l in train_labels + test_labels ]))
    label2index = dict()
    for label in unique_labels:
        label2index[label] = label2index.get(label, len(label2index))

    train_labels = [[label2index[l] for l in labels] for labels in train_labels]
    test_labels = [[label2index[l] for l in labels] for labels in test_labels]
    
    return train_sentences, train_labels, test_sentences, test_labels, label2index
# load the data
train_sentences, train_labels, test_sentences, test_labels, label2index = get_pos_data() # frac=0.1
num_labels = len(label2index)
print("Training sentences:", len(train_sentences))
print("Unique labels:", num_labels)
# inspect 
label2index

A probing experiment also requires a probing model, also known as an auxiliary classifier. Here we define a simple linear classifier, which takes a word representation as input and applies a linear transformation to map it to the label space.

class Classifier(torch.nn.Module):
    
    def __init__(self, input_dim, output_dim):
        super(Classifier, self).__init__()
        
        self.linear = torch.nn.Linear(input_dim, output_dim)
        
    def forward(self, input):
        output = self.linear(input)
        return output
    
    
def build_classifier(emb_dim, num_labels, device='cpu'):

    classifier = Classifier(emb_dim, num_labels).to(device)
    criterion = torch.nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.Adam(classifier.parameters())

    return classifier, criterion, optimizer



model_name = 'bert-base-cased'
# get model and tokenizer from Transformers
model, tokenizer, sep, emb_dim = get_model_and_tokenizer(model_name, device)
# build classifier
classifier, criterion, optimizer = build_classifier(emb_dim, num_labels, device)
print(model)
print(classifier)

Training#

Given a pre-trained model, a probing classifier, and supervised linguistic annotations (i.e., POS tags), we can run a probing experiment. First, we’ll define a training function that trains the classifier on the tags. This is a simple implementation, but one could implement various checks like early stopping on a development set, etc.

def train(
    num_epochs, 
    train_representations, 
    train_labels, 
    model, 
    tokenizer, 
    sep, 
    model_name, 
    device, 
    classifier, 
    criterion, 
    optimizer, 
    batch_size=32
):
    
    num_total = train_representations.shape[0] 
    for i in range(num_epochs):
        total_loss = 0.
        num_correct = 0.
        for batch in range(0, num_total, batch_size):
            batch_repr = train_representations[batch: batch+batch_size]
            batch_labels = train_labels[batch: batch+batch_size]

            optimizer.zero_grad()
            
            out = classifier(batch_repr)
            pred = out.max(1)[1]
            num_correct += pred.long().eq(batch_labels.long()).cpu().sum().item()
            loss = criterion(out, batch_labels)
            total_loss += loss.item()

            loss.backward()
            optimizer.step()
#         print('Training epoch: {}, loss: {}, accuracy: {}'.format(i, total_loss/num_total, num_correct/num_total))
    return total_loss/num_total, num_correct/num_total

Evaluation#

Given the trained classifier, we’ll evaluate its performance on the test set.

def evaluate(
    test_representations, 
    test_labels, 
    model, 
    tokenizer, 
    sep, 
    model_name, 
    device, 
    classifier, 
    criterion, 
    batch_size=32
):
    
    num_correct = 0.
    num_total = test_representations.shape[0]
    total_loss = 0.
    with torch.no_grad():
        for batch in range(0, num_total, batch_size):
            batch_repr = test_representations[batch: batch+batch_size]
            batch_labels = test_labels[batch: batch+batch_size]
            
            out = classifier(batch_repr)
            pred = out.max(1)[1]
            num_correct += pred.long().eq(batch_labels.long()).cpu().sum().item()
            total_loss += criterion(out, batch_labels)

#     print('Testing loss: {}, accuracy: {}'.format(total_loss/num_total, num_correct/num_total))
    return total_loss/num_total, num_correct/num_total

Now we put together the functions and perform a probing experiment:

  1. We retrieve word representations from each layer of the model.

  2. Train and evaluate the linear classifier, first only on the last-layer representations.

# top-level list: sentences, second-level lists: layers, third-level tensors of num_words x representation_dim
train_sentence_representations = [get_sentence_repr(sentence, model, tokenizer, sep, model_name, device) 
                                  for sentence in train_sentences]
test_sentence_representations = [get_sentence_repr(sentence, model, tokenizer, sep, model_name, device) 
                                  for sentence in test_sentences]

# top-level list: layers, second-level lists: sentences
train_sentence_representations = [list(l) for l in zip(*train_sentence_representations)]
test_sentence_representations = [list(l) for l in zip(*test_sentence_representations)]                           

# concatenate all word represenations
train_representations_all = [torch.tensor(np.concatenate(train_layer_representations, 0)).to(device) for train_layer_representations in train_sentence_representations]
test_representations_all = [torch.tensor(np.concatenate(test_layer_representations, 0)).to(device) for test_layer_representations in test_sentence_representations]
# concatenate all labels
train_labels_all = torch.tensor(np.concatenate(train_labels, 0)).to(device)
test_labels_all = torch.tensor(np.concatenate(test_labels, 0)).to(device)
# Take final layer representations
train_representations = train_representations_all[-1]
test_representations = test_representations_all[-1]

# train
train_loss, train_accuracy = train(10, train_representations, train_labels_all, 
          model, tokenizer, sep, model_name, device, 
          classifier, criterion, optimizer)
# test
test_loss, test_accuracy = evaluate(test_representations, test_labels_all, 
         model, tokenizer, sep, model_name, device, 
         classifier, criterion)
print("Train accuracy: {}, Test accuracy: {}".format(train_accuracy, test_accuracy))

Exercise 6.1.2: Probing

  1. Run the code above and inspect the testing results. Based on these, do you think the model learned representations of parts-of-speech?

  2. Run another round of training based on representations from some earlier layer of the model. Do you observe any differences? If yes, interpret them with respect to what the model’s representations seem to encode.

  3. [Optional] Instead of using a linear classifier, try to set up a (simple) non-linear classifier. Do the results change?