CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
huggingface

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: huggingface/notebooks
Path: blob/main/examples/text_classification-tf.ipynb
Views: 2535
Kernel: Python 3 (ipykernel)

If you're opening this Notebook on colab, you will probably need to install the most recent versions of 🤗 Transformers and 🤗 Datasets. We will also need scipy and scikit-learn for some of the metrics. Uncomment the following cell and run it.

#! pip install transformers #! pip install datasets #! pip install scipy sklearn #! pip install huggingface_hub

If you're opening this notebook locally, make sure your environment has an install from the latest version of those libraries.

To be able to share your model with the community and generate results like the one shown in the picture below via the inference API, there are a few more steps to follow:

First you have to create an access token on the Hugging Face website (sign up here if you haven't already!) then uncomment the following cell and input your token.

from huggingface_hub import notebook_login notebook_login()

Then you need to install Git-LFS and setup Git if you haven't already. Uncomment the following instructions and adapt with your name and email:

# !apt install git-lfs # !git config --global user.email "[email protected]" # !git config --global user.name "Your Name"

Make sure your version of Transformers is at least 4.16.0 since the functionality was introduced in that version:

import transformers print(transformers.__version__)
4.22.0.dev0

You can find a script version of this notebook to fine-tune your model in a distributed fashion using multiple GPUs or TPUs here.

We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely.

from transformers.utils import send_example_telemetry send_example_telemetry("text_classification_notebook", framework="tensorflow")

Fine-tuning a model on a text classification task

In this notebook, we will see how to fine-tune one of the 🤗 Transformers model to a text classification task of the GLUE Benchmark.

Widget inference on a text classification task

The GLUE Benchmark is a group of nine classification tasks on sentences or pairs of sentences which are:

  • CoLA (Corpus of Linguistic Acceptability) Determine if a sentence is grammatically correct or not.is a dataset containing sentences labeled grammatically correct or not.

  • MNLI (Multi-Genre Natural Language Inference) Determine if a sentence entails, contradicts or is unrelated to a given hypothesis. (This dataset has two versions, one with the validation and test set coming from the same distribution, another called mismatched where the validation and test use out-of-domain data.)

  • MRPC (Microsoft Research Paraphrase Corpus) Determine if two sentences are paraphrases from one another or not.

  • QNLI (Question-answering Natural Language Inference) Determine if the answer to a question is in the second sentence or not. (This dataset is built from the SQuAD dataset.)

  • QQP (Quora Question Pairs2) Determine if two questions are semantically equivalent or not.

  • RTE (Recognizing Textual Entailment) Determine if a sentence entails a given hypothesis or not.

  • SST-2 (Stanford Sentiment Treebank) Determine if the sentence has a positive or negative sentiment.

  • STS-B (Semantic Textual Similarity Benchmark) Determine the similarity of two sentences with a score from 1 to 5.

  • WNLI (Winograd Natural Language Inference) Determine if a sentence with an anonymous pronoun and a sentence with this pronoun replaced are entailed or not. (This dataset is built from the Winograd Schema Challenge dataset.)

We will see how to easily load the dataset for each one of those tasks and use Keras to fine-tune a model on it. Each task is named by its acronym, with mnli-mm standing for the mismatched version of MNLI (a task with the same training set as mnli but different validation and test sets):

GLUE_TASKS = [ "cola", "mnli", "mnli-mm", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb", "wnli", ]

This notebook is built to run on any of the tasks in the list above, with any model checkpoint from the Model Hub as long as that model has a version with a classification head. Depending on your model and the GPU you are using, you might need to adjust the batch size to avoid out-of-memory errors. Set these three parameters, then the rest of the notebook should run smoothly:

task = "cola" model_checkpoint = "distilbert-base-uncased" batch_size = 16

Loading the dataset

We will use the 🤗 Datasets library to download the data and the 🤗 Evaluate library to get the metric we need to use for evaluation (to compare our model to the benchmark). This can be easily done with the load_dataset function from datasets and and the load function from evaluate.

from datasets import load_dataset from evaluate import load

With the exception of mnli-mm, we can directly pass our task name to those functions. load_dataset will cache the dataset to avoid downloading it again the next time you run this cell.

actual_task = "mnli" if task == "mnli-mm" else task dataset = load_dataset("glue", actual_task) metric = load("glue", actual_task)
Reusing dataset glue (/home/matt/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)

The dataset object itself is DatasetDict, which contains one key for the training, validation and test set (with more keys for the mismatched validation and test set in the special case of mnli).

dataset
DatasetDict({ train: Dataset({ features: ['sentence', 'label', 'idx'], num_rows: 8551 }) validation: Dataset({ features: ['sentence', 'label', 'idx'], num_rows: 1043 }) test: Dataset({ features: ['sentence', 'label', 'idx'], num_rows: 1063 }) })

To access an actual element, you need to select a split first, then give an index:

dataset["train"][0]
{'sentence': "Our friends won't buy this analysis, let alone the next one we propose.", 'label': 1, 'idx': 0}

To get a sense of what the data looks like, the following function will show some examples picked randomly in the dataset.

import datasets import random import pandas as pd from IPython.display import display, HTML def show_random_elements(dataset, num_examples=10): assert num_examples <= len( dataset ), "Can't pick more elements than there are in the dataset." picks = [] for _ in range(num_examples): pick = random.randint(0, len(dataset) - 1) while pick in picks: pick = random.randint(0, len(dataset) - 1) picks.append(pick) df = pd.DataFrame(dataset[picks]) for column, typ in dataset.features.items(): if isinstance(typ, datasets.ClassLabel): df[column] = df[column].transform(lambda i: typ.names[i]) display(HTML(df.to_html()))
show_random_elements(dataset["train"])

The metric is an instance of datasets.Metric:

metric
EvaluationModule(name: "glue", module_type: "metric", features: {'predictions': Value(dtype='int64', id=None), 'references': Value(dtype='int64', id=None)}, usage: """ Compute GLUE evaluation metric associated to each GLUE dataset. Args: predictions: list of predictions to score. Each translation should be tokenized into a list of tokens. references: list of lists of references for each translation. Each reference should be tokenized into a list of tokens. Returns: depending on the GLUE subset, one or several of: "accuracy": Accuracy "f1": F1 score "pearson": Pearson Correlation "spearmanr": Spearman Correlation "matthews_correlation": Matthew Correlation Examples: >>> glue_metric = evaluate.load('glue', 'sst2') # 'sst2' or any of ["mnli", "mnli_mismatched", "mnli_matched", "qnli", "rte", "wnli", "hans"] >>> references = [0, 1] >>> predictions = [0, 1] >>> results = glue_metric.compute(predictions=predictions, references=references) >>> print(results) {'accuracy': 1.0} >>> glue_metric = evaluate.load('glue', 'mrpc') # 'mrpc' or 'qqp' >>> references = [0, 1] >>> predictions = [0, 1] >>> results = glue_metric.compute(predictions=predictions, references=references) >>> print(results) {'accuracy': 1.0, 'f1': 1.0} >>> glue_metric = evaluate.load('glue', 'stsb') >>> references = [0., 1., 2., 3., 4., 5.] >>> predictions = [0., 1., 2., 3., 4., 5.] >>> results = glue_metric.compute(predictions=predictions, references=references) >>> print({"pearson": round(results["pearson"], 2), "spearmanr": round(results["spearmanr"], 2)}) {'pearson': 1.0, 'spearmanr': 1.0} >>> glue_metric = evaluate.load('glue', 'cola') >>> references = [0, 1] >>> predictions = [0, 1] >>> results = glue_metric.compute(predictions=predictions, references=references) >>> print(results) {'matthews_correlation': 1.0} """, stored examples: 0)

You can call its compute method with your predictions and labels directly and it will return a dictionary with the metric(s) value:

import numpy as np fake_preds = np.random.randint(0, 2, size=(64,)) fake_labels = np.random.randint(0, 2, size=(64,)) metric.compute(predictions=fake_preds, references=fake_labels)
{'matthews_correlation': 0.061541083462958945}

Note that load has loaded the proper metric associated to your task, which is:

so the metric object only computes the one(s) needed for your task.

Preprocessing the data

Before we can feed those texts to our model, we need to preprocess them. This is done by a 🤗 Transformers Tokenizer which will (as the name indicates) tokenize the inputs (including converting the tokens to their corresponding IDs in the pretrained vocabulary) and put it in a format the model expects, as well as generate the other inputs that model requires.

To do all of this, we instantiate our tokenizer with the AutoTokenizer.from_pretrained method, which will ensure:

  • we get a tokenizer that corresponds to the model architecture we want to use,

  • we download the vocabulary used when pretraining this specific checkpoint.

That vocabulary will be cached, so it's not downloaded again the next time we run the cell.

from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

You can directly call this tokenizer on one sentence or a pair of sentences:

tokenizer("Hello, this is a sentence!", "And this sentence goes with it.")
{'input_ids': [101, 7592, 1010, 2023, 2003, 1037, 6251, 999, 102, 1998, 2023, 6251, 3632, 2007, 2009, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

Depending on the model you selected, you will see different keys in the dictionary returned by the cell above. They don't matter much for what we're doing here (just know they are required by the model we will instantiate later), you can learn more about them in this tutorial if you're interested.

To preprocess our dataset, we will thus need the names of the columns containing the sentence(s). The following dictionary keeps track of the correspondence task to column names:

task_to_keys = { "cola": ("sentence", None), "mnli": ("premise", "hypothesis"), "mnli-mm": ("premise", "hypothesis"), "mrpc": ("sentence1", "sentence2"), "qnli": ("question", "sentence"), "qqp": ("question1", "question2"), "rte": ("sentence1", "sentence2"), "sst2": ("sentence", None), "stsb": ("sentence1", "sentence2"), "wnli": ("sentence1", "sentence2"), }

We can double check it does work on our current dataset:

sentence1_key, sentence2_key = task_to_keys[task] if sentence2_key is None: print(f"Sentence: {dataset['train'][0][sentence1_key]}") else: print(f"Sentence 1: {dataset['train'][0][sentence1_key]}") print(f"Sentence 2: {dataset['train'][0][sentence2_key]}")
Sentence: Our friends won't buy this analysis, let alone the next one we propose.

We can them write the function that will preprocess our samples. We just feed them to the tokenizer with the arguments truncation=True and padding='longest. This will ensure that an input longer that what the model selected can handle will be truncated to the maximum length accepted by the model, and all inputs will be padded to the maximum input length to give us a single input array. A more performant method that reduces the number of padding tokens is to write a generator or tf.data.Dataset to only pad each batch to the maximum length in that batch, but most GLUE tasks are relatively quick on modern GPUs either way.

def preprocess_function(examples): if sentence2_key is None: return tokenizer(examples[sentence1_key], truncation=True) return tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True)

This function works with one or several examples. In the case of several examples, the tokenizer will return a list of lists for each key:

preprocess_function(dataset["train"][:5])
{'input_ids': [[101, 2256, 2814, 2180, 1005, 1056, 4965, 2023, 4106, 1010, 2292, 2894, 1996, 2279, 2028, 2057, 16599, 1012, 102], [101, 2028, 2062, 18404, 2236, 3989, 1998, 1045, 1005, 1049, 3228, 2039, 1012, 102], [101, 2028, 2062, 18404, 2236, 3989, 2030, 1045, 1005, 1049, 3228, 2039, 1012, 102], [101, 1996, 2062, 2057, 2817, 16025, 1010, 1996, 13675, 16103, 2121, 2027, 2131, 1012, 102], [101, 2154, 2011, 2154, 1996, 8866, 2024, 2893, 14163, 8024, 3771, 1012, 102]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}

To apply this function on all the sentences (or pairs of sentences) in our dataset, we just use the map method of our dataset object we created earlier. This will apply the function on all the elements of all the splits in dataset, so our training, validation and testing data will be preprocessed in one single command.

pre_tokenizer_columns = set(dataset["train"].features) encoded_dataset = dataset.map(preprocess_function, batched=True) tokenizer_columns = list(set(encoded_dataset["train"].features) - pre_tokenizer_columns) print("Columns added by tokenizer:", tokenizer_columns)
Loading cached processed dataset at /home/matt/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-53ea0538c5398ddb.arrow Loading cached processed dataset at /home/matt/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-b7cb36983237028d.arrow Loading cached processed dataset at /home/matt/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-26a3f487bad313e7.arrow
Columns added by tokenizer: ['input_ids', 'attention_mask']
encoded_dataset["train"].features["label"]
ClassLabel(num_classes=2, names=['unacceptable', 'acceptable'], id=None)

Even better, the results are automatically cached by the 🤗 Datasets library to avoid spending time on this step the next time you run your notebook. The 🤗 Datasets library is normally smart enough to detect when the function you pass to map has changed (and thus requires to not use the cache data). For instance, it will properly detect if you change the task in the first cell and rerun the notebook. 🤗 Datasets warns you when it uses cached files, you can pass load_from_cache_file=False in the call to map to not use the cached files and force the preprocessing to be applied again.

Note that we passed batched=True to encode the texts by batches together. This is to leverage the full benefit of the fast tokenizer we loaded earlier, which will use multi-threading to treat the texts in a batch concurrently.

Fine-tuning the model

Now that our data is ready, we can download the pretrained model and fine-tune it. Since all our tasks are about sentence classification, we use the TFAutoModelForSequenceClassification class. Like with the tokenizer, the from_pretrained method will download and cache the model for us. The only thing we have to specify is the number of labels for our problem (which is always 2, except for STS-B which is a regression problem and MNLI where we have 3 labels).

from transformers import TFAutoModelForSequenceClassification import tensorflow as tf num_labels = 3 if task.startswith("mnli") else 1 if task == "stsb" else 2 if task == "stsb": num_labels = 1 elif task.startswith("mnli"): num_labels = 3 else: num_labels = 2 # This next little bit is optional, but will give us cleaner label outputs later # If you're using a task other than CoLA, you will probably need to change these # to match the label names for your task! id2label = {0: "Invalid", 1: "Valid"} label2id = {val: key for key, val in id2label.items()} model = TFAutoModelForSequenceClassification.from_pretrained( model_checkpoint, num_labels=num_labels, id2label=id2label, label2id=label2id )
2022-08-03 13:07:25.935388: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_COMPAT_NOT_SUPPORTED_ON_DEVICE: forward compatibility was attempted on non supported HW 2022-08-03 13:07:25.935426: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: matt-TRX40-AORUS-PRO-WIFI 2022-08-03 13:07:25.935434: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: matt-TRX40-AORUS-PRO-WIFI 2022-08-03 13:07:25.935556: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:200] libcuda reported version is: 470.141.3 2022-08-03 13:07:25.935580: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:204] kernel reported version is: 470.129.6 2022-08-03 13:07:25.935586: E tensorflow/stream_executor/cuda/cuda_diagnostics.cc:313] kernel version 470.129.6 does not match DSO version 470.141.3 -- cannot find working devices in this configuration 2022-08-03 13:07:25.935836: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags. Some layers from the model checkpoint at distilbert-base-uncased were not used when initializing TFDistilBertForSequenceClassification: ['vocab_transform', 'vocab_projector', 'activation_13', 'vocab_layer_norm'] - This IS expected if you are initializing TFDistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing TFDistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). Some layers of TFDistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier', 'dropout_19', 'pre_classifier'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

