Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/peft_docs/en/clm-prompt-tuning.ipynb
Views: 2542
Prompt tuning for causal language modeling
Prompting helps guide language model behavior by adding some input text specific to a task. Prompt tuning is an additive method for only training and updating the newly added prompt tokens to a pretrained model. This way, you can use one pretrained model whose weights are frozen, and train and update a smaller set of prompt parameters for each downstream task instead of fully finetuning a separate model. As models grow larger and larger, prompt tuning can be more efficient, and results are even better as model parameters scale.
💡 Read The Power of Scale for Parameter-Efficient Prompt Tuning to learn more about prompt tuning.
This guide will show you how to apply prompt tuning to train a bloomz-560m
model on the twitter_complaints
subset of the RAFT dataset.
Before you begin, make sure you have all the necessary libraries installed:
Setup
Start by defining the model and tokenizer, the dataset and the dataset columns to train on, some training hyperparameters, and the PromptTuningConfig. The PromptTuningConfig contains information about the task type, the text to initialize the prompt embedding, the number of virtual tokens, and the tokenizer to use:
Load dataset
For this guide, you'll load the twitter_complaints
subset of the RAFT dataset. This subset contains tweets that are labeled either complaint
or no complaint
:
To make the Label
column more readable, replace the Label
value with the corresponding label text and store them in a text_label
column. You can use the map function to apply this change over the entire dataset in one step:
Preprocess dataset
Next, you'll setup a tokenizer; configure the appropriate padding token to use for padding sequences, and determine the maximum length of the tokenized labels:
Create a preprocess_function
to:
Tokenize the input text and labels.
For each example in a batch, pad the labels with the tokenizers
pad_token_id
.Concatenate the input text and labels into the
model_inputs
.Create a separate attention mask for
labels
andmodel_inputs
.Loop through each example in the batch again to pad the input ids, labels, and attention mask to the
max_length
and convert them to PyTorch tensors.
Use the map function to apply the preprocess_function
to the entire dataset. You can remove the unprocessed columns since the model won't need them:
Create a DataLoader
from the train
and eval
datasets. Set pin_memory=True
to speed up the data transfer to the GPU during training if the samples in your dataset are on a CPU.
Train
You're almost ready to setup your model and start training!
Initialize a base model from AutoModelForCausalLM, and pass it and peft_config
to the get_peft_model()
function to create a PeftModel. You can print the new PeftModel's trainable parameters to see how much more efficient it is than training the full parameters of the original model!
Setup an optimizer and learning rate scheduler:
Move the model to the GPU, then write a training loop to start training!
Share model
You can store and share your model on the Hub if you'd like. Log in to your Hugging Face account and enter your token when prompted:
Use the push_to_hub function to upload your model to a model repository on the Hub:
Once the model is uploaded, you'll see the model file size is only 33.5kB! 🤏
Inference
Let's try the model on a sample input for inference. If you look at the repository you uploaded the model to, you'll see a adapter_config.json
file. Load this file into PeftConfig to specify the peft_type
and task_type
. Then you can load the prompt tuned model weights, and the configuration into from_pretrained() to create the PeftModel:
Grab a tweet and tokenize it:
Put the model on a GPU and generate the predicted label: