Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ethen8181
GitHub Repository: ethen8181/machine-learning
Path: blob/master/unbalanced/unbalanced_code/unbalanced_functions.R
1480 views
1
# Useful functions when working with logistic regression
2
library(ROCR)
3
library(grid)
4
library(caret)
5
library(dplyr)
6
library(scales)
7
library(ggplot2)
8
library(gridExtra)
9
library(data.table)
10
11
# ------------------------------------------------------------------------------------------
12
# [AccuracyCutoffInfo] :
13
# Obtain the accuracy on the trainining and testing dataset.
14
# for cutoff value ranging from .4 to .8 ( with a .05 increase )
15
# @train : your data.table or data.frame type training data ( assumes you have the predicted score in it ).
16
# @test : your data.table or data.frame type testing data
17
# @predict : prediction's column name (assumes the same for training and testing set)
18
# @actual : actual results' column name
19
# returns : 1. data : a data.table with three columns.
20
# each row indicates the cutoff value and the accuracy for the
21
# train and test set respectively.
22
# 2. plot : plot that visualizes the data.table
23
24
AccuracyCutoffInfo <- function( train, test, predict, actual )
25
{
26
# change the cutoff value's range as you please
27
cutoff <- seq( .4, .8, by = .05 )
28
29
accuracy <- lapply( cutoff, function(c)
30
{
31
# use the confusionMatrix from the caret package
32
data_train <- as.factor( as.numeric( train[[predict]] > c ) )
33
cm_train <- confusionMatrix(data_train, as.factor(train[[actual]]) )
34
data_test <- as.factor( as.numeric( test[[predict]] > c ) )
35
cm_test <- confusionMatrix( data_test, as.factor(test[[actual]]) )
36
37
dt <- data.table( cutoff = c,
38
train = cm_train$overall[["Accuracy"]],
39
test = cm_test$overall[["Accuracy"]] )
40
return(dt)
41
}) %>% rbindlist()
42
43
# visualize the accuracy of the train and test set for different cutoff value
44
# accuracy in percentage.
45
accuracy_long <- gather( accuracy, "data", "accuracy", -1 )
46
47
plot <- ggplot( accuracy_long, aes( cutoff, accuracy, group = data, color = data ) ) +
48
geom_line( size = 1 ) + geom_point( size = 3 ) +
49
scale_y_continuous( label = percent ) +
50
ggtitle( "Train/Test Accuracy for Different Cutoff" )
51
52
return( list( data = accuracy, plot = plot ) )
53
}
54
55
56
# ------------------------------------------------------------------------------------------
57
# [ConfusionMatrixInfo] :
58
# Obtain the confusion matrix plot and data.table for a given
59
# dataset that already consists the predicted score and actual outcome.
60
# @data : your data.table or data.frame type data that consists the column
61
# of the predicted score and actual outcome
62
# @predict : predicted score's column name
63
# @actual : actual results' column name
64
# @cutoff : cutoff value for the prediction score
65
# return : 1. data : a data.table consisting of three column
66
# the first two stores the original value of the prediction and actual outcome from
67
# the passed in data frame, the third indicates the type, which is after choosing the
68
# cutoff value, will this row be a true/false positive/ negative
69
# 2. plot : plot that visualizes the data.table
70
71
ConfusionMatrixInfo <- function( data, predict, actual, cutoff )
72
{
73
# extract the column ;
74
# relevel making 1 appears on the more commonly seen position in
75
# a two by two confusion matrix
76
predict <- data[[predict]]
77
actual <- relevel( as.factor( data[[actual]] ), "1" )
78
79
result <- data.table( actual = actual, predict = predict )
80
81
# calculating each pred falls into which category for the confusion matrix
82
result[ , type := ifelse( predict >= cutoff & actual == 1, "TP",
83
ifelse( predict >= cutoff & actual == 0, "FP",
84
ifelse( predict < cutoff & actual == 1, "FN", "TN" ) ) ) %>% as.factor() ]
85
86
# jittering : can spread the points along the x axis
87
plot <- ggplot( result, aes( actual, predict, color = type ) ) +
88
geom_violin( fill = "white", color = NA ) +
89
geom_jitter( shape = 1 ) +
90
geom_hline( yintercept = cutoff, color = "blue", alpha = 0.6 ) +
91
scale_y_continuous( limits = c( 0, 1 ) ) +
92
scale_color_discrete( breaks = c( "TP", "FN", "FP", "TN" ) ) + # ordering of the legend
93
guides( col = guide_legend( nrow = 2 ) ) + # adjust the legend to have two rows
94
ggtitle( sprintf( "Confusion Matrix with Cutoff at %.2f", cutoff ) )
95
96
return( list( data = result, plot = plot ) )
97
}
98
99
100
# ------------------------------------------------------------------------------------------
101
# [ROCInfo] :
102
# Pass in the data that already consists the predicted score and actual outcome.
103
# to obtain the ROC curve
104
# @data : your data.table or data.frame type data that consists the column
105
# of the predicted score and actual outcome
106
# @predict : predicted score's column name
107
# @actual : actual results' column name
108
# @cost.fp : associated cost for a false positive
109
# @cost.fn : associated cost for a false negative
110
# return : a list containing
111
# 1. plot : a side by side roc and cost plot, title showing optimal cutoff value
112
# title showing optimal cutoff, total cost, and area under the curve (auc)
113
# 2. cutoff : optimal cutoff value according to the specified fp/fn cost
114
# 3. totalcost : total cost according to the specified fp/fn cost
115
# 4. auc : area under the curve
116
# 5. sensitivity : TP / (TP + FN)
117
# 6. specificity : TN / (FP + TN)
118
119
ROCInfo <- function( data, predict, actual, cost.fp, cost.fn )
120
{
121
# calculate the values using the ROCR library
122
# true positive, false postive
123
pred <- prediction( data[[predict]], data[[actual]] )
124
perf <- performance( pred, "tpr", "fpr" )
125
roc_dt <- data.frame( fpr = perf@x.values[[1]], tpr = perf@y.values[[1]] )
126
127
# cost with the specified false positive and false negative cost
128
# false postive rate * number of negative instances * false positive cost +
129
# false negative rate * number of positive instances * false negative cost
130
cost <- perf@x.values[[1]] * cost.fp * sum( data[[actual]] == 0 ) +
131
( 1 - perf@y.values[[1]] ) * cost.fn * sum( data[[actual]] == 1 )
132
133
cost_dt <- data.frame( cutoff = pred@cutoffs[[1]], cost = cost )
134
135
# optimal cutoff value, and the corresponding true positive and false positive rate
136
best_index <- which.min(cost)
137
best_cost <- cost_dt[ best_index, "cost" ]
138
best_tpr <- roc_dt[ best_index, "tpr" ]
139
best_fpr <- roc_dt[ best_index, "fpr" ]
140
best_cutoff <- pred@cutoffs[[1]][ best_index ]
141
142
# area under the curve
143
auc <- performance( pred, "auc" )@y.values[[1]]
144
145
# normalize the cost to assign colors to 1
146
normalize <- function(v) ( v - min(v) ) / diff( range(v) )
147
148
# create color from a palette to assign to the 100 generated threshold between 0 ~ 1
149
# then normalize each cost and assign colors to it, the higher the blacker
150
# don't times it by 100, there will be 0 in the vector
151
col_ramp <- colorRampPalette( c( "green", "orange", "red", "black" ) )(100)
152
col_by_cost <- col_ramp[ ceiling( normalize(cost) * 99 ) + 1 ]
153
154
roc_plot <- ggplot( roc_dt, aes( fpr, tpr ) ) +
155
geom_line( color = rgb( 0, 0, 1, alpha = 0.3 ) ) +
156
geom_point( color = col_by_cost, size = 4, alpha = 0.2 ) +
157
geom_segment( aes( x = 0, y = 0, xend = 1, yend = 1 ), alpha = 0.8, color = "royalblue" ) +
158
labs( title = "ROC", x = "False Postive Rate", y = "True Positive Rate" ) +
159
geom_hline( yintercept = best_tpr, alpha = 0.8, linetype = "dashed", color = "steelblue4" ) +
160
geom_vline( xintercept = best_fpr, alpha = 0.8, linetype = "dashed", color = "steelblue4" )
161
162
cost_plot <- ggplot( cost_dt, aes( cutoff, cost ) ) +
163
geom_line( color = "blue", alpha = 0.5 ) +
164
geom_point( color = col_by_cost, size = 4, alpha = 0.5 ) +
165
ggtitle( "Cost" ) +
166
scale_y_continuous( labels = comma ) +
167
geom_vline( xintercept = best_cutoff, alpha = 0.8, linetype = "dashed", color = "steelblue4" )
168
169
# the main title for the two arranged plot
170
sub_title <- sprintf( "Cutoff at %.2f - Total Cost = %d, AUC = %.3f",
171
best_cutoff, best_cost, auc )
172
173
# arranged into a side by side plot
174
plot <- arrangeGrob( roc_plot, cost_plot, ncol = 2,
175
top = textGrob( sub_title, gp = gpar( fontsize = 16, fontface = "bold" ) ) )
176
177
return( list( plot = plot,
178
cutoff = best_cutoff,
179
totalcost = best_cost,
180
auc = auc,
181
sensitivity = best_tpr,
182
specificity = 1 - best_fpr ) )
183
}
184
185
186