The warning is telling us we are throwing away some weights (the vocab_transform and vocab_layer_norm layers) and randomly initializing some other (the pre_classifier and classifier layers). This is absolutely normal in this case, because we are removing the head used to pretrain the model on a masked language modeling objective and replacing it with a new head for which we don't have pretrained weights, so the library warns us we should fine-tune this model before using it for inference, which is exactly what we are going to do.

Next, we convert our datasets to tf.data.Dataset, which Keras understands natively. There are two ways to do this - we can use the slightly more low-level Dataset.to_tf_dataset() method, or we can use Model.prepare_tf_dataset(). The main difference between these two is that the Model method can inspect the model to determine which column names it can use as input, which means you don't need to specify them yourself. Unless our samples are all the same length, we will also need to pass a tokenizer or collate_fn so that the tf.data.Dataset knows how to pad and combine samples into a batch.

validation_key = ( "validation_mismatched" if task == "mnli-mm" else "validation_matched" if task == "mnli" else "validation" ) tf_train_dataset = model.prepare_tf_dataset( encoded_dataset["train"], shuffle=True, batch_size=16, tokenizer=tokenizer ) tf_validation_dataset = model.prepare_tf_dataset( encoded_dataset[validation_key], shuffle=False, batch_size=16, tokenizer=tokenizer, )

