Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
tensorflow
GitHub Repository: tensorflow/docs-l10n
Path: blob/master/site/en-snapshot/guide/migrate/migrating_feature_columns.ipynb
25118 views
Kernel: Python 3
#@title Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.

Migrate tf.feature_columns to Keras preprocessing layers

Training a model usually comes with some amount of feature preprocessing, particularly when dealing with structured data. When training a tf.estimator.Estimator in TensorFlow 1, you usually perform feature preprocessing with the tf.feature_column API. In TensorFlow 2, you can do this directly with Keras preprocessing layers.

This migration guide demonstrates common feature transformations using both feature columns and preprocessing layers, followed by training a complete model with both APIs.

First, start with a couple of necessary imports:

import tensorflow as tf import tensorflow.compat.v1 as tf1 import math

Now, add a utility function for calling a feature column for demonstration:

def call_feature_columns(feature_columns, inputs): # This is a convenient way to call a `feature_column` outside of an estimator # to display its output. feature_layer = tf1.keras.layers.DenseFeatures(feature_columns) return feature_layer(inputs)

Input handling

To use feature columns with an estimator, model inputs are always expected to be a dictionary of tensors:

input_dict = { 'foo': tf.constant([1]), 'bar': tf.constant([0]), 'baz': tf.constant([-1]) }

Each feature column needs to be created with a key to index into the source data. The output of all feature columns is concatenated and used by the estimator model.

columns = [ tf1.feature_column.numeric_column('foo'), tf1.feature_column.numeric_column('bar'), tf1.feature_column.numeric_column('baz'), ] call_feature_columns(columns, input_dict)

In Keras, model input is much more flexible. A tf.keras.Model can handle a single tensor input, a list of tensor features, or a dictionary of tensor features. You can handle dictionary input by passing a dictionary of tf.keras.Input on model creation. Inputs will not be concatenated automatically, which allows them to be used in much more flexible ways. They can be concatenated with tf.keras.layers.Concatenate.

inputs = { 'foo': tf.keras.Input(shape=()), 'bar': tf.keras.Input(shape=()), 'baz': tf.keras.Input(shape=()), } # Inputs are typically transformed by preprocessing layers before concatenation. outputs = tf.keras.layers.Concatenate()(inputs.values()) model = tf.keras.Model(inputs=inputs, outputs=outputs) model(input_dict)

One-hot encoding integer IDs

A common feature transformation is one-hot encoding integer inputs of a known range. Here is an example using feature columns:

categorical_col = tf1.feature_column.categorical_column_with_identity( 'type', num_buckets=3) indicator_col = tf1.feature_column.indicator_column(categorical_col) call_feature_columns(indicator_col, {'type': [0, 1, 2]})

Using Keras preprocessing layers, these columns can be replaced by a single tf.keras.layers.CategoryEncoding layer with output_mode set to 'one_hot':

one_hot_layer = tf.keras.layers.CategoryEncoding( num_tokens=3, output_mode='one_hot') one_hot_layer([0, 1, 2])

Note: For large one-hot encodings, it is much more efficient to use a sparse representation of the output. If you pass sparse=True to the CategoryEncoding layer, the output of the layer will be a tf.sparse.SparseTensor, which can be efficiently handled as input to a tf.keras.layers.Dense layer.

Normalizing numeric features

When handling continuous, floating-point features with feature columns, you need to use a tf.feature_column.numeric_column. In the case where the input is already normalized, converting this to Keras is trivial. You can simply use a tf.keras.Input directly into your model, as shown above.

A numeric_column can also be used to normalize input:

def normalize(x): mean, variance = (2.0, 1.0) return (x - mean) / math.sqrt(variance) numeric_col = tf1.feature_column.numeric_column('col', normalizer_fn=normalize) call_feature_columns(numeric_col, {'col': tf.constant([[0.], [1.], [2.]])})

In contrast, with Keras, this normalization can be done with tf.keras.layers.Normalization.

normalization_layer = tf.keras.layers.Normalization(mean=2.0, variance=1.0) normalization_layer(tf.constant([[0.], [1.], [2.]]))

Bucketizing and one-hot encoding numeric features

Another common transformation of continuous, floating point inputs is to bucketize then to integers of a fixed range.

In feature columns, this can be achieved with a tf.feature_column.bucketized_column:

numeric_col = tf1.feature_column.numeric_column('col') bucketized_col = tf1.feature_column.bucketized_column(numeric_col, [1, 4, 5]) call_feature_columns(bucketized_col, {'col': tf.constant([1., 2., 3., 4., 5.])})

In Keras, this can be replaced by tf.keras.layers.Discretization:

discretization_layer = tf.keras.layers.Discretization(bin_boundaries=[1, 4, 5]) one_hot_layer = tf.keras.layers.CategoryEncoding( num_tokens=4, output_mode='one_hot') one_hot_layer(discretization_layer([1., 2., 3., 4., 5.]))

One-hot encoding string data with a vocabulary

Handling string features often requires a vocabulary lookup to translate strings into indices. Here is an example using feature columns to lookup strings and then one-hot encode the indices:

vocab_col = tf1.feature_column.categorical_column_with_vocabulary_list( 'sizes', vocabulary_list=['small', 'medium', 'large'], num_oov_buckets=0) indicator_col = tf1.feature_column.indicator_column(vocab_col) call_feature_columns(indicator_col, {'sizes': ['small', 'medium', 'large']})

Using Keras preprocessing layers, use the tf.keras.layers.StringLookup layer with output_mode set to 'one_hot':

string_lookup_layer = tf.keras.layers.StringLookup( vocabulary=['small', 'medium', 'large'], num_oov_indices=0, output_mode='one_hot') string_lookup_layer(['small', 'medium', 'large'])

Note: For large one-hot encodings, it is much more efficient to use a sparse representation of the output. If you pass sparse=True to the StringLookup layer, the output of the layer will be a tf.sparse.SparseTensor, which can be efficiently handled as input to a tf.keras.layers.Dense layer.

Embedding string data with a vocabulary

For larger vocabularies, an embedding is often needed for good performance. Here is an example embedding a string feature using feature columns:

vocab_col = tf1.feature_column.categorical_column_with_vocabulary_list( 'col', vocabulary_list=['small', 'medium', 'large'], num_oov_buckets=0) embedding_col = tf1.feature_column.embedding_column(vocab_col, 4) call_feature_columns(embedding_col, {'col': ['small', 'medium', 'large']})

Using Keras preprocessing layers, this can be achieved by combining a tf.keras.layers.StringLookup layer and an tf.keras.layers.Embedding layer. The default output for the StringLookup will be integer indices which can be fed directly into an embedding.

Note: The Embedding layer contains trainable parameters. While the StringLookup layer can be applied to data inside or outside of a model, the Embedding must always be part of a trainable Keras model to train correctly.

string_lookup_layer = tf.keras.layers.StringLookup( vocabulary=['small', 'medium', 'large'], num_oov_indices=0) embedding = tf.keras.layers.Embedding(3, 4) embedding(string_lookup_layer(['small', 'medium', 'large']))

Summing weighted categorical data

In some cases, you need to deal with categorical data where each occurance of a category comes with an associated weight. In feature columns, this is handled with tf.feature_column.weighted_categorical_column. When paired with an indicator_column, this has the effect of summing weights per category.

ids = tf.constant([[5, 11, 5, 17, 17]]) weights = tf.constant([[0.5, 1.5, 0.7, 1.8, 0.2]]) categorical_col = tf1.feature_column.categorical_column_with_identity( 'ids', num_buckets=20) weighted_categorical_col = tf1.feature_column.weighted_categorical_column( categorical_col, 'weights') indicator_col = tf1.feature_column.indicator_column(weighted_categorical_col) call_feature_columns(indicator_col, {'ids': ids, 'weights': weights})

In Keras, this can be done by passing a count_weights input to tf.keras.layers.CategoryEncoding with output_mode='count'.

ids = tf.constant([[5, 11, 5, 17, 17]]) weights = tf.constant([[0.5, 1.5, 0.7, 1.8, 0.2]]) # Using sparse output is more efficient when `num_tokens` is large. count_layer = tf.keras.layers.CategoryEncoding( num_tokens=20, output_mode='count', sparse=True) tf.sparse.to_dense(count_layer(ids, count_weights=weights))

Embedding weighted categorical data

You might alternately want to embed weighted categorical inputs. In feature columns, the embedding_column contains a combiner argument. If any sample contains multiple entries for a category, they will be combined according to the argument setting (by default 'mean').

ids = tf.constant([[5, 11, 5, 17, 17]]) weights = tf.constant([[0.5, 1.5, 0.7, 1.8, 0.2]]) categorical_col = tf1.feature_column.categorical_column_with_identity( 'ids', num_buckets=20) weighted_categorical_col = tf1.feature_column.weighted_categorical_column( categorical_col, 'weights') embedding_col = tf1.feature_column.embedding_column( weighted_categorical_col, 4, combiner='mean') call_feature_columns(embedding_col, {'ids': ids, 'weights': weights})

In Keras, there is no combiner option to tf.keras.layers.Embedding, but you can achieve the same effect with tf.keras.layers.Dense. The embedding_column above is simply linearly combining embedding vectors according to category weight. Though not obvious at first, it is exactly equivalent to representing your categorical inputs as a sparse weight vector of size (num_tokens), and multiplying them by a Dense kernel of shape (embedding_size, num_tokens).

ids = tf.constant([[5, 11, 5, 17, 17]]) weights = tf.constant([[0.5, 1.5, 0.7, 1.8, 0.2]]) # For `combiner='mean'`, normalize your weights to sum to 1. Removing this line # would be equivalent to an `embedding_column` with `combiner='sum'`. weights = weights / tf.reduce_sum(weights, axis=-1, keepdims=True) count_layer = tf.keras.layers.CategoryEncoding( num_tokens=20, output_mode='count', sparse=True) embedding_layer = tf.keras.layers.Dense(4, use_bias=False) embedding_layer(count_layer(ids, count_weights=weights))

Complete training example

To show a complete training workflow, first prepare some data with three features of different types:

features = { 'type': [0, 1, 1], 'size': ['small', 'small', 'medium'], 'weight': [2.7, 1.8, 1.6], } labels = [1, 1, 0] predict_features = {'type': [0], 'size': ['foo'], 'weight': [-0.7]}

Define some common constants for both TensorFlow 1 and TensorFlow 2 workflows:

vocab = ['small', 'medium', 'large'] one_hot_dims = 3 embedding_dims = 4 weight_mean = 2.0 weight_variance = 1.0

With feature columns

Feature columns must be passed as a list to the estimator on creation, and will be called implicitly during training.

categorical_col = tf1.feature_column.categorical_column_with_identity( 'type', num_buckets=one_hot_dims) # Convert index to one-hot; e.g. [2] -> [0,0,1]. indicator_col = tf1.feature_column.indicator_column(categorical_col) # Convert strings to indices; e.g. ['small'] -> [1]. vocab_col = tf1.feature_column.categorical_column_with_vocabulary_list( 'size', vocabulary_list=vocab, num_oov_buckets=1) # Embed the indices. embedding_col = tf1.feature_column.embedding_column(vocab_col, embedding_dims) normalizer_fn = lambda x: (x - weight_mean) / math.sqrt(weight_variance) # Normalize the numeric inputs; e.g. [2.0] -> [0.0]. numeric_col = tf1.feature_column.numeric_column( 'weight', normalizer_fn=normalizer_fn) estimator = tf1.estimator.DNNClassifier( feature_columns=[indicator_col, embedding_col, numeric_col], hidden_units=[1]) def _input_fn(): return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1) estimator.train(_input_fn)

The feature columns will also be used to transform input data when running inference on the model.

def _predict_fn(): return tf1.data.Dataset.from_tensor_slices(predict_features).batch(1) next(estimator.predict(_predict_fn))

With Keras preprocessing layers

Keras preprocessing layers are more flexible in where they can be called. A layer can be applied directly to tensors, used inside a tf.data input pipeline, or built directly into a trainable Keras model.

In this example, you will apply preprocessing layers inside a tf.data input pipeline. To do this, you can define a separate tf.keras.Model to preprocess your input features. This model is not trainable, but is a convenient way to group preprocessing layers.

inputs = { 'type': tf.keras.Input(shape=(), dtype='int64'), 'size': tf.keras.Input(shape=(), dtype='string'), 'weight': tf.keras.Input(shape=(), dtype='float32'), } # Convert index to one-hot; e.g. [2] -> [0,0,1]. type_output = tf.keras.layers.CategoryEncoding( one_hot_dims, output_mode='one_hot')(inputs['type']) # Convert size strings to indices; e.g. ['small'] -> [1]. size_output = tf.keras.layers.StringLookup(vocabulary=vocab)(inputs['size']) # Normalize the numeric inputs; e.g. [2.0] -> [0.0]. weight_output = tf.keras.layers.Normalization( axis=None, mean=weight_mean, variance=weight_variance)(inputs['weight']) outputs = { 'type': type_output, 'size': size_output, 'weight': weight_output, } preprocessing_model = tf.keras.Model(inputs, outputs)

Note: As an alternative to supplying a vocabulary and normalization statistics on layer creation, many preprocessing layers provide an adapt() method for learning layer state directly from the input data. See the preprocessing guide for more details.

You can now apply this model inside a call to tf.data.Dataset.map. Please note that the function passed to map will automatically be converted into a tf.function, and usual caveats for writing tf.function code apply (no side effects).

# Apply the preprocessing in tf.data.Dataset.map. dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1) dataset = dataset.map(lambda x, y: (preprocessing_model(x), y), num_parallel_calls=tf.data.AUTOTUNE) # Display a preprocessed input sample. next(dataset.take(1).as_numpy_iterator())

Next, you can define a separate Model containing the trainable layers. Note how the inputs to this model now reflect the preprocessed feature types and shapes.

inputs = { 'type': tf.keras.Input(shape=(one_hot_dims,), dtype='float32'), 'size': tf.keras.Input(shape=(), dtype='int64'), 'weight': tf.keras.Input(shape=(), dtype='float32'), } # Since the embedding is trainable, it needs to be part of the training model. embedding = tf.keras.layers.Embedding(len(vocab), embedding_dims) outputs = tf.keras.layers.Concatenate()([ inputs['type'], embedding(inputs['size']), tf.expand_dims(inputs['weight'], -1), ]) outputs = tf.keras.layers.Dense(1)(outputs) training_model = tf.keras.Model(inputs, outputs)

You can now train the training_model with tf.keras.Model.fit.

# Train on the preprocessed data. training_model.compile( loss=tf.keras.losses.BinaryCrossentropy(from_logits=True)) training_model.fit(dataset)

Finally, at inference time, it can be useful to combine these separate stages into a single model that handles raw feature inputs.

inputs = preprocessing_model.input outputs = training_model(preprocessing_model(inputs)) inference_model = tf.keras.Model(inputs, outputs) predict_dataset = tf.data.Dataset.from_tensor_slices(predict_features).batch(1) inference_model.predict(predict_dataset)

This composed model can be saved as a .keras file for later use.

inference_model.save('model.keras') restored_model = tf.keras.models.load_model('model.keras') restored_model.predict(predict_dataset)

Note: Preprocessing layers are not trainable, which allows you to apply them asynchronously using tf.data. This has performance benefits, as you can both prefetch preprocessed batches, and free up any accelerators to focus on the differentiable parts of a model (learn more in the Prefetching section of the Better performance with the tf.data API guide). As this guide shows, separating preprocessing during training and composing it during inference is a flexible way to leverage these performance gains. However, if your model is small or preprocessing time is negligible, it may be simpler to build preprocessing into a complete model from the start. To do this you can build a single model starting with tf.keras.Input, followed by preprocessing layers, followed by trainable layers.

Feature column equivalence table

For reference, here is an approximate correspondence between feature columns and Keras preprocessing layers:

Feature column Keras layer `tf.feature_column.bucketized_column` `tf.keras.layers.Discretization` `tf.feature_column.categorical_column_with_hash_bucket` `tf.keras.layers.Hashing` `tf.feature_column.categorical_column_with_identity` `tf.keras.layers.CategoryEncoding` `tf.feature_column.categorical_column_with_vocabulary_file` `tf.keras.layers.StringLookup` or `tf.keras.layers.IntegerLookup` `tf.feature_column.categorical_column_with_vocabulary_list` `tf.keras.layers.StringLookup` or `tf.keras.layers.IntegerLookup` `tf.feature_column.crossed_column` `tf.keras.layers.experimental.preprocessing.HashedCrossing` `tf.feature_column.embedding_column` `tf.keras.layers.Embedding` `tf.feature_column.indicator_column` `output_mode='one_hot'` or `output_mode='multi_hot'`* `tf.feature_column.numeric_column` `tf.keras.layers.Normalization` `tf.feature_column.sequence_categorical_column_with_hash_bucket` `tf.keras.layers.Hashing` `tf.feature_column.sequence_categorical_column_with_identity` `tf.keras.layers.CategoryEncoding` `tf.feature_column.sequence_categorical_column_with_vocabulary_file` `tf.keras.layers.StringLookup`, `tf.keras.layers.IntegerLookup`, or `tf.keras.layer.TextVectorization`† `tf.feature_column.sequence_categorical_column_with_vocabulary_list` `tf.keras.layers.StringLookup`, `tf.keras.layers.IntegerLookup`, or `tf.keras.layer.TextVectorization`† `tf.feature_column.sequence_numeric_column` `tf.keras.layers.Normalization` `tf.feature_column.weighted_categorical_column` `tf.keras.layers.CategoryEncoding`

* The output_mode can be passed to tf.keras.layers.CategoryEncoding, tf.keras.layers.StringLookup, tf.keras.layers.IntegerLookup, and tf.keras.layers.TextVectorization.

tf.keras.layers.TextVectorization can handle freeform text input directly (for example, entire sentences or paragraphs). This is not one-to-one replacement for categorical sequence handling in TensorFlow 1, but may offer a convenient replacement for ad-hoc text preprocessing.

Note: Linear estimators, such as tf.estimator.LinearClassifier, can handle direct categorical input (integer indices) without an embedding_column or indicator_column. However, integer indices cannot be passed directly to tf.keras.layers.Dense or tf.keras.experimental.LinearModel. These inputs should be first encoded with tf.layers.CategoryEncoding with output_mode='count' (and sparse=True if the category sizes are large) before calling into Dense or LinearModel.

Next steps