Path: blob/master/examples/structured_data/deep_neural_decision_forests.py
8071 views
"""1Title: Classification with Neural Decision Forests2Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)3Date created: 2021/01/154Last modified: 2021/01/155Description: How to train differentiable decision trees for end-to-end learning in deep neural networks.6Accelerator: GPU7"""89"""10## Introduction1112This example provides an implementation of the13[Deep Neural Decision Forest](https://ieeexplore.ieee.org/document/7410529)14model introduced by P. Kontschieder et al. for structured data classification.15It demonstrates how to build a stochastic and differentiable decision tree model,16train it end-to-end, and unify decision trees with deep representation learning.1718## The dataset1920This example uses the21[United States Census Income Dataset](https://archive.ics.uci.edu/ml/datasets/census+income)22provided by the23[UC Irvine Machine Learning Repository](https://archive.ics.uci.edu/ml/index.php).24The task is binary classification25to predict whether a person is likely to be making over USD 50,000 a year.2627The dataset includes 48,842 instances with 14 input features (such as age, work class, education, occupation, and so on): 5 numerical features28and 9 categorical features.29"""3031"""32## Setup33"""3435import keras36from keras import layers37from keras.layers import StringLookup38from keras import ops394041from tensorflow import data as tf_data42import numpy as np43import pandas as pd4445import math4647"""48## Prepare the data49"""5051CSV_HEADER = [52"age",53"workclass",54"fnlwgt",55"education",56"education_num",57"marital_status",58"occupation",59"relationship",60"race",61"gender",62"capital_gain",63"capital_loss",64"hours_per_week",65"native_country",66"income_bracket",67]6869train_data_url = (70"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"71)72train_data = pd.read_csv(train_data_url, header=None, names=CSV_HEADER)7374test_data_url = (75"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test"76)77test_data = pd.read_csv(test_data_url, header=None, names=CSV_HEADER)7879print(f"Train dataset shape: {train_data.shape}")80print(f"Test dataset shape: {test_data.shape}")8182"""83Remove the first record (because it is not a valid data example) and a trailing84'dot' in the class labels.85"""8687test_data = test_data[1:]88test_data.income_bracket = test_data.income_bracket.apply(89lambda value: value.replace(".", "")90)9192"""93We store the training and test data splits locally as CSV files.94"""9596train_data_file = "train_data.csv"97test_data_file = "test_data.csv"9899train_data.to_csv(train_data_file, index=False, header=False)100test_data.to_csv(test_data_file, index=False, header=False)101102"""103## Define dataset metadata104105Here, we define the metadata of the dataset that will be useful for reading and parsing106and encoding input features.107"""108109# A list of the numerical feature names.110NUMERIC_FEATURE_NAMES = [111"age",112"education_num",113"capital_gain",114"capital_loss",115"hours_per_week",116]117# A dictionary of the categorical features and their vocabulary.118CATEGORICAL_FEATURES_WITH_VOCABULARY = {119"workclass": sorted(list(train_data["workclass"].unique())),120"education": sorted(list(train_data["education"].unique())),121"marital_status": sorted(list(train_data["marital_status"].unique())),122"occupation": sorted(list(train_data["occupation"].unique())),123"relationship": sorted(list(train_data["relationship"].unique())),124"race": sorted(list(train_data["race"].unique())),125"gender": sorted(list(train_data["gender"].unique())),126"native_country": sorted(list(train_data["native_country"].unique())),127}128# A list of the columns to ignore from the dataset.129IGNORE_COLUMN_NAMES = ["fnlwgt"]130# A list of the categorical feature names.131CATEGORICAL_FEATURE_NAMES = list(CATEGORICAL_FEATURES_WITH_VOCABULARY.keys())132# A list of all the input features.133FEATURE_NAMES = NUMERIC_FEATURE_NAMES + CATEGORICAL_FEATURE_NAMES134# A list of column default values for each feature.135COLUMN_DEFAULTS = [136[0.0] if feature_name in NUMERIC_FEATURE_NAMES + IGNORE_COLUMN_NAMES else ["NA"]137for feature_name in CSV_HEADER138]139# The name of the target feature.140TARGET_FEATURE_NAME = "income_bracket"141# A list of the labels of the target features.142TARGET_LABELS = [" <=50K", " >50K"]143144"""145## Create `tf_data.Dataset` objects for training and validation146147We create an input function to read and parse the file, and convert features and labels148into a [`tf_data.Dataset`](https://www.tensorflow.org/guide/datasets)149for training and validation. We also preprocess the input by mapping the target label150to an index.151"""152153154target_label_lookup = StringLookup(155vocabulary=TARGET_LABELS, mask_token=None, num_oov_indices=0156)157158159lookup_dict = {}160for feature_name in CATEGORICAL_FEATURE_NAMES:161vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]162# Create a lookup to convert a string values to an integer indices.163# Since we are not using a mask token, nor expecting any out of vocabulary164# (oov) token, we set mask_token to None and num_oov_indices to 0.165lookup = StringLookup(vocabulary=vocabulary, mask_token=None, num_oov_indices=0)166lookup_dict[feature_name] = lookup167168169def encode_categorical(batch_x, batch_y):170for feature_name in CATEGORICAL_FEATURE_NAMES:171batch_x[feature_name] = lookup_dict[feature_name](batch_x[feature_name])172173return batch_x, batch_y174175176def get_dataset_from_csv(csv_file_path, shuffle=False, batch_size=128):177dataset = (178tf_data.experimental.make_csv_dataset(179csv_file_path,180batch_size=batch_size,181column_names=CSV_HEADER,182column_defaults=COLUMN_DEFAULTS,183label_name=TARGET_FEATURE_NAME,184num_epochs=1,185header=False,186na_value="?",187shuffle=shuffle,188)189.map(lambda features, target: (features, target_label_lookup(target)))190.map(encode_categorical)191)192193return dataset.cache()194195196"""197## Create model inputs198"""199200201def create_model_inputs():202inputs = {}203for feature_name in FEATURE_NAMES:204if feature_name in NUMERIC_FEATURE_NAMES:205inputs[feature_name] = layers.Input(206name=feature_name, shape=(), dtype="float32"207)208else:209inputs[feature_name] = layers.Input(210name=feature_name, shape=(), dtype="int32"211)212return inputs213214215"""216## Encode input features217"""218219220def encode_inputs(inputs):221encoded_features = []222for feature_name in inputs:223if feature_name in CATEGORICAL_FEATURE_NAMES:224vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]225# Create a lookup to convert a string values to an integer indices.226# Since we are not using a mask token, nor expecting any out of vocabulary227# (oov) token, we set mask_token to None and num_oov_indices to 0.228value_index = inputs[feature_name]229embedding_dims = int(math.sqrt(lookup.vocabulary_size()))230# Create an embedding layer with the specified dimensions.231embedding = layers.Embedding(232input_dim=lookup.vocabulary_size(), output_dim=embedding_dims233)234# Convert the index values to embedding representations.235encoded_feature = embedding(value_index)236else:237# Use the numerical features as-is.238encoded_feature = inputs[feature_name]239if inputs[feature_name].shape[-1] is None:240encoded_feature = keras.ops.expand_dims(encoded_feature, -1)241242encoded_features.append(encoded_feature)243244encoded_features = layers.concatenate(encoded_features)245return encoded_features246247248"""249## Deep Neural Decision Tree250251A neural decision tree model has two sets of weights to learn. The first set is `pi`,252which represents the probability distribution of the classes in the tree leaves.253The second set is the weights of the routing layer `decision_fn`, which represents the probability254of going to each leave. The forward pass of the model works as follows:2552561. The model expects input `features` as a single vector encoding all the features of an instance257in the batch. This vector can be generated from a Convolution Neural Network (CNN) applied to images258or dense transformations applied to structured data features.2592. The model first applies a `used_features_mask` to randomly select a subset of input features to use.2603. Then, the model computes the probabilities (`mu`) for the input instances to reach the tree leaves261by iteratively performing a *stochastic* routing throughout the tree levels.2624. Finally, the probabilities of reaching the leaves are combined by the class probabilities at the263leaves to produce the final `outputs`.264"""265266267class NeuralDecisionTree(keras.Model):268def __init__(self, depth, num_features, used_features_rate, num_classes):269super().__init__()270self.depth = depth271self.num_leaves = 2**depth272self.num_classes = num_classes273274# Create a mask for the randomly selected features.275num_used_features = int(num_features * used_features_rate)276one_hot = np.eye(num_features)277sampled_feature_indices = np.random.choice(278np.arange(num_features), num_used_features, replace=False279)280self.used_features_mask = ops.convert_to_tensor(281one_hot[sampled_feature_indices], dtype="float32"282)283284# Initialize the weights of the classes in leaves.285self.pi = self.add_weight(286initializer="random_normal",287shape=[self.num_leaves, self.num_classes],288dtype="float32",289trainable=True,290)291292# Initialize the stochastic routing layer.293self.decision_fn = layers.Dense(294units=self.num_leaves, activation="sigmoid", name="decision"295)296297def call(self, features):298batch_size = ops.shape(features)[0]299300# Apply the feature mask to the input features.301features = ops.matmul(302features, ops.transpose(self.used_features_mask)303) # [batch_size, num_used_features]304# Compute the routing probabilities.305decisions = ops.expand_dims(306self.decision_fn(features), axis=2307) # [batch_size, num_leaves, 1]308# Concatenate the routing probabilities with their complements.309decisions = layers.concatenate(310[decisions, 1 - decisions], axis=2311) # [batch_size, num_leaves, 2]312313mu = ops.ones([batch_size, 1, 1])314315begin_idx = 1316end_idx = 2317# Traverse the tree in breadth-first order.318for level in range(self.depth):319mu = ops.reshape(mu, [batch_size, -1, 1]) # [batch_size, 2 ** level, 1]320mu = ops.tile(mu, (1, 1, 2)) # [batch_size, 2 ** level, 2]321level_decisions = decisions[322:, begin_idx:end_idx, :323] # [batch_size, 2 ** level, 2]324mu = mu * level_decisions # [batch_size, 2**level, 2]325begin_idx = end_idx326end_idx = begin_idx + 2 ** (level + 1)327328mu = ops.reshape(mu, [batch_size, self.num_leaves]) # [batch_size, num_leaves]329probabilities = keras.activations.softmax(self.pi) # [num_leaves, num_classes]330outputs = ops.matmul(mu, probabilities) # [batch_size, num_classes]331return outputs332333334"""335## Deep Neural Decision Forest336337The neural decision forest model consists of a set of neural decision trees that are338trained simultaneously. The output of the forest model is the average outputs of its trees.339"""340341342class NeuralDecisionForest(keras.Model):343def __init__(self, num_trees, depth, num_features, used_features_rate, num_classes):344super().__init__()345self.ensemble = []346# Initialize the ensemble by adding NeuralDecisionTree instances.347# Each tree will have its own randomly selected input features to use.348for _ in range(num_trees):349self.ensemble.append(350NeuralDecisionTree(depth, num_features, used_features_rate, num_classes)351)352353def call(self, inputs):354# Initialize the outputs: a [batch_size, num_classes] matrix of zeros.355batch_size = ops.shape(inputs)[0]356outputs = ops.zeros([batch_size, num_classes])357358# Aggregate the outputs of trees in the ensemble.359for tree in self.ensemble:360outputs += tree(inputs)361# Divide the outputs by the ensemble size to get the average.362outputs /= len(self.ensemble)363return outputs364365366"""367Finally, let's set up the code that will train and evaluate the model.368"""369370learning_rate = 0.01371batch_size = 265372num_epochs = 10373374375def run_experiment(model):376model.compile(377optimizer=keras.optimizers.Adam(learning_rate=learning_rate),378loss=keras.losses.SparseCategoricalCrossentropy(),379metrics=[keras.metrics.SparseCategoricalAccuracy()],380)381382print("Start training the model...")383train_dataset = get_dataset_from_csv(384train_data_file, shuffle=True, batch_size=batch_size385)386387model.fit(train_dataset, epochs=num_epochs)388print("Model training finished")389390print("Evaluating the model on the test data...")391test_dataset = get_dataset_from_csv(test_data_file, batch_size=batch_size)392393_, accuracy = model.evaluate(test_dataset)394print(f"Test accuracy: {round(accuracy * 100, 2)}%")395396397"""398## Experiment 1: train a decision tree model399400In this experiment, we train a single neural decision tree model401where we use all input features.402"""403404num_trees = 10405depth = 10406used_features_rate = 1.0407num_classes = len(TARGET_LABELS)408409410def create_tree_model():411inputs = create_model_inputs()412features = encode_inputs(inputs)413features = layers.BatchNormalization()(features)414num_features = features.shape[1]415416tree = NeuralDecisionTree(depth, num_features, used_features_rate, num_classes)417418outputs = tree(features)419model = keras.Model(inputs=inputs, outputs=outputs)420return model421422423tree_model = create_tree_model()424run_experiment(tree_model)425426427"""428## Experiment 2: train a forest model429430In this experiment, we train a neural decision forest with `num_trees` trees431where each tree uses randomly selected 50% of the input features. You can control the number432of features to be used in each tree by setting the `used_features_rate` variable.433In addition, we set the depth to 5 instead of 10 compared to the previous experiment.434"""435436num_trees = 25437depth = 5438used_features_rate = 0.5439440441def create_forest_model():442inputs = create_model_inputs()443features = encode_inputs(inputs)444features = layers.BatchNormalization()(features)445num_features = features.shape[1]446447forest_model = NeuralDecisionForest(448num_trees, depth, num_features, used_features_rate, num_classes449)450451outputs = forest_model(features)452model = keras.Model(inputs=inputs, outputs=outputs)453return model454455456forest_model = create_forest_model()457458run_experiment(forest_model)459460461