CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
huggingface

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: huggingface/notebooks
Path: blob/main/examples/text_classification.ipynb
Views: 2535
Kernel: Python 3 (ipykernel)

If you're opening this Notebook on colab, you will probably need to install 🤗 Transformers, 🤗 Datasets and 🤗 Accelerate. Uncomment the following cell and run it.

#! pip install datasets transformers accelerate

If you're opening this notebook locally, make sure your environment has an install from the last version of those libraries.

To be able to share your model with the community and generate results like the one shown in the picture below via the inference API, there are a few more steps to follow.

First you have to store your authentication token from the Hugging Face website (sign up here if you haven't already!) then execute the following cell and input your username and password:

from huggingface_hub import notebook_login notebook_login()

Then you need to install Git-LFS. Uncomment the following instructions:

# !apt install git-lfs

Make sure your version of Transformers is at least 4.11.0 since the functionality was introduced in that version:

import transformers print(transformers.__version__)

You can find a script version of this notebook to fine-tune your model in a distributed fashion using multiple GPUs or TPUs here.

We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely.

from transformers.utils import send_example_telemetry send_example_telemetry("text_classification_notebook", framework="pytorch")

Fine-tuning a model on a text classification task

In this notebook, we will see how to fine-tune one of the 🤗 Transformers model to a text classification task of the GLUE Benchmark.

Widget inference on a text classification task

The GLUE Benchmark is a group of nine classification tasks on sentences or pairs of sentences which are:

  • CoLA (Corpus of Linguistic Acceptability) Determine if a sentence is grammatically correct or not.is a dataset containing sentences labeled grammatically correct or not.

  • MNLI (Multi-Genre Natural Language Inference) Determine if a sentence entails, contradicts or is unrelated to a given hypothesis. (This dataset has two versions, one with the validation and test set coming from the same distribution, another called mismatched where the validation and test use out-of-domain data.)

  • MRPC (Microsoft Research Paraphrase Corpus) Determine if two sentences are paraphrases from one another or not.

  • QNLI (Question-answering Natural Language Inference) Determine if the answer to a question is in the second sentence or not. (This dataset is built from the SQuAD dataset.)

  • QQP (Quora Question Pairs2) Determine if two questions are semantically equivalent or not.

  • RTE (Recognizing Textual Entailment) Determine if a sentence entails a given hypothesis or not.

  • SST-2 (Stanford Sentiment Treebank) Determine if the sentence has a positive or negative sentiment.

  • STS-B (Semantic Textual Similarity Benchmark) Determine the similarity of two sentences with a score from 1 to 5.

  • WNLI (Winograd Natural Language Inference) Determine if a sentence with an anonymous pronoun and a sentence with this pronoun replaced are entailed or not. (This dataset is built from the Winograd Schema Challenge dataset.)

We will see how to easily load the dataset for each one of those tasks and use the Trainer API to fine-tune a model on it. Each task is named by its acronym, with mnli-mm standing for the mismatched version of MNLI (so same training set as mnli but different validation and test sets):

GLUE_TASKS = ["cola", "mnli", "mnli-mm", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb", "wnli"]

This notebook is built to run on any of the tasks in the list above, with any model checkpoint from the Model Hub as long as that model has a version with a classification head. Depending on you model and the GPU you are using, you might need to adjust the batch size to avoid out-of-memory errors. Set those three parameters, then the rest of the notebook should run smoothly:

task = "cola" model_checkpoint = "distilbert-base-uncased" batch_size = 16

Loading the dataset

We will use the 🤗 Datasets library to download the data and get the metric we need to use for evaluation (to compare our model to the benchmark). This can be easily done with the functions load_dataset and load_metric.

from datasets import load_dataset, load_metric

Apart from mnli-mm being a special code, we can directly pass our task name to those functions. load_dataset will cache the dataset to avoid downloading it again the next time you run this cell.

actual_task = "mnli" if task == "mnli-mm" else task dataset = load_dataset("glue", actual_task) metric = load_metric('glue', actual_task)
Reusing dataset glue (/home/sgugger/.cache/huggingface/datasets/glue/cola/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4)

The dataset object itself is DatasetDict, which contains one key for the training, validation and test set (with more keys for the mismatched validation and test set in the special case of mnli).

dataset
DatasetDict({'train': Dataset(features: {'sentence': Value(dtype='string', id=None), 'label': ClassLabel(num_classes=2, names=['unacceptable', 'acceptable'], names_file=None, id=None), 'idx': Value(dtype='int32', id=None)}, num_rows: 8551), 'validation': Dataset(features: {'sentence': Value(dtype='string', id=None), 'label': ClassLabel(num_classes=2, names=['unacceptable', 'acceptable'], names_file=None, id=None), 'idx': Value(dtype='int32', id=None)}, num_rows: 1043), 'test': Dataset(features: {'sentence': Value(dtype='string', id=None), 'label': ClassLabel(num_classes=2, names=['unacceptable', 'acceptable'], names_file=None, id=None), 'idx': Value(dtype='int32', id=None)}, num_rows: 1063)})

To access an actual element, you need to select a split first, then give an index:

dataset["train"][0]
{'idx': 0, 'label': 1, 'sentence': "Our friends won't buy this analysis, let alone the next one we propose."}

To get a sense of what the data looks like, the following function will show some examples picked randomly in the dataset.

import datasets import random import pandas as pd from IPython.display import display, HTML def show_random_elements(dataset, num_examples=10): assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset." picks = [] for _ in range(num_examples): pick = random.randint(0, len(dataset)-1) while pick in picks: pick = random.randint(0, len(dataset)-1) picks.append(pick) df = pd.DataFrame(dataset[picks]) for column, typ in dataset.features.items(): if isinstance(typ, datasets.ClassLabel): df[column] = df[column].transform(lambda i: typ.names[i]) display(HTML(df.to_html()))
show_random_elements(dataset["train"])

The metric is an instance of datasets.Metric:

metric
Metric(name: "glue", features: {'predictions': Value(dtype='int64', id=None), 'references': Value(dtype='int64', id=None)}, usage: """ Compute GLUE evaluation metric associated to each GLUE dataset. Args: predictions: list of predictions to score. Each translation should be tokenized into a list of tokens. references: list of lists of references for each translation. Each reference should be tokenized into a list of tokens. Returns: depending on the GLUE subset, one or several of: "accuracy": Accuracy "f1": F1 score "pearson": Pearson Correlation "spearmanr": Spearman Correlation "matthews_correlation": Matthew Correlation """, stored examples: 0)

You can call its compute method with your predictions and labels directly and it will return a dictionary with the metric(s) value:

import numpy as np fake_preds = np.random.randint(0, 2, size=(64,)) fake_labels = np.random.randint(0, 2, size=(64,)) metric.compute(predictions=fake_preds, references=fake_labels)
{'matthews_correlation': -0.12156862745098039}

Note that load_metric has loaded the proper metric associated to your task, which is:

so the metric object only computes the one(s) needed for your task.

Preprocessing the data

Before we can feed those texts to our model, we need to preprocess them. This is done by a 🤗 Transformers Tokenizer which will (as the name indicates) tokenize the inputs (including converting the tokens to their corresponding IDs in the pretrained vocabulary) and put it in a format the model expects, as well as generate the other inputs that model requires.

To do all of this, we instantiate our tokenizer with the AutoTokenizer.from_pretrained method, which will ensure:

  • we get a tokenizer that corresponds to the model architecture we want to use,

  • we download the vocabulary used when pretraining this specific checkpoint.

That vocabulary will be cached, so it's not downloaded again the next time we run the cell.

from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)

We pass along use_fast=True to the call above to use one of the fast tokenizers (backed by Rust) from the 🤗 Tokenizers library. Those fast tokenizers are available for almost all models, but if you got an error with the previous call, remove that argument.

You can directly call this tokenizer on one sentence or a pair of sentences:

tokenizer("Hello, this one sentence!", "And this sentence goes with it.")
{'input_ids': [101, 7592, 1010, 2023, 2028, 6251, 999, 102, 1998, 2023, 6251, 3632, 2007, 2009, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

Depending on the model you selected, you will see different keys in the dictionary returned by the cell above. They don't matter much for what we're doing here (just know they are required by the model we will instantiate later), you can learn more about them in this tutorial if you're interested.

To preprocess our dataset, we will thus need the names of the columns containing the sentence(s). The following dictionary keeps track of the correspondence task to column names:

task_to_keys = { "cola": ("sentence", None), "mnli": ("premise", "hypothesis"), "mnli-mm": ("premise", "hypothesis"), "mrpc": ("sentence1", "sentence2"), "qnli": ("question", "sentence"), "qqp": ("question1", "question2"), "rte": ("sentence1", "sentence2"), "sst2": ("sentence", None), "stsb": ("sentence1", "sentence2"), "wnli": ("sentence1", "sentence2"), }

We can double check it does work on our current dataset:

sentence1_key, sentence2_key = task_to_keys[task] if sentence2_key is None: print(f"Sentence: {dataset['train'][0][sentence1_key]}") else: print(f"Sentence 1: {dataset['train'][0][sentence1_key]}") print(f"Sentence 2: {dataset['train'][0][sentence2_key]}")
Sentence: Our friends won't buy this analysis, let alone the next one we propose.

We can then write the function that will preprocess our samples. We just feed them to the tokenizer with the argument truncation=True. This will ensure that an input longer that what the model selected can handle will be truncated to the maximum length accepted by the model.

def preprocess_function(examples): if sentence2_key is None: return tokenizer(examples[sentence1_key], truncation=True) return tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True)

This function works with one or several examples. In the case of several examples, the tokenizer will return a list of lists for each key:

preprocess_function(dataset['train'][:5])
{'input_ids': [[101, 2256, 2814, 2180, 1005, 1056, 4965, 2023, 4106, 1010, 2292, 2894, 1996, 2279, 2028, 2057, 16599, 1012, 102], [101, 2028, 2062, 18404, 2236, 3989, 1998, 1045, 1005, 1049, 3228, 2039, 1012, 102], [101, 2028, 2062, 18404, 2236, 3989, 2030, 1045, 1005, 1049, 3228, 2039, 1012, 102], [101, 1996, 2062, 2057, 2817, 16025, 1010, 1996, 13675, 16103, 2121, 2027, 2131, 1012, 102], [101, 2154, 2011, 2154, 1996, 8866, 2024, 2893, 14163, 8024, 3771, 1012, 102]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}

To apply this function on all the sentences (or pairs of sentences) in our dataset, we just use the map method of our dataset object we created earlier. This will apply the function on all the elements of all the splits in dataset, so our training, validation and testing data will be preprocessed in one single command.

encoded_dataset = dataset.map(preprocess_function, batched=True)

Even better, the results are automatically cached by the 🤗 Datasets library to avoid spending time on this step the next time you run your notebook. The 🤗 Datasets library is normally smart enough to detect when the function you pass to map has changed (and thus requires to not use the cache data). For instance, it will properly detect if you change the task in the first cell and rerun the notebook. 🤗 Datasets warns you when it uses cached files, you can pass load_from_cache_file=False in the call to map to not use the cached files and force the preprocessing to be applied again.

Note that we passed batched=True to encode the texts by batches together. This is to leverage the full benefit of the fast tokenizer we loaded earlier, which will use multi-threading to treat the texts in a batch concurrently.

Fine-tuning the model

Now that our data is ready, we can download the pretrained model and fine-tune it. Since all our tasks are about sentence classification, we use the AutoModelForSequenceClassification class. Like with the tokenizer, the from_pretrained method will download and cache the model for us. The only thing we have to specify is the number of labels for our problem (which is always 2, except for STS-B which is a regression problem and MNLI where we have 3 labels):

from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer num_labels = 3 if task.startswith("mnli") else 1 if task=="stsb" else 2 model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias'] - This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model). - This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classifier.weight', 'classifier.bias'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

The warning is telling us we are throwing away some weights (the vocab_transform and vocab_layer_norm layers) and randomly initializing some other (the pre_classifier and classifier layers). This is absolutely normal in this case, because we are removing the head used to pretrain the model on a masked language modeling objective and replacing it with a new head for which we don't have pretrained weights, so the library warns us we should fine-tune this model before using it for inference, which is exactly what we are going to do.

To instantiate a Trainer, we will need to define two more things. The most important is the TrainingArguments, which is a class that contains all the attributes to customize the training. It requires one folder name, which will be used to save the checkpoints of the model, and all other arguments are optional:

metric_name = "pearson" if task == "stsb" else "matthews_correlation" if task == "cola" else "accuracy" model_name = model_checkpoint.split("/")[-1] args = TrainingArguments( f"{model_name}-finetuned-{task}", evaluation_strategy = "epoch", save_strategy = "epoch", learning_rate=2e-5, per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size, num_train_epochs=5, weight_decay=0.01, load_best_model_at_end=True, metric_for_best_model=metric_name, push_to_hub=True, )

Here we set the evaluation to be done at the end of each epoch, tweak the learning rate, use the batch_size defined at the top of the notebook and customize the number of epochs for training, as well as the weight decay. Since the best model might not be the one at the end of training, we ask the Trainer to load the best model it saved (according to metric_name) at the end of training.

The last argument to setup everything so we can push the model to the Hub regularly during training. Remove it if you didn't follow the installation steps at the top of the notebook. If you want to save your model locally in a name that is different than the name of the repository it will be pushed, or if you want to push your model under an organization and not your name space, use the hub_model_id argument to set the repo name (it needs to be the full name, including your namespace: for instance "sgugger/bert-finetuned-mrpc" or "huggingface/bert-finetuned-mrpc").

The last thing to define for our Trainer is how to compute the metrics from the predictions. We need to define a function for this, which will just use the metric we loaded earlier, the only preprocessing we have to do is to take the argmax of our predicted logits (our just squeeze the last axis in the case of STS-B):

def compute_metrics(eval_pred): predictions, labels = eval_pred if task != "stsb": predictions = np.argmax(predictions, axis=1) else: predictions = predictions[:, 0] return metric.compute(predictions=predictions, references=labels)

Then we just need to pass all of this along with our datasets to the Trainer:

validation_key = "validation_mismatched" if task == "mnli-mm" else "validation_matched" if task == "mnli" else "validation" trainer = Trainer( model, args, train_dataset=encoded_dataset["train"], eval_dataset=encoded_dataset[validation_key], tokenizer=tokenizer, compute_metrics=compute_metrics )

You might wonder why we pass along the tokenizer when we already preprocessed our data. This is because we will use it once last time to make all the samples we gather the same length by applying padding, which requires knowing the model's preferences regarding padding (to the left or right? with which token?). The tokenizer has a pad method that will do all of this right for us, and the Trainer will use it. You can customize this part by defining and passing your own data_collator which will receive the samples like the dictionaries seen above and will need to return a dictionary of tensors.

We can now finetune our model by just calling the train method:

trainer.train()
TrainOutput(global_step=2675, training_loss=0.26824118355724297)

We can check with the evaluate method that our Trainer did reload the best model properly (if it was not the last one):

trainer.evaluate()
{'eval_loss': 0.6563718318939209, 'eval_matthews_correlation': 0.5432575763528743, 'epoch': 5.0, 'total_flos': 354601098011100}

To see how your model fared you can compare it to the GLUE Benchmark leaderboard.

You can now upload the result of the training to the Hub, just execute this instruction:

trainer.push_to_hub()

You can now share this model with all your friends, family, favorite pets: they can all load it with the identifier "your-username/the-name-you-picked" so for instance:

from transformers import AutoModelForSequenceClassification model = AutoModelForSequenceClassification.from_pretrained("sgugger/my-awesome-model")

The Trainer supports hyperparameter search using optuna or Ray Tune. For this last section you will need either of those libraries installed, just uncomment the line you want on the next cell and run it.

# ! pip install optuna # ! pip install ray[tune]

During hyperparameter search, the Trainer will run several trainings, so it needs to have the model defined via a function (so it can be reinitialized at each new run) instead of just having it passed. We jsut use the same function as before:

def model_init(): return AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)

And we can instantiate our Trainer like before:

trainer = Trainer( model_init=model_init(), args=args, train_dataset=encoded_dataset["train"], eval_dataset=encoded_dataset[validation_key], tokenizer=tokenizer, compute_metrics=compute_metrics )
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias'] - This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model). - This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classifier.weight', 'classifier.bias'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

The method we call this time is hyperparameter_search. Note that it can take a long time to run on the full dataset for some of the tasks. You can try to find some good hyperparameter on a portion of the training dataset by replacing the train_dataset line above by:

train_dataset = encoded_dataset["train"].shard(index=1, num_shards=10)

for 1/10th of the dataset. Then you can run a full training on the best hyperparameters picked by the search.

best_run = trainer.hyperparameter_search(n_trials=10, direction="maximize")
[I 2020-10-15 17:06:50,302] A new study created in memory with name: no-name-5ca133fa-884f-4b0b-beb6-82fb19c9da06 Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias'] - This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model). - This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classifier.weight', 'classifier.bias'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[I 2020-10-15 17:08:54,670] Trial 0 finished with value: 0.4276247641425646 and parameters: {'learning_rate': 9.066394855554376e-05, 'num_train_epochs': 4, 'seed': 3, 'per_device_train_batch_size': 8}. Best is trial 0 with value: 0.4276247641425646. Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias'] - This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model). - This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classifier.weight', 'classifier.bias'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[I 2020-10-15 17:09:27,505] Trial 1 finished with value: 0.4806486312161432 and parameters: {'learning_rate': 4.70956471834944e-05, 'num_train_epochs': 1, 'seed': 40, 'per_device_train_batch_size': 8}. Best is trial 1 with value: 0.4806486312161432. Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias'] - This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model). - This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classifier.weight', 'classifier.bias'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[I 2020-10-15 17:09:51,733] Trial 2 finished with value: 0.5215162259225145 and parameters: {'learning_rate': 4.357724525964853e-05, 'num_train_epochs': 2, 'seed': 38, 'per_device_train_batch_size': 32}. Best is trial 2 with value: 0.5215162259225145. Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias'] - This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model). - This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classifier.weight', 'classifier.bias'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[I 2020-10-15 17:10:42,034] Trial 3 finished with value: 0.5146251651052666 and parameters: {'learning_rate': 1.966064728684351e-05, 'num_train_epochs': 3, 'seed': 4, 'per_device_train_batch_size': 16}. Best is trial 2 with value: 0.5215162259225145. Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias'] - This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model). - This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classifier.weight', 'classifier.bias'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
/home/sgugger/.pyenv/versions/3.7.9/envs/base/lib/python3.7/site-packages/sklearn/metrics/_classification.py:846: RuntimeWarning: invalid value encountered in double_scalars mcc = cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp) [I 2020-10-15 17:11:00,530] Trial 4 finished with value: 0.0 and parameters: {'learning_rate': 1.1681626444738994e-06, 'num_train_epochs': 1, 'seed': 35, 'per_device_train_batch_size': 16}. Best is trial 2 with value: 0.5215162259225145. Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias'] - This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model). - This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classifier.weight', 'classifier.bias'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
/home/sgugger/.pyenv/versions/3.7.9/envs/base/lib/python3.7/site-packages/sklearn/metrics/_classification.py:846: RuntimeWarning: invalid value encountered in double_scalars mcc = cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp) [I 2020-10-15 17:11:16,708] Trial 5 pruned. Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias'] - This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model). - This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classifier.weight', 'classifier.bias'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[I 2020-10-15 17:11:27,293] Trial 6 pruned. Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias'] - This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model). - This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classifier.weight', 'classifier.bias'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[I 2020-10-15 17:11:39,998] Trial 7 finished with value: 0.47750559698251505 and parameters: {'learning_rate': 8.294058232883187e-05, 'num_train_epochs': 1, 'seed': 14, 'per_device_train_batch_size': 64}. Best is trial 2 with value: 0.5215162259225145. Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias'] - This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model). - This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classifier.weight', 'classifier.bias'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
/home/sgugger/.pyenv/versions/3.7.9/envs/base/lib/python3.7/site-packages/sklearn/metrics/_classification.py:846: RuntimeWarning: invalid value encountered in double_scalars mcc = cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp) [I 2020-10-15 17:11:50,629] Trial 8 pruned. Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias'] - This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model). - This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classifier.weight', 'classifier.bias'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[I 2020-10-15 17:12:09,216] Trial 9 finished with value: 0.4496077935780946 and parameters: {'learning_rate': 3.983123210032869e-05, 'num_train_epochs': 1, 'seed': 35, 'per_device_train_batch_size': 16}. Best is trial 2 with value: 0.5215162259225145.

The hyperparameter_search method returns a BestRun objects, which contains the value of the objective maximized (by default the sum of all metrics) and the hyperparameters it used for that run.

best_run
BestRun(run_id='2', objective=0.5215162259225145, hyperparameters={'learning_rate': 4.357724525964853e-05, 'num_train_epochs': 2, 'seed': 38, 'per_device_train_batch_size': 32})

You can customize the objective to maximize by passing along a compute_objective function to the hyperparameter_search method, and you can customize the search space by passing a hp_space argument to hyperparameter_search. See this forum post for some examples.

To reproduce the best training, just set the hyperparameters in your TrainingArgument before creating a Trainer:

for n, v in best_run.hyperparameters.items(): setattr(trainer.args, n, v) trainer.train()
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias'] - This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model). - This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classifier.weight', 'classifier.bias'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
TrainOutput(global_step=536, training_loss=0.4191349371155696)