Bonus. Messing with a GPT Model
In lecture, we have seen the basic architecture of a transformer model:
Let's try to work with these components in real life!
In this final mini-lab of the semester, we would delve into a very lightweight (toy) GPT model, distilgpt2 (Distilled-GPT2). It only contains 82 million parameters, so it would comfortably run locally just with your CPU (actual SoTA models have hundreds/thousands of billions of parameters!).
Training a model from scratch is a difficult task (and requires LOTSSS of computing resources), so we would load a pretrained model. We will then mess around with the weights of the model and observe how the model behaviour changes. Hopefully, this would give us a deeper understanding on how the model is internally structured.
Note: You will need the transformers package. Install it using pip install transformers. Note that you can't really do this on the Cocalc server since Internet access is required for you to download the model.
Your task: There are 4 tasks in this mini-lab. Complete them in sequence.
Submission: Submit your writeup and/or implementation before/during the tutorial for extra EXP! (If you are unable to attend my tutorial, you can choose to replace EXP by free bubble tea!)
Warm-up: Loading the Model
The first code snippet downloads and loads the distilgpt2 model from the Internet. The second code snippet runs an inference with the input text Once upon a time. See how the model generates a story! (about the US :0)
The third code snippet demonstrates how you could tamper with the weights of the model. The basic idea is to first extract the weights by calling the method state_dict(), tamper with the weights, then load it back using the method load_state_dict().
Note that the following tasks will instruct you to mess around with the weights. You should re-load the model at the beginning of each task so that the tampered weights will not carry over.
Task 1: Ruin the model!
As a warm-up, try to implement your own weight tampering so that the model no longer outputs "United States" for the input "Once upon a time". You are free to tamper anything you wish.
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Cell In[10], line 10
7 output = model.generate(input_ids, attention_mask=attention_mask, pad_token_id=tokenizer.eos_token_id, max_length=25)
8 output_text = tokenizer.decode(output.sequences[0], skip_special_tokens=True)
---> 10 assert "United States" not in output_text, \
11 "Output still contains 'United States'"
13 print("Test case passed!")
AssertionError: Output still contains 'United States'
Task 2: Once a upon time
Let's try to run another inference with the input text Once a upon time.
Notice that the model now outputs gibberish. (What does "the world is a little more like a place where you can see the world" even mean?)
The reason that the model outputs differently for these two inputs is that the tokens a and upon corresponds to different embeddings (in this model, an embedding is simply a vector of size 768, the hidden size). In addition to token embeddings, we also have position embeddings which adds in position-related information (the first token will have a different position embedding as the second token). The model adds together these two types of embeddings.
Task: Modify the embedding weights such that the model outputs "United States" for the input Once a upon time.
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Cell In[15], line 10
7 output = model.generate(input_ids, attention_mask=attention_mask, pad_token_id=tokenizer.eos_token_id, max_length=25)
8 output_text = tokenizer.decode(output.sequences[0], skip_special_tokens=True)
---> 10 assert "United States" in output_text, \
11 "Output does not contain 'United States'"
13 print("Test case passed!")
AssertionError: Output does not contain 'United States'
Task 3: What is the model seeing?
We know that the GPT model utilizes the self-attention mechanism. Let's try to visualize the attention scores of the model!
Run the following cells to generate the visualization. Observe how different attention heads draw different connections between words.
Now, let's try to remove some of the attention heads!
We cannot really remove the attention heads by only tampering with the weights. However, recall that the attention scores are calculated using this formula:
By zeroing out either or , all attention scores will essentially become the same!
Task: Implement this wiping by zeroing out either or , only for attention heads 0 - 7 (out of the 12 attention heads) of the first layer. Observe how the attention scores change.
Hint 1: You should search for the weight tensors , and in the list of weights above. In this model, , and are concatenated as a single tensor. Hint 2: Don't forget to wipe out the bias as well.
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Cell In[22], line 12
9 attention_values = output.attentions[0][0][0]
11 for i in range(0, 8):
---> 12 assert torch.all(torch.isclose(output.attentions[0][0][0][i], torch.tensor(
13 [[1, 0, 0, 0], [1/2, 1/2, 0, 0], [1/3, 1/3, 1/3, 0], [1/4, 1/4, 1/4, 1/4]]
14 ))), f"Attention scores in head {i} is has not been wiped out"
16 for i in range(8, 12):
17 assert not torch.all(torch.isclose(output.attentions[0][0][0][i], torch.tensor(
18 [[1, 0, 0, 0], [1/2, 1/2, 0, 0], [1/3, 1/3, 1/3, 0], [1/4, 1/4, 1/4, 1/4]]
19 ))), f"Attention scores in head {i} is incorrectly wiped out"
AssertionError: Attention scores in head 0 is has not been wiped out
Task 4: Why not positional encoding?
In task 2, we swapped the locations of the tokens a and upon. You might have tried swapping the position encodings of the second and third token.
However, while the output of the model under the input Once a upon time changes, you might have noticed that it is not entirely same as if we gave it the prompt Once upon a time. This implies that other than position encoding, there are some parts of the model leveraging the positional information of the tokens.
With reference to your observations in Task 3, explain why this is the case.