Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/en-snapshot/guide/migrate/tpu_embedding.ipynb
25118 views
Kernel: Python 3
#@title Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.

Migrate from TPU embedding_columns to TPUEmbedding layer

This guide demonstrates how to migrate embedding training on on TPUs from TensorFlow 1's embedding_column API with TPUEstimator to TensorFlow 2's TPUEmbedding layer API with TPUStrategy.

Embeddings are (large) matrices. They are lookup tables that map from a sparse feature space to dense vectors. Embeddings provide efficient and dense representations, capturing complex similarities and relationships between features.

TensorFlow includes specialized support for training embeddings on TPUs. This TPU-specific embedding support allows you to train embeddings that are larger than the memory of a single TPU device, and to use sparse and ragged inputs on TPUs.

  • In TensorFlow 1, tf.compat.v1.estimator.tpu.TPUEstimator is a high level API that encapsulates training, evaluation, prediction, and exporting for serving with TPUs. It has special support for tf.compat.v1.tpu.experimental.embedding_column.

  • To implement this in TensorFlow 2, use the TensorFlow Recommenders' tfrs.layers.embedding.TPUEmbedding layer. For training and evaluation, use a TPU distribution strategy—tf.distribute.TPUStrategy—which is compatible with the Keras APIs for, for example, model building (tf.keras.Model), optimizers (tf.keras.optimizers.Optimizer), and training with Model.fit or a custom training loop with tf.function and tf.GradientTape.

For additional information, refer to the tfrs.layers.embedding.TPUEmbedding layer's API documentation, as well as the tf.tpu.experimental.embedding.TableConfig and tf.tpu.experimental.embedding.FeatureConfig docs for additional information. For an overview of tf.distribute.TPUStrategy, check out the Distributed training guide and the Use TPUs guide. If you're migrating from TPUEstimator to TPUStrategy, check out the TPU migration guide.

Setup

Start by installing TensorFlow Recommenders and importing some necessary packages:

!pip install tensorflow-recommenders
import tensorflow as tf import tensorflow.compat.v1 as tf1 # TPUEmbedding layer is not part of TensorFlow. import tensorflow_recommenders as tfrs

And prepare a simple dataset for demonstration purposes:

features = [[1., 1.5]] embedding_features_indices = [[0, 0], [0, 1]] embedding_features_values = [0, 5] labels = [[0.3]] eval_features = [[4., 4.5]] eval_embedding_features_indices = [[0, 0], [0, 1]] eval_embedding_features_values = [4, 3] eval_labels = [[0.8]]

TensorFlow 1: Train embeddings on TPUs with TPUEstimator

In TensorFlow 1, you set up TPU embeddings using the tf.compat.v1.tpu.experimental.embedding_column API and train/evaluate the model on TPUs with tf.compat.v1.estimator.tpu.TPUEstimator.

The inputs are integers ranging from zero to the vocabulary size for the TPU embedding table. Begin with encoding the inputs to categorical ID with tf.feature_column.categorical_column_with_identity. Use "sparse_feature" for the key parameter, since the input features are integer-valued, while num_buckets is the vocabulary size for the embedding table (10).

embedding_id_column = ( tf1.feature_column.categorical_column_with_identity( key="sparse_feature", num_buckets=10))

Next, convert the sparse categorical inputs to a dense representation with tpu.experimental.embedding_column, where dimension is the width of the embedding table. It will store an embedding vector for each of the num_buckets.

embedding_column = tf1.tpu.experimental.embedding_column( embedding_id_column, dimension=5)

Now, define the TPU-specific embedding configuration via tf.estimator.tpu.experimental.EmbeddingConfigSpec. You will pass it later to tf.estimator.tpu.TPUEstimator as an embedding_config_spec parameter.

embedding_config_spec = tf1.estimator.tpu.experimental.EmbeddingConfigSpec( feature_columns=(embedding_column,), optimization_parameters=( tf1.tpu.experimental.AdagradParameters(0.05)))

Next, to use a TPUEstimator, define:

  • An input function for the training data

  • An evaluation input function for the evaluation data

  • A model function for instructing the TPUEstimator how the training op is defined with the features and labels