Next, we need to set up our optimizer and compile() our model. The create_optimizer function in the Transformers library creates a very useful AdamW optimizer with weight and learning rate decay. This performs very well for training most transformer networks - we recommend using it as your default unless you have a good reason not to! Note, however, that because it decays the learning rate over the course of training, it needs to know how many batches it will see during training.

Note that all models in transformers can pick a sensible loss function by default. To use this loss, simply do not pass a loss argument to compile(). Although the losses for GLUE tasks are usually just simple cross-entropy, this can be very helpful in models when the loss is intricate and contains multiple terms.

In some of our other examples, we use jit_compile to compile the model with XLA. In this case, we should be careful about that - because our inputs have variable sequence lengths, we may end up having to do a new XLA compilation for each possible length, because XLA compilation expects a static input shape! For small datasets, this will probably result in spending more time on XLA compilation than actually training, which isn't very helpful.

If you really want to use XLA without these problems (for example, if you're training on TPU), you can create a tokenizer with padding="max_length". This will pad all of your samples to the same length, ensuring that a single XLA compilation will suffice for your entire dataset. Note that depending on the nature of your dataset, this may result in a lot of wasted computation on padding tokens!

from transformers import create_optimizer num_epochs = 3 batches_per_epoch = len(encoded_dataset["train"]) // batch_size total_train_steps = int(batches_per_epoch * num_epochs) optimizer, schedule = create_optimizer( init_lr=2e-5, num_warmup_steps=0, num_train_steps=total_train_steps ) model.compile(optimizer=optimizer)
No loss specified in compile() - the model's internal loss computation will be used as the loss. Don't panic - this is a common way to train TensorFlow models in Transformers! To disable this behaviour please pass a loss argument, or explicitly pass `loss=None` if you do not want your model to compute a loss.

The last thing to define is how to compute the metrics from the predictions. We need to define a function for this, which will just use the metric we loaded earlier. The only preprocessing we have to do is to take the argmax of our predicted logits (our just squeeze the last axis in the case of STS-B).

In addition, let's wrap this metric computation function in a KerasMetricCallback. This callback will compute the metric on the validation set each epoch, including printing it and logging it for other callbacks like TensorBoard and EarlyStopping.

Why do it this way, though, and not just use a straightforward Keras Metric object? This is a good question - on this task, several of the metrics such as Accuracy are very straightforward, and it would probably make more sense to just use a Keras metric for those instead. However, we want to demonstrate the use of KerasMetricCallback here, because it can handle any arbitrary Python function for the metric computation. This turns out to be very important for other NLP tasks like summarization and translation, where standard metrics like BLEU and ROUGE are much more complex to compute, and often involve decoding tokens generated by the model to strings and comparing their similarity to target sentences. If you want to stop training once ROUGE scores on the validation set start to decline, then KerasMetricCallback is essential.

That said, if you're only interested in tasks like text classification with straightforward metrics, then by all means remove the KerasMetricCallback and use a Keras Accuracy metric instead!

With that out of the way, how do we actually use KerasMetricCallback? It's straightfoward: We simply define a function that computes metrics given a tuple of numpy arrays of predictions and labels, then we pass that, along with the validation set to compute metrics on, to the callback:

from transformers.keras_callbacks import KerasMetricCallback def compute_metrics(eval_predictions): predictions, labels = eval_predictions if task != "stsb": predictions = np.argmax(predictions, axis=1) else: predictions = predictions[:, 0] return metric.compute(predictions=predictions, references=labels) metric_callback = KerasMetricCallback( metric_fn=compute_metrics, eval_dataset=tf_validation_dataset )

We can now finetune our model by just calling the fit method. Be sure to pass the TF datasets, and not the original datasets! We can also add a callback to sync up our model with the Hub - this allows us to resume training from other machines and even test the model's inference quality midway through training! Make sure to change the username if you do. If you don't want to do this, simply remove the callbacks argument in the call to fit().

from transformers.keras_callbacks import PushToHubCallback from tensorflow.keras.callbacks import TensorBoard model_name = model_checkpoint.split("/")[-1] push_to_hub_model_id = f"{model_name}-finetuned-{task}" tensorboard_callback = TensorBoard(log_dir="./text_classification_model_save/logs") push_to_hub_callback = PushToHubCallback( output_dir="./text_classification_model_save", tokenizer=tokenizer, hub_model_id=push_to_hub_model_id, ) callbacks = [metric_callback, tensorboard_callback, push_to_hub_callback] model.fit( tf_train_dataset, validation_data=tf_validation_dataset, epochs=num_epochs, callbacks=callbacks, )
/home/matt/PycharmProjects/notebooks/examples/text_classification_model_save is already a clone of https://huggingface.co/Rocketknight1/distilbert-base-uncased-finetuned-cola. Make sure you pull the latest changes with `repo.git_pull()`.
Epoch 1/3 534/534 [==============================] - ETA: 0s - loss: 0.5126
Several commits (2) will be pushed upstream.
534/534 [==============================] - 181s 330ms/step - loss: 0.5126 - val_loss: 0.4638 - matthews_correlation: 0.4555 Epoch 2/3 534/534 [==============================] - 174s 327ms/step - loss: 0.3182 - val_loss: 0.4914 - matthews_correlation: 0.5056 Epoch 3/3 534/534 [==============================] - 175s 327ms/step - loss: 0.1864 - val_loss: 0.5599 - matthews_correlation: 0.5285
<keras.callbacks.History at 0x7fb8d86573d0>

To see how your model fared you can compare it to the GLUE Benchmark leaderboard.

If you used the callback above, you can now share this model with all your friends, family or favorite pets: they can all load it with the identifier "your-username/the-name-you-picked" so for instance:

from transformers import TFAutoModelForSequenceClassification model = TFAutoModelForSequenceClassification.from_pretrained("your-username/my-awesome-model")

Inference

Training a model is fun, but once it's trained you usually want to use it to get predictions on new data. Let's take a look at how to do that. Firstly, we'll load our trained model from the hub - this lets us resume the code from here without needing to rerun all the training above every time.

from transformers import AutoTokenizer, TFAutoModelForSequenceClassification # You can, of course, use your own username and model name here once you've pushed your model using the code above! model = TFAutoModelForSequenceClassification.from_pretrained("Rocketknight1/distilbert-base-uncased-finetuned-cola") tokenizer = AutoTokenizer.from_pretrained("Rocketknight1/distilbert-base-uncased-finetuned-cola")
Some layers from the model checkpoint at Rocketknight1/distilbert-base-uncased-finetuned-cola were not used when initializing TFDistilBertForSequenceClassification: ['dropout_19'] - This IS expected if you are initializing TFDistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing TFDistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). Some layers of TFDistilBertForSequenceClassification were not initialized from the model checkpoint at Rocketknight1/distilbert-base-uncased-finetuned-cola and are newly initialized: ['dropout_39'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

Now, let's make up some sentences and see if the model can classify them properly! The first sentence is valid English, but the second one makes a grammatical mistake.

sentences = [ "The judge told the jurors to think carefully.", "The judge told that the jurors to think carefully." ]

To feed them into our model, we'll need to tokenize them and then get our model's predictions:

tokenized = tokenizer(sentences, return_tensors="np", padding="longest") outputs = model(tokenized).logits classifications = np.argmax(outputs, axis=1) print(classifications)
[1 0]

What do those label values mean? Let's use the id2label property set on our model to make them a little more comprehensible:

classifications = [model.config.id2label[output] for output in classifications] print(classifications)
['Valid', 'Invalid']

Looks right to me!

Pipeline API

An alternative way to quickly perform inference with any model on the hub is to use the Pipeline API, which abstracts away all the steps we did manually above. It will perform the preprocessing, forward pass and postprocessing all in a single object.

Let's showcase this for our trained model:

from transformers import pipeline classifier = pipeline("text-classification", "Rocketknight1/distilbert-base-uncased-finetuned-cola", framework="tf")
Some layers from the model checkpoint at Rocketknight1/distilbert-base-uncased-finetuned-cola were not used when initializing TFDistilBertForSequenceClassification: ['dropout_19'] - This IS expected if you are initializing TFDistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing TFDistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). Some layers of TFDistilBertForSequenceClassification were not initialized from the model checkpoint at Rocketknight1/distilbert-base-uncased-finetuned-cola and are newly initialized: ['dropout_59'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
classifier(sentences)
[{'label': 'Valid', 'score': 0.9861387610435486}, {'label': 'Invalid', 'score': 0.8175984025001526}]

And that's it - the code above is all you need to get classifications from your model in future! Note how the id2label property we set during training is automatically read by our pipeline to assign sensible names to the output classes.