Path: blob/main/setfit_doc/en/classification_heads.ipynb
8427 views
Classification heads
Any 🤗 SetFit model consists of two parts: a SentenceTransformer embedding body and a classification head.
This guide will show you:
The built-in logistic regression classification head
The built-in differentiable classification head
The requirements for a custom classification head
Logistic Regression classification head
When a new SetFit model is initialized, a scikit-learn logistic regression head is chosen by default. This has been shown to be highly effective when applied on top of a finetuned sentence transformer body, and it remains the recommended classification head. Initializing a new SetFit model with a Logistic Regression head is simple:
To initialize the Logistic Regression head (or any other head) with additional parameters, then you can use the head_params argument on SetFitModel.from_pretrained():
Differentiable classification head
SetFit also provides SetFitHead as an exclusively torch classification head. It uses a linear layer to map the embeddings to the class. It can be used by setting the use_differentiable_head argument on SetFitModel.from_pretrained() to True:
By default, this will assume binary classification. To change that, also set the out_features via head_params to the number of classes that you are using.
Unlike the default Logistic Regression head, the differentiable classification head only supports integer labels in the following range: [0, num_classes).
Training with a differentiable classification head
Using the SetFitHead unlocks some new TrainingArguments that are not used with a sklearn-based head. Note that training with SetFit consists of two phases behind the scenes: finetuning embeddings and training a classification head. As a result, some of the training arguments can be tuples, where the two values are used for each of the two phases, respectively. For a lot of these cases, the second value is only used if the classification head is differentiable. For example:
batch_size: (
Union[int, Tuple[int, int]], defaults to(16, 2)) - The second value in the tuple determines the batch size when training the differentiable SetFitHead.num_epochs: (
Union[int, Tuple[int, int]], defaults to(1, 16)) - The second value in the tuple determines the number of epochs when training the differentiable SetFitHead. In practice, thenum_epochsis usually larger for training the classification head. There are two reasons for this:This training phase does not train with contrastive pairs, so unlike when finetuning the embedding model, you only get one training sample per labeled training text.
This training phase involves training a classifier from scratch, not finetuning an already capable model. We need more training steps for this.
end_to_end: (
bool, defaults toFalse) - IfTrue, train the entire model end-to-end during the classifier training phase. Otherwise, freeze the Sentence Transformer body and only train the head.body_learning_rate: (
Union[float, Tuple[float, float]], defaults to(2e-5, 1e-5)) - The second value in the tuple determines the learning rate of the Sentence Transformer body during the classifier training phase. This is only relevant ifend_to_endisTrue, as otherwise the Sentence Transformer body is frozen when training the classifier.head_learning_rate (
float, defaults to1e-2) - This value determines the learning rate of the differentiable head during the classifier training phase. It is only used if the differentiable head is used.l2_weight (
float, optional) - Optional l2 weight for both the model body and head, passed to theAdamWoptimizer in the classifier training phase only if a differentiable head is used.
For example, a full training script using a differentiable classification head may look something like this:
Custom classification head
Alongside the two built-in options, SetFit allows you to specify a custom classification head. There are two forms of supported heads: a custom differentiable head or a custom non-differentiable head. Both heads must implement the following two methods:
Custom differentiable head
A custom differentiable head must follow these requirements:
Must subclass
nn.Module.A
predictmethod:(self, torch.Tensor with shape [num_inputs, embedding_size]) -> torch.Tensor with shape [num_inputs]- This method classifies the embeddings. The output must integers in the range of[0, num_classes).A
predict_probamethod:(self, torch.Tensor with shape [num_inputs, embedding_size]) -> torch.Tensor with shape [num_inputs, num_classes]- This method classifies the embeddings into probabilities for each class. For each input, the tensor of sizenum_classesmust sum to 1. Applyingtorch.argmax(output, dim=-1)should result in the output forpredict.A
get_loss_fnmethod:(self) -> nn.Module- Returns an initialized loss function, e.g.torch.nn.CrossEntropyLoss().A
forwardmethod:(self, Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]- Given the output from the Sentence Transformer body, i.e. a dictionary of'input_ids','token_type_ids','attention_mask','token_embeddings'and'sentence_embedding'keys, return a dictionary with a'logits'key and atorch.Tensorvalue with shape[batch_size, num_classes].
Custom non-differentiable head
A custom non-differentiable head must follow these requirements:
A
predictmethod:(self, np.array with shape [num_inputs, embedding_size]) -> np.array with shape [num_inputs]- This method classifies the embeddings. The output must integers in the range of[0, num_classes).A
predict_probamethod:(self, np.array with shape [num_inputs, embedding_size]) -> np.array with shape [num_inputs, num_classes]- This method classifies the embeddings into probabilities for each class. For each input, the array of sizenum_classesmust sum to 1. Applyingnp.argmax(output, dim=-1)should result in the output forpredict.A
fitmethod:(self, np.array with shape [num_inputs, embedding_size], List[Any]) -> None- This method must take anumpyarray of embeddings and a list of corresponding labels. The labels need not be integers per se.
Many classifiers from sklearn already fit these requirements, such as RandomForestClassifier, MLPClassifier, KNeighborsClassifier, etc.
When initializing a SetFit model using your custom (non-)differentiable classification head, it is recommended to use the regular __init__ method:
Then, training and inference can commence like normal, e.g.: