Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/en-snapshot/tutorials/load_data/text.ipynb
25118 views
Kernel: Python 3
#@title Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.

Load text

This tutorial demonstrates two ways to load and preprocess text.

  • First, you will use Keras utilities and preprocessing layers. These include tf.keras.utils.text_dataset_from_directory to turn data into a tf.data.Dataset and tf.keras.layers.TextVectorization for data standardization, tokenization, and vectorization. If you are new to TensorFlow, you should start with these.

  • Then, you will use lower-level utilities like tf.data.TextLineDataset to load text files, and TensorFlow Text APIs, such as text.UnicodeScriptTokenizer and text.case_fold_utf8, to preprocess the data for finer-grain control.

!pip install "tensorflow-text==2.11.*"
import collections import pathlib import tensorflow as tf from tensorflow.keras import layers from tensorflow.keras import losses from tensorflow.keras import utils from tensorflow.keras.layers import TextVectorization import tensorflow_datasets as tfds import tensorflow_text as tf_text

Example 1: Predict the tag for a Stack Overflow question

As a first example, you will download a dataset of programming questions from Stack Overflow. Each question ("How do I sort a dictionary by value?") is labeled with exactly one tag (Python, CSharp, JavaScript, or Java). Your task is to develop a model that predicts the tag for a question. This is an example of multi-class classification—an important and widely applicable kind of machine learning problem.

Download and explore the dataset

Begin by downloading the Stack Overflow dataset using tf.keras.utils.get_file, and exploring the directory structure:

data_url = 'https://storage.googleapis.com/download.tensorflow.org/data/stack_overflow_16k.tar.gz' dataset_dir = utils.get_file( origin=data_url, untar=True, cache_dir='stack_overflow', cache_subdir='') dataset_dir = pathlib.Path(dataset_dir).parent
list(dataset_dir.iterdir())
train_dir = dataset_dir/'train' list(train_dir.iterdir())

The train/csharp, train/java, train/python and train/javascript directories contain many text files, each of which is a Stack Overflow question.

Print an example file and inspect the data:

sample_file = train_dir/'python/1755.txt' with open(sample_file) as f: print(f.read())

Load the dataset

Next, you will load the data off disk and prepare it into a format suitable for training. To do so, you will use the tf.keras.utils.text_dataset_from_directory utility to create a labeled tf.data.Dataset. If you're new to tf.data, it's a powerful collection of tools for building input pipelines. (Learn more in the tf.data: Build TensorFlow input pipelines guide.)

The tf.keras.utils.text_dataset_from_directory API expects a directory structure as follows:

train/ ...csharp/ ......1.txt ......2.txt ...java/ ......1.txt ......2.txt ...javascript/ ......1.txt ......2.txt ...python/ ......1.txt ......2.txt

When running a machine learning experiment, it is a best practice to divide your dataset into three splits: training, validation, and test.

The Stack Overflow dataset has already been divided into training and test sets, but it lacks a validation set.

Create a validation set using an 80:20 split of the training data by using tf.keras.utils.text_dataset_from_directory with validation_split set to 0.2 (i.e. 20%):

batch_size = 32 seed = 42 raw_train_ds = utils.text_dataset_from_directory( train_dir, batch_size=batch_size, validation_split=0.2, subset='training', seed=seed)

As the previous cell output suggests, there are 8,000 examples in the training folder, of which you will use 80% (or 6,400) for training. You will learn in a moment that you can train a model by passing a tf.data.Dataset directly to Model.fit.

First, iterate over the dataset and print out a few examples, to get a feel for the data.

Note: To increase the difficulty of the classification problem, the dataset author replaced occurrences of the words Python, CSharp, JavaScript, or Java in the programming question with the word blank.

for text_batch, label_batch in raw_train_ds.take(1): for i in range(10): print("Question: ", text_batch.numpy()[i]) print("Label:", label_batch.numpy()[i])

The labels are 0, 1, 2 or 3. To check which of these correspond to which string label, you can inspect the class_names property on the dataset:

for i, label in enumerate(raw_train_ds.class_names): print("Label", i, "corresponds to", label)

Next, you will create a validation and a test set using tf.keras.utils.text_dataset_from_directory. You will use the remaining 1,600 reviews from the training set for validation.

Note: When using the validation_split and subset arguments of tf.keras.utils.text_dataset_from_directory, make sure to either specify a random seed or pass shuffle=False, so that the validation and training splits have no overlap.

# Create a validation set. raw_val_ds = utils.text_dataset_from_directory( train_dir, batch_size=batch_size, validation_split=0.2, subset='validation', seed=seed)
test_dir = dataset_dir/'test' # Create a test set. raw_test_ds = utils.text_dataset_from_directory( test_dir, batch_size=batch_size)

Prepare the dataset for training

Next, you will standardize, tokenize, and vectorize the data using the tf.keras.layers.TextVectorization layer.

  • Standardization refers to preprocessing the text, typically to remove punctuation or HTML elements to simplify the dataset.

  • Tokenization refers to splitting strings into tokens (for example, splitting a sentence into individual words by splitting on whitespace).

  • Vectorization refers to converting tokens into numbers so they can be fed into a neural network.

All of these tasks can be accomplished with this layer. (You can learn more about each of these in the tf.keras.layers.TextVectorization API docs.)

Note that:

  • The default standardization converts text to lowercase and removes punctuation (standardize='lower_and_strip_punctuation').

  • The default tokenizer splits on whitespace (split='whitespace').

  • The default vectorization mode is 'int' (output_mode='int'). This outputs integer indices (one per token). This mode can be used to build models that take word order into account. You can also use other modes—like 'binary'—to build bag-of-words models.

You will build two models to learn more about standardization, tokenization, and vectorization with TextVectorization:

  • First, you will use the 'binary' vectorization mode to build a bag-of-words model.

  • Then, you will use the 'int' mode with a 1D ConvNet.

VOCAB_SIZE = 10000 binary_vectorize_layer = TextVectorization( max_tokens=VOCAB_SIZE, output_mode='binary')

For the 'int' mode, in addition to maximum vocabulary size, you need to set an explicit maximum sequence length (MAX_SEQUENCE_LENGTH), which will cause the layer to pad or truncate sequences to exactly output_sequence_length values:

MAX_SEQUENCE_LENGTH = 250 int_vectorize_layer = TextVectorization( max_tokens=VOCAB_SIZE, output_mode='int', output_sequence_length=MAX_SEQUENCE_LENGTH)

Next, call TextVectorization.adapt to fit the state of the preprocessing layer to the dataset. This will cause the model to build an index of strings to integers.

Note: It's important to only use your training data when calling TextVectorization.adapt, as using the test set would leak information.

# Make a text-only dataset (without labels), then call `TextVectorization.adapt`. train_text = raw_train_ds.map(lambda text, labels: text) binary_vectorize_layer.adapt(train_text) int_vectorize_layer.adapt(train_text)

Print the result of using these layers to preprocess data:

def binary_vectorize_text(text, label): text = tf.expand_dims(text, -1) return binary_vectorize_layer(text), label
def int_vectorize_text(text, label): text = tf.expand_dims(text, -1) return int_vectorize_layer(text), label
# Retrieve a batch (of 32 reviews and labels) from the dataset. text_batch, label_batch = next(iter(raw_train_ds)) first_question, first_label = text_batch[0], label_batch[0] print("Question", first_question) print("Label", first_label)
print("'binary' vectorized question:", binary_vectorize_text(first_question, first_label)[0])
print("'int' vectorized question:", int_vectorize_text(first_question, first_label)[0])

As shown above, TextVectorization's 'binary' mode returns an array denoting which tokens exist at least once in the input, while the 'int' mode replaces each token by an integer, thus preserving their order.

You can lookup the token (string) that each integer corresponds to by calling TextVectorization.get_vocabulary on the layer:

print("1289 ---> ", int_vectorize_layer.get_vocabulary()[1289]) print("313 ---> ", int_vectorize_layer.get_vocabulary()[313]) print("Vocabulary size: {}".format(len(int_vectorize_layer.get_vocabulary())))

You are nearly ready to train your model.

As a final preprocessing step, you will apply the TextVectorization layers you created earlier to the training, validation, and test sets:

binary_train_ds = raw_train_ds.map(binary_vectorize_text) binary_val_ds = raw_val_ds.map(binary_vectorize_text) binary_test_ds = raw_test_ds.map(binary_vectorize_text) int_train_ds = raw_train_ds.map(int_vectorize_text) int_val_ds = raw_val_ds.map(int_vectorize_text) int_test_ds = raw_test_ds.map(int_vectorize_text)

Configure the dataset for performance

These are two important methods you should use when loading data to make sure that I/O does not become blocking.

  • Dataset.cache keeps data in memory after it's loaded off disk. This will ensure the dataset does not become a bottleneck while training your model. If your dataset is too large to fit into memory, you can also use this method to create a performant on-disk cache, which is more efficient to read than many small files.

  • Dataset.prefetch overlaps data preprocessing and model execution while training.

You can learn more about both methods, as well as how to cache data to disk in the Prefetching section of the Better performance with the tf.data API guide.

AUTOTUNE = tf.data.AUTOTUNE def configure_dataset(dataset): return dataset.cache().prefetch(buffer_size=AUTOTUNE)
binary_train_ds = configure_dataset(binary_train_ds) binary_val_ds = configure_dataset(binary_val_ds) binary_test_ds = configure_dataset(binary_test_ds) int_train_ds = configure_dataset(int_train_ds) int_val_ds = configure_dataset(int_val_ds) int_test_ds = configure_dataset(int_test_ds)

Train the model

It's time to create your neural network.

For the 'binary' vectorized data, define a simple bag-of-words linear model, then configure and train it:

binary_model = tf.keras.Sequential([layers.Dense(4)]) binary_model.compile( loss=losses.SparseCategoricalCrossentropy(from_logits=True), optimizer='adam', metrics=['accuracy']) history = binary_model.fit( binary_train_ds, validation_data=binary_val_ds, epochs=10)

Next, you will use the 'int' vectorized layer to build a 1D ConvNet:

def create_model(vocab_size, num_labels): model = tf.keras.Sequential([ layers.Embedding(vocab_size, 64, mask_zero=True), layers.Conv1D(64, 5, padding="valid", activation="relu", strides=2), layers.GlobalMaxPooling1D(), layers.Dense(num_labels) ]) return model
# `vocab_size` is `VOCAB_SIZE + 1` since `0` is used additionally for padding. int_model = create_model(vocab_size=VOCAB_SIZE + 1, num_labels=4) int_model.compile( loss=losses.SparseCategoricalCrossentropy(from_logits=True), optimizer='adam', metrics=['accuracy']) history = int_model.fit(int_train_ds, validation_data=int_val_ds, epochs=5)

Compare the two models:

print("Linear model on binary vectorized data:") print(binary_model.summary())
print("ConvNet model on int vectorized data:") print(int_model.summary())

Evaluate both models on the test data:

binary_loss, binary_accuracy = binary_model.evaluate(binary_test_ds) int_loss, int_accuracy = int_model.evaluate(int_test_ds) print("Binary model accuracy: {:2.2%}".format(binary_accuracy)) print("Int model accuracy: {:2.2%}".format(int_accuracy))

Note: This example dataset represents a rather simple classification problem. More complex datasets and problems bring out subtle but significant differences in preprocessing strategies and model architectures. Be sure to try out different hyperparameters and epochs to compare various approaches.

Export the model

In the code above, you applied tf.keras.layers.TextVectorization to the dataset before feeding text to the model. If you want to make your model capable of processing raw strings (for example, to simplify deploying it), you can include the TextVectorization layer inside your model.

To do so, you can create a new model using the weights you have just trained:

export_model = tf.keras.Sequential( [binary_vectorize_layer, binary_model, layers.Activation('sigmoid')]) export_model.compile( loss=losses.SparseCategoricalCrossentropy(from_logits=False), optimizer='adam', metrics=['accuracy']) # Test it with `raw_test_ds`, which yields raw strings loss, accuracy = export_model.evaluate(raw_test_ds) print("Accuracy: {:2.2%}".format(accuracy))

Now, your model can take raw strings as input and predict a score for each label using Model.predict. Define a function to find the label with the maximum score:

def get_string_labels(predicted_scores_batch): predicted_int_labels = tf.math.argmax(predicted_scores_batch, axis=1) predicted_labels = tf.gather(raw_train_ds.class_names, predicted_int_labels) return predicted_labels

