Path: blob/master/unbalanced/unbalanced_code/unbalanced_functions.R
1480 views
# Useful functions when working with logistic regression1library(ROCR)2library(grid)3library(caret)4library(dplyr)5library(scales)6library(ggplot2)7library(gridExtra)8library(data.table)910# ------------------------------------------------------------------------------------------11# [AccuracyCutoffInfo] :12# Obtain the accuracy on the trainining and testing dataset.13# for cutoff value ranging from .4 to .8 ( with a .05 increase )14# @train : your data.table or data.frame type training data ( assumes you have the predicted score in it ).15# @test : your data.table or data.frame type testing data16# @predict : prediction's column name (assumes the same for training and testing set)17# @actual : actual results' column name18# returns : 1. data : a data.table with three columns.19# each row indicates the cutoff value and the accuracy for the20# train and test set respectively.21# 2. plot : plot that visualizes the data.table2223AccuracyCutoffInfo <- function( train, test, predict, actual )24{25# change the cutoff value's range as you please26cutoff <- seq( .4, .8, by = .05 )2728accuracy <- lapply( cutoff, function(c)29{30# use the confusionMatrix from the caret package31data_train <- as.factor( as.numeric( train[[predict]] > c ) )32cm_train <- confusionMatrix(data_train, as.factor(train[[actual]]) )33data_test <- as.factor( as.numeric( test[[predict]] > c ) )34cm_test <- confusionMatrix( data_test, as.factor(test[[actual]]) )3536dt <- data.table( cutoff = c,37train = cm_train$overall[["Accuracy"]],38test = cm_test$overall[["Accuracy"]] )39return(dt)40}) %>% rbindlist()4142# visualize the accuracy of the train and test set for different cutoff value43# accuracy in percentage.44accuracy_long <- gather( accuracy, "data", "accuracy", -1 )4546plot <- ggplot( accuracy_long, aes( cutoff, accuracy, group = data, color = data ) ) +47geom_line( size = 1 ) + geom_point( size = 3 ) +48scale_y_continuous( label = percent ) +49ggtitle( "Train/Test Accuracy for Different Cutoff" )5051return( list( data = accuracy, plot = plot ) )52}535455# ------------------------------------------------------------------------------------------56# [ConfusionMatrixInfo] :57# Obtain the confusion matrix plot and data.table for a given58# dataset that already consists the predicted score and actual outcome.59# @data : your data.table or data.frame type data that consists the column60# of the predicted score and actual outcome61# @predict : predicted score's column name62# @actual : actual results' column name63# @cutoff : cutoff value for the prediction score64# return : 1. data : a data.table consisting of three column65# the first two stores the original value of the prediction and actual outcome from66# the passed in data frame, the third indicates the type, which is after choosing the67# cutoff value, will this row be a true/false positive/ negative68# 2. plot : plot that visualizes the data.table6970ConfusionMatrixInfo <- function( data, predict, actual, cutoff )71{72# extract the column ;73# relevel making 1 appears on the more commonly seen position in74# a two by two confusion matrix75predict <- data[[predict]]76actual <- relevel( as.factor( data[[actual]] ), "1" )7778result <- data.table( actual = actual, predict = predict )7980# calculating each pred falls into which category for the confusion matrix81result[ , type := ifelse( predict >= cutoff & actual == 1, "TP",82ifelse( predict >= cutoff & actual == 0, "FP",83ifelse( predict < cutoff & actual == 1, "FN", "TN" ) ) ) %>% as.factor() ]8485# jittering : can spread the points along the x axis86plot <- ggplot( result, aes( actual, predict, color = type ) ) +87geom_violin( fill = "white", color = NA ) +88geom_jitter( shape = 1 ) +89geom_hline( yintercept = cutoff, color = "blue", alpha = 0.6 ) +90scale_y_continuous( limits = c( 0, 1 ) ) +91scale_color_discrete( breaks = c( "TP", "FN", "FP", "TN" ) ) + # ordering of the legend92guides( col = guide_legend( nrow = 2 ) ) + # adjust the legend to have two rows93ggtitle( sprintf( "Confusion Matrix with Cutoff at %.2f", cutoff ) )9495return( list( data = result, plot = plot ) )96}979899# ------------------------------------------------------------------------------------------100# [ROCInfo] :101# Pass in the data that already consists the predicted score and actual outcome.102# to obtain the ROC curve103# @data : your data.table or data.frame type data that consists the column104# of the predicted score and actual outcome105# @predict : predicted score's column name106# @actual : actual results' column name107# @cost.fp : associated cost for a false positive108# @cost.fn : associated cost for a false negative109# return : a list containing110# 1. plot : a side by side roc and cost plot, title showing optimal cutoff value111# title showing optimal cutoff, total cost, and area under the curve (auc)112# 2. cutoff : optimal cutoff value according to the specified fp/fn cost113# 3. totalcost : total cost according to the specified fp/fn cost114# 4. auc : area under the curve115# 5. sensitivity : TP / (TP + FN)116# 6. specificity : TN / (FP + TN)117118ROCInfo <- function( data, predict, actual, cost.fp, cost.fn )119{120# calculate the values using the ROCR library121# true positive, false postive122pred <- prediction( data[[predict]], data[[actual]] )123perf <- performance( pred, "tpr", "fpr" )124roc_dt <- data.frame( fpr = perf@x.values[[1]], tpr = perf@y.values[[1]] )125126# cost with the specified false positive and false negative cost127# false postive rate * number of negative instances * false positive cost +128# false negative rate * number of positive instances * false negative cost129cost <- perf@x.values[[1]] * cost.fp * sum( data[[actual]] == 0 ) +130( 1 - perf@y.values[[1]] ) * cost.fn * sum( data[[actual]] == 1 )131132cost_dt <- data.frame( cutoff = pred@cutoffs[[1]], cost = cost )133134# optimal cutoff value, and the corresponding true positive and false positive rate135best_index <- which.min(cost)136best_cost <- cost_dt[ best_index, "cost" ]137best_tpr <- roc_dt[ best_index, "tpr" ]138best_fpr <- roc_dt[ best_index, "fpr" ]139best_cutoff <- pred@cutoffs[[1]][ best_index ]140141# area under the curve142auc <- performance( pred, "auc" )@y.values[[1]]143144# normalize the cost to assign colors to 1145normalize <- function(v) ( v - min(v) ) / diff( range(v) )146147# create color from a palette to assign to the 100 generated threshold between 0 ~ 1148# then normalize each cost and assign colors to it, the higher the blacker149# don't times it by 100, there will be 0 in the vector150col_ramp <- colorRampPalette( c( "green", "orange", "red", "black" ) )(100)151col_by_cost <- col_ramp[ ceiling( normalize(cost) * 99 ) + 1 ]152153roc_plot <- ggplot( roc_dt, aes( fpr, tpr ) ) +154geom_line( color = rgb( 0, 0, 1, alpha = 0.3 ) ) +155geom_point( color = col_by_cost, size = 4, alpha = 0.2 ) +156geom_segment( aes( x = 0, y = 0, xend = 1, yend = 1 ), alpha = 0.8, color = "royalblue" ) +157labs( title = "ROC", x = "False Postive Rate", y = "True Positive Rate" ) +158geom_hline( yintercept = best_tpr, alpha = 0.8, linetype = "dashed", color = "steelblue4" ) +159geom_vline( xintercept = best_fpr, alpha = 0.8, linetype = "dashed", color = "steelblue4" )160161cost_plot <- ggplot( cost_dt, aes( cutoff, cost ) ) +162geom_line( color = "blue", alpha = 0.5 ) +163geom_point( color = col_by_cost, size = 4, alpha = 0.5 ) +164ggtitle( "Cost" ) +165scale_y_continuous( labels = comma ) +166geom_vline( xintercept = best_cutoff, alpha = 0.8, linetype = "dashed", color = "steelblue4" )167168# the main title for the two arranged plot169sub_title <- sprintf( "Cutoff at %.2f - Total Cost = %d, AUC = %.3f",170best_cutoff, best_cost, auc )171172# arranged into a side by side plot173plot <- arrangeGrob( roc_plot, cost_plot, ncol = 2,174top = textGrob( sub_title, gp = gpar( fontsize = 16, fontface = "bold" ) ) )175176return( list( plot = plot,177cutoff = best_cutoff,178totalcost = best_cost,179auc = auc,180sensitivity = best_tpr,181specificity = 1 - best_fpr ) )182}183184185186