def _input_fn(params): dataset = tf1.data.Dataset.from_tensor_slices(( {"dense_feature": features, "sparse_feature": tf1.SparseTensor( embedding_features_indices, embedding_features_values, [1, 2])}, labels)) dataset = dataset.repeat() return dataset.batch(params['batch_size'], drop_remainder=True) def _eval_input_fn(params): dataset = tf1.data.Dataset.from_tensor_slices(( {"dense_feature": eval_features, "sparse_feature": tf1.SparseTensor( eval_embedding_features_indices, eval_embedding_features_values, [1, 2])}, eval_labels)) dataset = dataset.repeat() return dataset.batch(params['batch_size'], drop_remainder=True) def _model_fn(features, labels, mode, params): embedding_features = tf1.keras.layers.DenseFeatures(embedding_column)(features) concatenated_features = tf1.keras.layers.Concatenate(axis=1)( [embedding_features, features["dense_feature"]]) logits = tf1.layers.Dense(1)(concatenated_features) loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits) optimizer = tf1.train.AdagradOptimizer(0.05) optimizer = tf1.tpu.CrossShardOptimizer(optimizer) train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step()) return tf1.estimator.tpu.TPUEstimatorSpec(mode, loss=loss, train_op=train_op)

With those functions defined, create a tf.distribute.cluster_resolver.TPUClusterResolver that provides the cluster information, and a tf.compat.v1.estimator.tpu.RunConfig object.

Along with the model function you have defined, you can now create a TPUEstimator. Here, you will simplify the flow by skipping checkpoint savings. Then, you will specify the batch size for both training and evaluation for the TPUEstimator.

cluster_resolver = tf1.distribute.cluster_resolver.TPUClusterResolver(tpu='') print("All devices: ", tf1.config.list_logical_devices('TPU'))
tpu_config = tf1.estimator.tpu.TPUConfig( iterations_per_loop=10, per_host_input_for_training=tf1.estimator.tpu.InputPipelineConfig .PER_HOST_V2) config = tf1.estimator.tpu.RunConfig( cluster=cluster_resolver, save_checkpoints_steps=None, tpu_config=tpu_config) estimator = tf1.estimator.tpu.TPUEstimator( model_fn=_model_fn, config=config, train_batch_size=8, eval_batch_size=8, embedding_config_spec=embedding_config_spec)

Call TPUEstimator.train to begin training the model:

estimator.train(_input_fn, steps=1)

Then, call TPUEstimator.evaluate to evaluate the model using the evaluation data:

estimator.evaluate(_eval_input_fn, steps=1)

TensorFlow 2: Train embeddings on TPUs with TPUStrategy

In TensorFlow 2, to train on the TPU workers, use tf.distribute.TPUStrategy together with the Keras APIs for model definition and training/evaluation. (Refer to the Use TPUs guide for more examples of training with Keras Model.fit and a custom training loop (with tf.function and tf.GradientTape).)

Since you need to perform some initialization work to connect to the remote cluster and initialize the TPU workers, start by creating a TPUClusterResolver to provide the cluster information and connect to the cluster. (Learn more in the TPU initialization section of the Use TPUs guide.)

cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') tf.config.experimental_connect_to_cluster(cluster_resolver) tf.tpu.experimental.initialize_tpu_system(cluster_resolver) print("All devices: ", tf.config.list_logical_devices('TPU'))

Next, prepare your data. This is similar to how you created a dataset in the TensorFlow 1 example, except the dataset function is now passed a tf.distribute.InputContext object rather than a params dict. You can use this object to determine the local batch size (and which host this pipeline is for, so you can properly partition your data).

  • When using the tfrs.layers.embedding.TPUEmbedding API, it is important to include the drop_remainder=True option when batching the dataset with Dataset.batch, since TPUEmbedding requires a fixed batch size.

  • Additionally, the same batch size must be used for evaluation and training if they are taking place on the same set of devices.

  • Finally, you should use tf.keras.utils.experimental.DatasetCreator along with the special input option—experimental_fetch_to_device=False—in tf.distribute.InputOptions (which holds strategy-specific configurations). This is demonstrated below:

global_batch_size = 8 def _input_dataset(context: tf.distribute.InputContext): dataset = tf.data.Dataset.from_tensor_slices(( {"dense_feature": features, "sparse_feature": tf.SparseTensor( embedding_features_indices, embedding_features_values, [1, 2])}, labels)) dataset = dataset.shuffle(10).repeat() dataset = dataset.batch( context.get_per_replica_batch_size(global_batch_size), drop_remainder=True) return dataset.prefetch(2) def _eval_dataset(context: tf.distribute.InputContext): dataset = tf.data.Dataset.from_tensor_slices(( {"dense_feature": eval_features, "sparse_feature": tf.SparseTensor( eval_embedding_features_indices, eval_embedding_features_values, [1, 2])}, eval_labels)) dataset = dataset.repeat() dataset = dataset.batch( context.get_per_replica_batch_size(global_batch_size), drop_remainder=True) return dataset.prefetch(2) input_options = tf.distribute.InputOptions( experimental_fetch_to_device=False) input_dataset = tf.keras.utils.experimental.DatasetCreator( _input_dataset, input_options=input_options) eval_dataset = tf.keras.utils.experimental.DatasetCreator( _eval_dataset, input_options=input_options)

Next, once the data is prepared, you will create a TPUStrategy, and define a model, metrics, and an optimizer under the scope of this strategy (Strategy.scope).

You should pick a number for steps_per_execution in Model.compile since it specifies the number of batches to run during each tf.function call, and is critical for performance. This argument is similar to iterations_per_loop used in TPUEstimator.

The features and table configuration that were specified in TensorFlow 1 via the tf.tpu.experimental.embedding_column (and tf.tpu.experimental.shared_embedding_column) can be specified directly in TensorFlow 2 via a pair of configuration objects:

  • tf.tpu.experimental.embedding.FeatureConfig

  • tf.tpu.experimental.embedding.TableConfig

(Refer to the associated API documentation for more details.)

strategy = tf.distribute.TPUStrategy(cluster_resolver) with strategy.scope(): if hasattr(tf.keras.optimizers, "legacy"): optimizer = tf.keras.optimizers.legacy.Adagrad(learning_rate=0.05) else: optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05) dense_input = tf.keras.Input(shape=(2,), dtype=tf.float32, batch_size=global_batch_size) sparse_input = tf.keras.Input(shape=(), dtype=tf.int32, batch_size=global_batch_size) embedded_input = tfrs.layers.embedding.TPUEmbedding( feature_config=tf.tpu.experimental.embedding.FeatureConfig( table=tf.tpu.experimental.embedding.TableConfig( vocabulary_size=10, dim=5, initializer=tf.initializers.TruncatedNormal(mean=0.0, stddev=1)), name="sparse_input"), optimizer=optimizer)(sparse_input) input = tf.keras.layers.Concatenate(axis=1)([dense_input, embedded_input]) result = tf.keras.layers.Dense(1)(input) model = tf.keras.Model(inputs={"dense_feature": dense_input, "sparse_feature": sparse_input}, outputs=result) model.compile(optimizer, "mse", steps_per_execution=10)

With that, you are ready to train the model with the training dataset:

model.fit(input_dataset, epochs=5, steps_per_epoch=10)

Finally, evaluate the model using the evaluation dataset:

model.evaluate(eval_dataset, steps=1, return_dict=True)

Next steps

Learn more about setting up TPU-specific embeddings in the API docs:

  • tfrs.layers.embedding.TPUEmbedding: particularly about feature and table configuration, setting the optimizer, creating a model (using the Keras functional API or via subclassing tf.keras.Model), training/evaluation, and model serving with tf.saved_model

  • tf.tpu.experimental.embedding.TableConfig

  • tf.tpu.experimental.embedding.FeatureConfig

For more information about TPUStrategy in TensorFlow 2, consider the following resources:

To learn more about customizing your training, refer to:

TPUs—Google's specialized ASICs for machine learning—are available through Google Colab, the TPU Research Cloud, and Cloud TPU.