Run inference on new data

inputs = [ "how do I extract keys from a dict into a list?", # 'python' "debug public static void main(string[] args) {...}", # 'java' ] predicted_scores = export_model.predict(inputs) predicted_labels = get_string_labels(predicted_scores) for input, label in zip(inputs, predicted_labels): print("Question: ", input) print("Predicted label: ", label.numpy())

Including the text preprocessing logic inside your model enables you to export a model for production that simplifies deployment, and reduces the potential for train/test skew.

There is a performance difference to keep in mind when choosing where to apply tf.keras.layers.TextVectorization. Using it outside of your model enables you to do asynchronous CPU processing and buffering of your data when training on GPU. So, if you're training your model on the GPU, you probably want to go with this option to get the best performance while developing your model, then switch to including the TextVectorization layer inside your model when you're ready to prepare for deployment.

Visit the Save and load models tutorial to learn more about saving models.

Example 2: Predict the author of Iliad translations

The following provides an example of using tf.data.TextLineDataset to load examples from text files, and TensorFlow Text to preprocess the data. You will use three different English translations of the same work, Homer's Iliad, and train a model to identify the translator given a single line of text.

Download and explore the dataset

The texts of the three translations are by:

The text files used in this tutorial have undergone some typical preprocessing tasks like removing document header and footer, line numbers and chapter titles.

Download these lightly munged files locally:

DIRECTORY_URL = 'https://storage.googleapis.com/download.tensorflow.org/data/illiad/' FILE_NAMES = ['cowper.txt', 'derby.txt', 'butler.txt'] for name in FILE_NAMES: text_dir = utils.get_file(name, origin=DIRECTORY_URL + name) parent_dir = pathlib.Path(text_dir).parent list(parent_dir.iterdir())

Load the dataset

Previously, with tf.keras.utils.text_dataset_from_directory all contents of a file were treated as a single example. Here, you will use tf.data.TextLineDataset, which is designed to create a tf.data.Dataset from a text file where each example is a line of text from the original file. TextLineDataset is useful for text data that is primarily line-based (for example, poetry or error logs).

Iterate through these files, loading each one into its own dataset. Each example needs to be individually labeled, so use Dataset.map to apply a labeler function to each one. This will iterate over every example in the dataset, returning (example, label) pairs.

def labeler(example, index): return example, tf.cast(index, tf.int64)
labeled_data_sets = [] for i, file_name in enumerate(FILE_NAMES): lines_dataset = tf.data.TextLineDataset(str(parent_dir/file_name)) labeled_dataset = lines_dataset.map(lambda ex: labeler(ex, i)) labeled_data_sets.append(labeled_dataset)

Next, you'll combine these labeled datasets into a single dataset using Dataset.concatenate, and shuffle it with Dataset.shuffle:

BUFFER_SIZE = 50000 BATCH_SIZE = 64 VALIDATION_SIZE = 5000
all_labeled_data = labeled_data_sets[0] for labeled_dataset in labeled_data_sets[1:]: all_labeled_data = all_labeled_data.concatenate(labeled_dataset) all_labeled_data = all_labeled_data.shuffle( BUFFER_SIZE, reshuffle_each_iteration=False)

Print out a few examples as before. The dataset hasn't been batched yet, hence each entry in all_labeled_data corresponds to one data point:

for text, label in all_labeled_data.take(10): print("Sentence: ", text.numpy()) print("Label:", label.numpy())

Prepare the dataset for training

Instead of using tf.keras.layers.TextVectorization to preprocess the text dataset, you will now use the TensorFlow Text APIs to standardize and tokenize the data, build a vocabulary and use tf.lookup.StaticVocabularyTable to map tokens to integers to feed to the model. (Learn more about TensorFlow Text).

Define a function to convert the text to lower-case and tokenize it:

  • TensorFlow Text provides various tokenizers. In this example, you will use the text.UnicodeScriptTokenizer to tokenize the dataset.

  • You will use Dataset.map to apply the tokenization to the dataset.

tokenizer = tf_text.UnicodeScriptTokenizer()
def tokenize(text, unused_label): lower_case = tf_text.case_fold_utf8(text) return tokenizer.tokenize(lower_case)
tokenized_ds = all_labeled_data.map(tokenize)

You can iterate over the dataset and print out a few tokenized examples:

for text_batch in tokenized_ds.take(5): print("Tokens: ", text_batch.numpy())

Next, you will build a vocabulary by sorting tokens by frequency and keeping the top VOCAB_SIZE tokens:

tokenized_ds = configure_dataset(tokenized_ds) vocab_dict = collections.defaultdict(lambda: 0) for toks in tokenized_ds.as_numpy_iterator(): for tok in toks: vocab_dict[tok] += 1 vocab = sorted(vocab_dict.items(), key=lambda x: x[1], reverse=True) vocab = [token for token, count in vocab] vocab = vocab[:VOCAB_SIZE] vocab_size = len(vocab) print("Vocab size: ", vocab_size) print("First five vocab entries:", vocab[:5])

To convert the tokens into integers, use the vocab set to create a tf.lookup.StaticVocabularyTable. You will map tokens to integers in the range [2, vocab_size + 2]. As with the TextVectorization layer, 0 is reserved to denote padding and 1 is reserved to denote an out-of-vocabulary (OOV) token.

keys = vocab values = range(2, len(vocab) + 2) # Reserve `0` for padding, `1` for OOV tokens. init = tf.lookup.KeyValueTensorInitializer( keys, values, key_dtype=tf.string, value_dtype=tf.int64) num_oov_buckets = 1 vocab_table = tf.lookup.StaticVocabularyTable(init, num_oov_buckets)

Finally, define a function to standardize, tokenize and vectorize the dataset using the tokenizer and lookup table:

def preprocess_text(text, label): standardized = tf_text.case_fold_utf8(text) tokenized = tokenizer.tokenize(standardized) vectorized = vocab_table.lookup(tokenized) return vectorized, label

You can try this on a single example to print the output:

example_text, example_label = next(iter(all_labeled_data)) print("Sentence: ", example_text.numpy()) vectorized_text, example_label = preprocess_text(example_text, example_label) print("Vectorized sentence: ", vectorized_text.numpy())

Now run the preprocess function on the dataset using Dataset.map:

all_encoded_data = all_labeled_data.map(preprocess_text)

Split the dataset into training and test sets

The Keras TextVectorization layer also batches and pads the vectorized data. Padding is required because the examples inside of a batch need to be the same size and shape, but the examples in these datasets are not all the same size—each line of text has a different number of words.

tf.data.Dataset supports splitting and padded-batching datasets:

train_data = all_encoded_data.skip(VALIDATION_SIZE).shuffle(BUFFER_SIZE) validation_data = all_encoded_data.take(VALIDATION_SIZE)
train_data = train_data.padded_batch(BATCH_SIZE) validation_data = validation_data.padded_batch(BATCH_SIZE)

Now, validation_data and train_data are not collections of (example, label) pairs, but collections of batches. Each batch is a pair of (many examples, many labels) represented as arrays.

To illustrate this:

sample_text, sample_labels = next(iter(validation_data)) print("Text batch shape: ", sample_text.shape) print("Label batch shape: ", sample_labels.shape) print("First text example: ", sample_text[0]) print("First label example: ", sample_labels[0])

Since you use 0 for padding and 1 for out-of-vocabulary (OOV) tokens, the vocabulary size has increased by two:

vocab_size += 2

Configure the datasets for better performance as before:

train_data = configure_dataset(train_data) validation_data = configure_dataset(validation_data)

Train the model

You can train a model on this dataset as before:

model = create_model(vocab_size=vocab_size, num_labels=3) model.compile( optimizer='adam', loss=losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy']) history = model.fit(train_data, validation_data=validation_data, epochs=3)
loss, accuracy = model.evaluate(validation_data) print("Loss: ", loss) print("Accuracy: {:2.2%}".format(accuracy))

Export the model

To make the model capable of taking raw strings as input, you will create a Keras TextVectorization layer that performs the same steps as your custom preprocessing function. Since you have already trained a vocabulary, you can use TextVectorization.set_vocabulary (instead of TextVectorization.adapt), which trains a new vocabulary.

preprocess_layer = TextVectorization( max_tokens=vocab_size, standardize=tf_text.case_fold_utf8, split=tokenizer.tokenize, output_mode='int', output_sequence_length=MAX_SEQUENCE_LENGTH) preprocess_layer.set_vocabulary(vocab)
export_model = tf.keras.Sequential( [preprocess_layer, model, layers.Activation('sigmoid')]) export_model.compile( loss=losses.SparseCategoricalCrossentropy(from_logits=False), optimizer='adam', metrics=['accuracy'])
# Create a test dataset of raw strings. test_ds = all_labeled_data.take(VALIDATION_SIZE).batch(BATCH_SIZE) test_ds = configure_dataset(test_ds) loss, accuracy = export_model.evaluate(test_ds) print("Loss: ", loss) print("Accuracy: {:2.2%}".format(accuracy))

The loss and accuracy for the model on encoded validation set and the exported model on the raw validation set are the same, as expected.

Run inference on new data

inputs = [ "Join'd to th' Ionians with their flowing robes,", # Label: 1 "the allies, and his armour flashed about him so that he seemed to all", # Label: 2 "And with loud clangor of his arms he fell.", # Label: 0 ] predicted_scores = export_model.predict(inputs) predicted_labels = tf.math.argmax(predicted_scores, axis=1) for input, label in zip(inputs, predicted_labels): print("Question: ", input) print("Predicted label: ", label.numpy())

Download more datasets using TensorFlow Datasets (TFDS)

You can download many more datasets from TensorFlow Datasets.

In this example, you will use the IMDB Large Movie Review dataset to train a model for sentiment classification:

# Training set. train_ds = tfds.load( 'imdb_reviews', split='train[:80%]', batch_size=BATCH_SIZE, shuffle_files=True, as_supervised=True)
# Validation set. val_ds = tfds.load( 'imdb_reviews', split='train[80%:]', batch_size=BATCH_SIZE, shuffle_files=True, as_supervised=True)

Print a few examples:

for review_batch, label_batch in val_ds.take(1): for i in range(5): print("Review: ", review_batch[i].numpy()) print("Label: ", label_batch[i].numpy())

You can now preprocess the data and train a model as before.

Note: You will use tf.keras.losses.BinaryCrossentropy instead of tf.keras.losses.SparseCategoricalCrossentropy for your model, since this is a binary classification problem.

Prepare the dataset for training

vectorize_layer = TextVectorization( max_tokens=VOCAB_SIZE, output_mode='int', output_sequence_length=MAX_SEQUENCE_LENGTH) # Make a text-only dataset (without labels), then call `TextVectorization.adapt`. train_text = train_ds.map(lambda text, labels: text) vectorize_layer.adapt(train_text)
def vectorize_text(text, label): text = tf.expand_dims(text, -1) return vectorize_layer(text), label
train_ds = train_ds.map(vectorize_text) val_ds = val_ds.map(vectorize_text)
# Configure datasets for performance as before. train_ds = configure_dataset(train_ds) val_ds = configure_dataset(val_ds)

Create, configure and train the model

model = create_model(vocab_size=VOCAB_SIZE + 1, num_labels=1) model.summary()
model.compile( loss=losses.BinaryCrossentropy(from_logits=True), optimizer='adam', metrics=['accuracy'])
history = model.fit(train_ds, validation_data=val_ds, epochs=3)
loss, accuracy = model.evaluate(val_ds) print("Loss: ", loss) print("Accuracy: {:2.2%}".format(accuracy))

Export the model

export_model = tf.keras.Sequential( [vectorize_layer, model, layers.Activation('sigmoid')]) export_model.compile( loss=losses.SparseCategoricalCrossentropy(from_logits=False), optimizer='adam', metrics=['accuracy'])
# 0 --> negative review # 1 --> positive review inputs = [ "This is a fantastic movie.", "This is a bad movie.", "This movie was so bad that it was good.", "I will never say yes to watching this movie.", ] predicted_scores = export_model.predict(inputs) predicted_labels = [int(round(x[0])) for x in predicted_scores] for input, label in zip(inputs, predicted_labels): print("Question: ", input) print("Predicted label: ", label)

Conclusion

This tutorial demonstrated several ways to load and preprocess text. As a next step, you can explore additional text preprocessing TensorFlow Text tutorials, such as:

You can also find new datasets on TensorFlow Datasets. And, to learn more about tf.data, check out the guide on building input pipelines.