Path: blob/master/site/en-snapshot/lattice/tutorials/aggregate_function_models.ipynb
25118 views
Copyright 2020 The TensorFlow Authors.
TF Lattice Aggregate Function Models
Overview
TFL Premade Aggregate Function Models are quick and easy ways to build TFL tf.keras.model
instances for learning complex aggregation functions. This guide outlines the steps needed to construct a TFL Premade Aggregate Function Model and train/test it.
Setup
Installing TF Lattice package:
Importing required packages:
Downloading the Puzzles dataset:
Extract and convert features and labels
Setting the default values used for training in this guide:
Feature Configs
Feature calibration and per-feature configurations are set using tfl.configs.FeatureConfig. Feature configurations include monotonicity constraints, per-feature regularization (see tfl.configs.RegularizerConfig), and lattice sizes for lattice models.
Note that we must fully specify the feature config for any feature that we want our model to recognize. Otherwise the model will have no way of knowing that such a feature exists. For aggregation models, these features will automaticaly be considered and properly handled as ragged.
Compute Quantiles
Although the default setting for pwl_calibration_input_keypoints
in tfl.configs.FeatureConfig
is 'quantiles', for premade models we have to manually define the input keypoints. To do so, we first define our own helper function for computing quantiles.
Defining Our Feature Configs
Now that we can compute our quantiles, we define a feature config for each feature that we want our model to take as input.
Aggregate Function Model
To construct a TFL premade model, first construct a model configuration from tfl.configs. An aggregate function model is constructed using the tfl.configs.AggregateFunctionConfig. It applies piecewise-linear and categorical calibration, followed by a lattice model on each dimension of the ragged input. It then applies an aggregation layer over the output for each dimension. This is then followed by an optional output piecewise-linear calibration.
The output of each Aggregation layer is the averaged output of a calibrated lattice over the ragged inputs. Here is the model used inside the first Aggregation layer:
Now, as with any other tf.keras.Model, we compile and fit the model to our data.
After training our model, we can evaluate it on our test set.