Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/setfit_doc/en/zero_shot.ipynb
Views: 2542
Zero-shot Text Classification
Although SetFit was designed for few-shot learning, the method can also be applied in scenarios where no labeled data is available. The main trick is to create synthetic examples that resemble the classification task, and then train a SetFit model on them.
Remarkably, this simple technique typically outperforms the zero-shot pipeline in 🤗 Transformers, and can generate predictions by a factor of 5x (or more) faster!
In this tutorial, we'll explore how:
SetFit can be applied for zero-shot classification
Adding synthetic examples can also provide a performance boost to few-shot classification.
Setup
If you're running this Notebook on Colab or some other cloud platform, you will need to install the setfit
library. Uncomment the following cell and run it:
To benchmark the performance of the "zero-shot" method, we'll use the following dataset and pretrained model:
Next, we'll download the reference dataset from the Hugging Face Hub:
Now that we're set up, let's create some synthetic data to train on!
Creating a synthetic dataset
The first thing we need to do is create a dataset of synthetic examples. In setfit
, we can do this by applying the get_templated_dataset()
function to a dummy dataset. This function expects a few main things:
A list of candidate labels to classify with. We'll use the labels from the reference dataset here, but this could be anything that's relevant to the task and dataset at hand.
A template to generate examples with. By default, it is
"This sentence is {}"
, where the{}
will be filled by one of the candidate labelsA sample size , which will create synthetic examples per class. We find usually works best.
Armed with this information, let's first extract some candidate labels from the dataset:
Some datasets on the Hugging Face Hub don't have a ClassLabel
feature for the label column. In these cases, you should compute the candidate labels manually by first computing the id2label mapping as follows:
Now that we have the labels, it's a simple matter to create synthetic examples:
You might find you can get better performance by tweaking the template
argument from the default of "The sentence is {}"
to variants like "This sentence is {}"
or "This example is {}"
.
Since our dataset has 6 classes and we chose a sample size of 8, our synthetic dataset contains examples. If we take a look at a few of the examples:
We can see that each input takes the form of the template and has a corresponding label associated with it.
Let's not train a SetFit model on these examples!
Fine-tuning the model
To train a SetFit model, the first thing to do is download a pretrained checkpoint from the Hub. We can do so by using the SetFitModel.from_pretrained()
method:
Here, we've downloaded a pretrained Sentence Transformer from the Hub and added a logistic classification head to the create the SetFit model. As indicated in the message, we need to train this model on some labeled examples. We can do so by using the Trainer class as follows:
Now that we've created a trainer, we can train it! While we're at it, let's time how long it takes to train and evaluate the model:
Great, now that we have a reference score let's compare against the zero-shot pipeline from 🤗 Transformers.
Comparing against the zero-shot pipeline from 🤗 Transformers
🤗 Transformers provides a zero-shot pipeline that frames text classification as a natural language inference task. Let's load the pipeline and place it on the GPU for fast inference:
Now that we have the model, let's generate some predictions. We'll use the same candidate labels as we did with SetFit and increase the batch size for to speed things up:
Note that this took almost 5x longer to generate predictions than SetFit! OK, so how well does it perform? Since each prediction is a dictionary of label names ranked by score:
We can use the str2int()
function from the label
column to convert them to integers.
Note: As noted earlier, if you're using a dataset that doesn't have a ClassLabel
feature for the label column, you'll need to compute the label mapping manually with something like:
The last step is to compute accuracy using 🤗 Evaluate:
Compared to SetFit, this approach performs significantly worse. Let's wrap up our analysis by combining synthetic examples with a few labeled ones.
Augmenting labeled data with synthetic examples
If you have a few labeled examples, adding synthetic data can often boost performance. To simulate this, let's first sample 8 labeled examples from our reference dataset:
To warm up, we'll train a SetFit model on these true labels:
Note that for this particular dataset, the performance with true labels is worse than training on synthetic examples! In our experiments, we found that the difference depends strongly on the dataset in question. Since SetFit models are fast to train, you can always try both approaches and pick the best one.
In any case, let's now add some synthetic examples to our training set:
As before, we can train and evaluate SetFit with the augmented dataset:
Great, this has given us a significant boost in performance and given us a few percentage points over the purely synthetic example.
Let's plot the final results for comparison: