Path: blob/master/site/es-419/tutorials/structured_data/imbalanced_data.ipynb
25118 views
Copyright 2019 The TensorFlow Authors.
Classification on imbalanced data
This tutorial demonstrates how to classify a highly imbalanced dataset in which the number of examples in one class greatly outnumbers the examples in another. You will work with the Credit Card Fraud Detection dataset hosted on Kaggle. The aim is to detect a mere 492 fraudulent transactions from 284,807 transactions in total. You will use Keras to define the model and class weights to help the model learn from the imbalanced data. .
This tutorial contains complete code to:
Load a CSV file using Pandas.
Create train, validation, and test sets.
Define and train a model using Keras (including setting class weights).
Evaluate the model using various metrics (including precision and recall).
Select a threshold for a probabilistic classifier to get a deterministic classifier.
Try and compare with class weighted modelling and oversampling.
Setup
Data processing and exploration
Download the Kaggle Credit Card Fraud data set
Pandas is a Python library with many helpful utilities for loading and working with structured data. It can be used to download CSVs into a Pandas DataFrame.
Note: This dataset has been collected and analysed during a research collaboration of Worldline and the Machine Learning Group of ULB (Université Libre de Bruxelles) on big data mining and fraud detection. More details on current and past projects on related topics are available here and the page of the DefeatFraud project
Examine the class label imbalance
Let's look at the dataset imbalance:
This shows the small fraction of positive samples.
Clean, split and normalize the data
The raw data has a few issues. First the Time
and Amount
columns are too variable to use directly. Drop the Time
column (since it's not clear what it means) and take the log of the Amount
column to reduce its range.
Split the dataset into train, validation, and test sets. The validation set is used during the model fitting to evaluate the loss and any metrics, however the model is not fit with this data. The test set is completely unused during the training phase and is only used at the end to evaluate how well the model generalizes to new data. This is especially important with imbalanced datasets where overfitting is a significant concern from the lack of training data.
Given the small number of positive labels, this seems about right.
Normalize the input features using the sklearn StandardScaler. This will set the mean to 0 and standard deviation to 1.
Note: The StandardScaler
is only fit using the train_features
to be sure the model is not peeking at the validation or test sets.
Given the small number of positive labels, this seems about right.
Normalize the input features using the sklearn StandardScaler. This will set the mean to 0 and standard deviation to 1.
Note: The StandardScaler
is only fit using the train_features
to be sure the model is not peeking at the validation or test sets.
Caution: If you want to deploy a model, it's critical that you preserve the preprocessing calculations. The easiest way to implement them as layers, and attach them to your model before export.
Look at the data distribution
Next compare the distributions of the positive and negative examples over a few features. Good questions to ask yourself at this point are:
Do these distributions make sense?
Yes. You've normalized the input and these are mostly concentrated in the
+/- 2
range.
Can you see the difference between the distributions?
Yes the positive examples contain a much higher rate of extreme values.
Define the model and metrics
Define a function that creates a simple neural network with a densly connected hidden layer, a dropout layer to reduce overfitting, and an output sigmoid layer that returns the probability of a transaction being fraudulent:
Understanding useful metrics
Notice that there are a few metrics defined above that can be computed by the model that will be helpful when evaluating the performance. These can be divided into three groups.
Metrics for probability predictions
As we train our network with the cross entropy as a loss function, it is fully capable of predicting class probabilities, i.e. it is a probabilistic classifier. Good metrics to assess probabilistic predictions are, in fact, proper scoring rules. Their key property is that predicting the true probability is optimal. We give two well-known examples:
cross entropy also known as log loss
Mean squared error also known as the Brier score
Metrics for deterministic 0/1 predictions
In the end, one often wants to predict a class label, 0 or 1, no fraud or fraud. This is called a deterministic classifier. To get a label prediction from our probabilistic classifier, one needs to choose a probability threshold . The default is to predict label 1 (fraud) if the predicted probability is larger than and all the following metrics implicitly use this default.
False negatives and false positives are samples that were incorrectly classified
True negatives and true positives are samples that were correctly classified
Accuracy is the percentage of examples correctly classified
Precision is the percentage of predicted positives that were correctly classified
Recall is the percentage of actual positives that were correctly classified
Note: Accuracy is not a helpful metric for this task. You can have 99.8%+ accuracy on this task by predicting False all the time.
Other metrices
The following metrics take into account all possible choices of thresholds .
AUC refers to the Area Under the Curve of a Receiver Operating Characteristic curve (ROC-AUC). This metric is equal to the probability that a classifier will rank a random positive sample higher than a random negative sample.
AUPRC refers to Area Under the Curve of the Precision-Recall Curve. This metric computes precision-recall pairs for different probability thresholds.
Read more:
Baseline model
Build the model
Now create and train your model using the function that was defined earlier. Notice that the model is fit using a larger than default batch size of 2048, this is important to ensure that each batch has a decent chance of containing a few positive samples. If the batch size was too small, they would likely have no fraudulent transactions to learn from.
Note: Fitting this model will not handle the class imbalance efficiently. You will improve it later in this tutorial.
Test run the model:
Optional: Set the correct initial bias.
These initial guesses are not great. You know the dataset is imbalanced. Set the output layer's bias to reflect that, see A Recipe for Training Neural Networks: "init well". This can help with initial convergence.
With the default bias initialization the loss should be about math.log(2) = 0.69314
The correct bias to set can be derived from:
Set that as the initial bias, and the model will give much more reasonable initial guesses.
It should be near: pos/total = 0.0018
With this initialization the initial loss should be approximately:
This initial loss is about 50 times less than it would have been with naive initialization.
This way the model doesn't need to spend the first few epochs just learning that positive examples are unlikely. It also makes it easier to read plots of the loss during training.
Checkpoint the initial weights
To make the various training runs more comparable, keep this initial model's weights in a checkpoint file, and load them into each model before training:
Confirm that the bias fix helps
Before moving on, confirm quick that the careful bias initialization actually helped.
Train the model for 20 epochs, with and without this careful initialization, and compare the losses:
The above figure makes it clear: In terms of validation loss, on this problem, this careful initialization gives a clear advantage.
Train the model
Check training history
In this section, you will produce plots of your model's accuracy and loss on the training and validation set. These are useful to check for overfitting, which you can learn more about in the Overfit and underfit tutorial.
Additionally, you can produce these plots for any of the metrics you created above. False negatives are included as an example.
Note: That the validation curve generally performs better than the training curve. This is mainly caused by the fact that the dropout layer is not active when evaluating the model.
Evaluate metrics
You can use a confusion matrix to summarize the actual vs. predicted labels, where the X axis is the predicted label and the Y axis is the actual label:
Evaluate your model on the test dataset and display the results for the metrics you created above:
If the model had predicted everything perfectly (impossible with true randomness), this would be a diagonal matrix where values off the main diagonal, indicating incorrect predictions, would be zero. In this case, the matrix shows that you have relatively few false positives, meaning that there were relatively few legitimate transactions that were incorrectly flagged.
Plot the ROC
Now plot the ROC. This plot is useful because it shows, at a glance, the range of performance the model can reach by tuning the output threshold over its full range (0 to 1). So each point corresponds to a single value of the threshold.
Plot the ROC
Now plot the ROC. This plot is useful because it shows, at a glance, the range of performance the model can reach by tuning the output threshold over its full range (0 to 1). So each point corresponds to a single value of the threshold.
Plot the PRC
Now plot the AUPRC. Area under the interpolated precision-recall curve, obtained by plotting (recall, precision) points for different values of the classification threshold. Depending on how it's calculated, PR AUC may be equivalent to the average precision of the model.
It looks like the precision is relatively high, but the recall and the area under the ROC curve (AUC) aren't as high as you might like. Classifiers often face challenges when trying to maximize both precision and recall, which is especially true when working with imbalanced datasets. It is important to consider the costs of different types of errors in the context of the problem you care about. In this example, a false negative (a fraudulent transaction is missed) may have a financial cost, while a false positive (a transaction is incorrectly flagged as fraudulent) may decrease user happiness.
Class weights
Calculate class weights
The goal is to identify fraudulent transactions, but you don't have very many of those positive samples to work with, so you would want to have the classifier heavily weight the few examples that are available. You can do this by passing Keras weights for each class through a parameter. These will cause the model to "pay more attention" to examples from an under-represented class. Note, however, that this does not increase in any way the amount of information of your dataset. In the end, using class weights is more or less equivalent to changing the output bias or to changing the threshold. Let's see how it works out.
Train a model with class weights
Now try re-training and evaluating the model with class weights to see how that affects the predictions.
Note: Using class_weights
changes the range of the loss. This may affect the stability of the training depending on the optimizer. Optimizers whose step size is dependent on the magnitude of the gradient, like tf.keras.optimizers.SGD
, may fail. The optimizer used here, tf.keras.optimizers.Adam
, is unaffected by the scaling change. Also note that because of the weighting, the total losses are not comparable between the two models.
Check training history
Evaluate metrics
Here you can see that with class weights the accuracy and precision are lower because there are more false positives, but conversely the recall and AUC are higher because the model also found more true positives. Despite having lower accuracy, this model has higher recall (and identifies more fraudulent transactions than the baseline model at threshold 50%). Of course, there is a cost to both types of error (you wouldn't want to bug users by flagging too many legitimate transactions as fraudulent, either). Carefully consider the trade-offs between these different types of errors for your application.
Compared to the baseline model with changed threshold, the class weighted model is clearly inferior. The superiority of the baseline model is further confirmed by the lower test loss value (cross entropy and mean squared error) and additionally can be seen by plotting the ROC curves of both models together.
Plot the ROC
Plot the PRC
Oversampling
Oversample the minority class
A related approach would be to resample the dataset by oversampling the minority class.
Using NumPy
You can balance the dataset manually by choosing the right number of random indices from the positive examples:
Using tf.data
If you're using tf.data
the easiest way to produce balanced examples is to start with a positive
and a negative
dataset, and merge them. See the tf.data guide for more examples.
Each dataset provides (feature, label)
pairs:
Merge the two together using tf.data.Dataset.sample_from_datasets
:
To use this dataset, you'll need the number of steps per epoch.
The definition of "epoch" in this case is less clear. Say it's the number of batches required to see each negative example once:
Train on the oversampled data
Now try training the model with the resampled data set instead of using class weights to see how these methods compare.
Note: Because the data was balanced by replicating the positive examples, the total dataset size is larger, and each epoch runs for more training steps.
If the training process were considering the whole dataset on each gradient update, this oversampling would be basically identical to the class weighting.
But when training the model batch-wise, as you did here, the oversampled data provides a smoother gradient signal: Instead of each positive example being shown in one batch with a large weight, they're shown in many different batches each time with a small weight.
This smoother gradient signal makes it easier to train the model.
Check training history
Note that the distributions of metrics will be different here, because the training data has a totally different distribution from the validation and test data.
Re-train
Because training is easier on the balanced data, the above training procedure may overfit quickly.
So break up the epochs to give the tf.keras.callbacks.EarlyStopping
finer control over when to stop training.
Re-check training history
Evaluate metrics
Plot the ROC
Plot the AUPRC
Applying this tutorial to your problem
Imbalanced data classification is an inherently difficult task since there are so few samples to learn from. You should always start with the data first and do your best to collect as many samples as possible and give substantial thought to what features may be relevant so the model can get the most out of your minority class. At some point your model may struggle to improve and yield the results you want, so it is important to keep in mind the context of your problem and the trade offs between different types of errors.