Skip to content

Commit e458f76

Browse files
author
Ramtin Zargari Marandi
committed
Updated SHAP hovering information
1 parent b11985a commit e458f76

16 files changed

Lines changed: 355 additions & 238 deletions

R/SHAPclust.R

Lines changed: 49 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -115,35 +115,34 @@ SHAPclust <- function(task,
115115
algorithm="Hartigan-Wong",
116116
iter.max = 1000
117117
){
118-
119-
prediction_correctness <- NULL
120-
truth <- NULL
121-
response <- NULL
118+
cluster <- NULL
119+
correct_prediction <- NULL
120+
feature <- NULL
121+
f_val <- NULL
122122
fval <- NULL
123-
variable <- NULL
124123
mean_absolute_shap <- NULL
125-
feature <- NULL
126-
value <- NULL
124+
mean_phi <- NULL
125+
Phi <- NULL
126+
pred_class <- NULL
127+
pred_prob <- NULL
128+
prediction_correctness <- NULL
129+
response <- NULL
127130
sample_num <- NULL
128-
cluster <- NULL
131+
truth <- NULL
132+
unscaled_f_val <- NULL
133+
value <- NULL
134+
variable <- NULL
135+
129136
mydata <- task$data()
130137
# randomly subset the target variable and the corresponding rows
131-
if (subset < 1) {
132-
set.seed(seed) # set seed for reproducibility
133-
n <- round(subset * length(splits$test))
134-
target_index <- sample(splits$test, size = n, replace = FALSE)
135-
mydata <- mydata[target_index, ]
136-
137-
# do the prediction for the test set
138-
pred_results <- trained_model$predict(task,target_index)
139138

140-
} else {
141-
mydata <- mydata[splits$test, ]
139+
set.seed(seed) # set seed for reproducibility
140+
n <- round(subset * length(splits$test))
141+
target_index <- sample(splits$test, size = n, replace = FALSE)
142+
mydata <- mydata[target_index, ]
142143

143-
# do the prediction for the test set
144-
pred_results <- trained_model$predict(task,splits$test)
145-
146-
}
144+
# do the prediction for the test set
145+
pred_results <- trained_model$predict(task,target_index)
147146

148147
# the test set based on the data split is used to calculate SHAP values
149148
test_set <- as.data.frame(mydata)
@@ -165,39 +164,44 @@ SHAPclust <- function(task,
165164
shap_Mean_wide_kmeans$row_ids <- shap_Mean_wide_kmeans$row_ids - shap_Mean_wide_kmeans$row_ids[1] + 1
166165
shap_Mean_wide_kmeans[, prediction_correctness := (truth == response)]
167166
shap_Mean_wide_kmeans_forCM <- shap_Mean_wide_kmeans
168-
shap_Mean_wide_kmeans[,c(1,2,3,4,5)] <- NULL
167+
168+
shap_Mean_wide_kmeans[,c(1,2,5)] <- NULL # ,3,4
169+
colnames(shap_Mean_wide_kmeans)[2] <- "prob_positive_class"
169170
variables_for_long_format <- colnames(shap_Mean_wide_kmeans)
170171

171-
variables_for_long_format <- variables_for_long_format[!variables_for_long_format %in% c(colnames(pred_results), "sample_num", "prediction_correctness", "cluster")]
172+
variables_for_long_format <- variables_for_long_format[!variables_for_long_format %in% c("sample_num", "prediction_correctness", "cluster","response","prob_positive_class")]
172173

173174
# Melt the data.table from wide to long format
174-
dt_long <- data.table::melt(shap_Mean_wide_kmeans, id.vars = c("sample_num","prediction_correctness","cluster"),
175+
dt_long <- data.table::melt(shap_Mean_wide_kmeans,
176+
id.vars = c("sample_num", "prediction_correctness", "cluster","response","prob_positive_class"),
175177
measure.vars = variables_for_long_format,
176178
variable.name = "variable",
177179
value.name = "value")
178180

179-
dt_long$fval <- NA
180-
dt_long_vars <- as.character(dt_long$variable)
181-
for (i in 1:nrow(dt_long)){
182-
dt_long$mean_absolute_shap[i] <- mean(abs(dt_long$value[dt_long$variable==dt_long$variable[i]]))
183-
idx <- which(row(test_set)[,1]==dt_long$sample_num[i])
184-
dt_long$fval[i] <- test_set[idx, which(colnames(test_set)==dt_long_vars[idx])]
185-
}
186-
187-
dt_long$fval <- as.numeric(dt_long$fval)
188-
dt_long[, fval := lapply(.SD, range01), by = variable, .SDcols = "fval"]
181+
# Remove specified columns
182+
dt_long[, c("response", "prob_positive_class", "prediction_correctness") := NULL]
183+
# Rename columns
184+
names(dt_long)[names(dt_long) == "variable"] <- "feature"
185+
names(dt_long)[names(dt_long) == "value"] <- "Phi"
189186

187+
# Merge the two dataframes
188+
dt_long <- merge(dt_long, shap_Mean_long, by = c("sample_num", "feature", "Phi"))
189+
print(dt_long)
190190
############## SHAP plots for clusters
191191
shap_plot1 <- dt_long %>%
192-
mutate(feature = forcats::fct_reorder(variable, mean_absolute_shap)) %>%
193-
ggplot(aes(x = feature, y = value, color = fval)) +
192+
mutate(feature = forcats::fct_reorder(feature, mean_phi)) %>%
193+
ggplot(aes(x = feature, y = Phi, color = f_val))+
194194
geom_violin(colour = "grey") +
195-
geom_line(aes(group = sample_num), alpha = 0.1, size = 0.2) +
195+
geom_line(aes(group = sample_num), alpha = 0.1,size=0.2) +
196196
coord_flip() +
197-
geom_jitter(aes(shape = factor(prediction_correctness, levels = c(FALSE, TRUE), labels = c("Incorrect","Correct"))), alpha = 0.6, size = 1.5, position = position_jitter(width = 0.2, height = 0)) +
198-
# geom_jitter(aes(shape = factor(prediction_correctness)), alpha = 0.6, size = 1, position = position_jitter(width = 0.2, height = 0)) +
197+
geom_jitter(aes(shape=correct_prediction, text = paste("Feature: ", feature,
198+
"<br>Unscaled feature value: ", unscaled_f_val,
199+
"<br>SHAP value: ", Phi,
200+
"<br>Prediction correctness: ", correct_prediction,
201+
"<br>Predicted probability: ", pred_prob,
202+
"<br>Predicted class: ", pred_class)),
203+
alpha = 0.6, size=1.5, position=position_jitter(width=0.2, height=0)) +
199204
scale_shape_manual(values = c(4, 19)) + # 19 for correct predictions (circle), 4 for incorrect predictions (cross)
200-
# labs(shape = "Prediction Correctness") +
201205
labs(shape = "model prediction") +
202206
scale_colour_gradient2(low = "blue", mid = "green", high = "red", midpoint = 0.5, breaks = c(0, 1), labels = c("Low", "High")) +
203207
geom_text(aes(x = feature, y = -Inf, label = ""), hjust = -0.2, alpha = 0.7, color = "black") +
@@ -207,19 +211,19 @@ SHAPclust <- function(task,
207211
theme(text = element_text(size = 8, family = "Helvetica"), panel.border = element_blank(),
208212
panel.grid.major = element_blank(), panel.grid.minor = element_blank(), panel.background = element_blank(),
209213
axis.line = element_line(colour = "grey"), legend.key.width = grid::unit(2, "mm")) +
210-
ylim(min(dt_long$value) - 0.05, max(dt_long$value) + 0.05) +
214+
ylim(min(dt_long$Phi) - 0.05, max(dt_long$Phi) + 0.05) +
211215
guides(
212216
shape = ggplot2::guide_legend(color = "black")
213217
)
214218

215219
shap_plot_onerow <- shap_plot1 + facet_wrap(~ cluster, ncol = num_of_clusters)
216220

217-
shap_plot_onerow <- ggplotly(shap_plot_onerow)
221+
shap_plot_onerow <- ggplotly(shap_plot_onerow, tooltip="text")
218222

219223
CM_plt <- list()
220224
# Create a tibble for each cluster and calculate the confusion matrix for each cluster
221225
for (i in 1:num_of_clusters) {
222-
d_binomial <- tibble("Truth" = shap_Mean_wide_kmeans_forCM$truth[which(shap_Mean_wide_kmeans_forCM$cluster==i)],
226+
d_binomial <- tibble::tibble("Truth" = shap_Mean_wide_kmeans_forCM$truth[which(shap_Mean_wide_kmeans_forCM$cluster==i)],
223227
"Prediction" = shap_Mean_wide_kmeans_forCM$response[which(shap_Mean_wide_kmeans_forCM$cluster==i)])
224228
cvms::confusion_matrix(targets = d_binomial$Truth, predictions = d_binomial$Prediction)
225229
# basic_table <- table(d_binomial)
@@ -228,7 +232,7 @@ SHAPclust <- function(task,
228232

229233
cm_tbl <- data.frame(matrix(nrow = 4, ncol = 3))
230234
colnames(cm_tbl) <- c("Target", "Prediction", "N")
231-
cm_tbl <- as_tibble(cm_tbl)
235+
cm_tbl <- tibble::as_tibble(cm_tbl)
232236
cm_tbl[1:2,1] <- levels(d_binomial$Truth)[1]
233237
cm_tbl[3:4,1] <- levels(d_binomial$Truth)[2]
234238
cm_tbl[1,2] <- levels(d_binomial$Truth)[1]

R/eCM_plot.R

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ eCM_plot <- function(task,
7171
featset_total_test <- as.data.frame(featset_total_test)
7272
pred_results <- trained_model$predict(task, splits$test)
7373
# plot confusion matrix
74-
d_binomial <- tibble("Truth" = featset_total_test[, task$target_names],
74+
d_binomial <- tibble::tibble("Truth" = featset_total_test[, task$target_names],
7575
"Prediction" = pred_results$response)
7676
basic_table <- table(d_binomial)
7777
cfm <- tibble::as_tibble(basic_table)
@@ -84,7 +84,7 @@ eCM_plot <- function(task,
8484
palette = "Oranges",
8585
label = "Total",
8686
tc_tile_border_color = "black"
87-
))
87+
)) + ggtitle("Confusion matrix for the train set")
8888
CM_plt_test[["labels"]][["x"]] <- 'Truth (observation)'
8989
CM_plt_test[["labels"]][["y"]] <- 'Prediction (model output)'
9090
CM_plt_test[["theme"]][["text"]][["size"]] <- 9
@@ -95,7 +95,7 @@ eCM_plot <- function(task,
9595
featset_total_train <- mydata[splits$train,]
9696
featset_total_train <- as.data.frame(featset_total_train)
9797
pred_results <- trained_model$predict(task, splits$train)
98-
d_binomial <- tibble("Truth" = featset_total_train[, task$target_names],
98+
d_binomial <- tibble::tibble("Truth" = featset_total_train[, task$target_names],
9999
"Prediction" = pred_results$response)
100100
basic_table <- table(d_binomial)
101101
# cfm <- broom::tidy(basic_table)
@@ -109,17 +109,12 @@ eCM_plot <- function(task,
109109
palette = "Oranges",
110110
label = "Total",
111111
tc_tile_border_color = "black"
112-
))
112+
)) + ggtitle("Confusion matrix for the test set")
113113
CM_plt_train[["labels"]][["x"]] <- 'Truth (observation)'
114114
CM_plt_train[["labels"]][["y"]] <- 'Prediction (model output)'
115115
CM_plt_train[["theme"]][["text"]][["size"]] <- 9
116116
CM_plt_train[["theme"]][["axis.text"]][["size"]] <- 9
117117
# CM_plt_train[["theme"]][["text"]][["family"]] <- 'Helvetica'
118-
119-
CM_plt_both <- egg::ggarrange(CM_plt_train,
120-
CM_plt_test,
121-
labels = c("train set", "test set"),
122-
nrow = 1,
123-
ncol = 2)
124-
return(CM_plt_both)
118+
# Return a list containing both plots
119+
return(list(train_set = CM_plt_train, test_set = CM_plt_test))
125120
}

R/eSHAP_plot.R

Lines changed: 64 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -73,56 +73,60 @@ eSHAP_plot <- function(task,
7373
sample.size = 30,
7474
seed = 246,
7575
subset = 1) {
76-
77-
# utils::globalVariables(c("feature", "sample_num", "correct_prediction"))
76+
cluster <- NULL
77+
correct_prediction <- NULL
7878
feature <- NULL
79+
f_val <- NULL
80+
fval <- NULL
81+
mean_absolute_shap <- NULL
82+
mean_phi <- NULL
83+
Phi <- NULL
84+
pred_class <- NULL
85+
pred_prob <- NULL
86+
prediction_correctness <- NULL
87+
response <- NULL
7988
sample_num <- NULL
80-
correct_prediction <- NULL
81-
# library(ggplot2)
89+
truth <- NULL
90+
unscaled_f_val <- NULL
91+
92+
set.seed(seed) # set seed for reproducibility
8293
mydata <- task$data()
8394
mydata <- as.data.frame(mydata)
8495
X <- mydata[which(names(mydata[splits$train,]) != task$target_names)]
8596
model <- iml::Predictor$new(trained_model, data = X, y = mydata[, task$target_names])
86-
# randomly subset the target variable and the corresponding rows
87-
if (subset < 1) {
88-
set.seed(seed) # set seed for reproducibility
89-
n <- round(subset * length(splits$test))
90-
target_index <- sample(splits$test, size = n, replace = FALSE)
91-
mydata <- mydata[target_index, ]
92-
93-
# do the prediction for the test set
94-
pred_results <- trained_model$predict(task,target_index)
95-
96-
} else {
97-
mydata <- mydata[splits$test, ]
9897

99-
# do the prediction for the test set
100-
pred_results <- trained_model$predict(task,splits$test)
101-
102-
}
98+
# randomly subset the target variable and the corresponding rows
99+
n <- round(subset * length(splits$test))
100+
target_index <- sample(splits$test, size = n, replace = FALSE)
101+
mydata <- mydata[target_index, ]
102+
# do the prediction for the test set
103+
pred_results <- trained_model$predict(task,target_index)
103104

104105
# the test set based on the data split is used to calculate SHAP values
105106
test_set <- as.data.frame(mydata)
106107
feature_names <- colnames(X)
107108
nfeats <- length(feature_names)
108109

110+
# print(pred_results)
111+
# print(pred_results$prob)
112+
# save the predicted probability for the positive class
113+
pred_prob <- pred_results$prob[,1]
109114

110-
111-
# save the predicted probability for the positive class (assuming with have a binary classification task)
112-
pred_prob <- pred_results$prob[,2]
113115
# mark which samples were correctly predicted and which samples were not
114116
predicted_correct <- mydata$Class==pred_results$response
115117

116-
test_set.nolab <- mydata
118+
# test_set.nolab <- mydata
117119
# initialize the results list.
118120
shap_values <- vector("list", nrow(test_set))
119121
for (i in seq_along(shap_values)) {
120-
set.seed(seed)
121-
shap_values[[i]] <- iml::Shapley$new(model, x.interest = test_set[i,feature_names],
122+
# set.seed(seed)
123+
shap_values[[i]] <- iml::Shapley$new(model,
124+
x.interest = test_set[i,feature_names],
122125
sample.size = sample.size)$results
123126
shap_values[[i]]$sample_num <- i # identifier to track our instances.
124127
shap_values[[i]]$predcorrectness <- predicted_correct[i]
125128
shap_values[[i]]$pred_prob <- pred_prob[i]
129+
shap_values[[i]]$pred_class <- pred_results$response[i]
126130
}
127131
data_shap_values <- dplyr::bind_rows(shap_values) # collapse the list.
128132

@@ -134,6 +138,7 @@ eSHAP_plot <- function(task,
134138
f_val_lst <- rep(0,nfeats)
135139
indiv_correctness <- rep(0,nfeats)
136140
pred_prob_rep <- rep(0,nfeats)
141+
pred_class_rep <- rep(0,nfeats)
137142

138143
feature_values <- gsub(".*=",'',shap$feature.value)
139144
shap$feature.value <- as.numeric(feature_values)
@@ -143,49 +148,63 @@ eSHAP_plot <- function(task,
143148
f_val_lst[i] = list(feature_values[seq(i,nrow(shap),nfeats)])
144149
indiv_correctness[i] = list(shap$predcorrectness[seq(i,nrow(shap),nfeats)])
145150
pred_prob_rep[i] = list(shap$pred_prob[seq(i,nrow(shap),nfeats)])
151+
pred_class_rep[i] = list(shap$pred_class[seq(i,nrow(shap),nfeats)])
146152
}
147153

148-
149154
# test_set.nolab[,task$target_names:=NULL]
150-
test_set.nolab[,task$target_names] <- NULL
155+
mydata[,task$target_names] <- NULL
151156
# get the column names of the data frame
152-
cols <- colnames(test_set.nolab)
157+
cols <- colnames(mydata)
153158

154159
# loop through each column
155160
for (col in cols) {
156161
# check if the column is numeric
157-
if (!is.numeric(test_set.nolab[[col]])) {
162+
if (!is.numeric(mydata[[col]])) {
158163
# convert non-numeric columns to numeric
159-
test_set.nolab[[col]] <- as.numeric(test_set.nolab[[col]])
164+
mydata[[col]] <- as.numeric(mydata[[col]])
160165
}
161166
}
167+
# store feature values
168+
unscaled_f_val_lst <- f_val_lst
162169

163170
# apply transformation for visualization
164171
for (i in 1:length(f_val_lst)){
165-
f_val_lst[[i]] <- range01(test_set.nolab[,i])
172+
unscaled_f_val_lst[[i]] <- mydata[,i] # not scaled
173+
f_val_lst[[i]] <- range01(mydata[,i]) # normalization
166174
}
167175

176+
(unscaled_f_val = as.numeric(unlist(unscaled_f_val_lst)))
168177
(f_val = as.numeric(unlist(f_val_lst)))
169178
(Phi = unlist(indiv_phi))
170179

171180
shap_Mean <- data.table::data.table(feature=rep(feature_names,each=total_reps),
172181
mean_phi = rep(mean_phi,each=total_reps),
173182
Phi = Phi,
174183
f_val = f_val,
184+
unscaled_f_val = unscaled_f_val,
175185
sample_num = rep(1:nrow(test_set),length(feature_names)),
176186
correct_prediction = unlist(indiv_correctness),
177-
pred_prob = unlist(pred_prob_rep))
187+
pred_prob = unlist(pred_prob_rep),
188+
pred_class = unlist(pred_class_rep))
178189

179190
shap_Mean_wide <- data.table::dcast(shap_Mean, sample_num ~ feature, value.var="Phi")
180191

181192
shap_Mean$correct_prediction <- factor(shap_Mean$correct_prediction, levels = c(FALSE, TRUE), labels = c("Incorrect","Correct"))
193+
194+
182195
shap_plot <- shap_Mean %>%
183196
mutate(feature = forcats::fct_reorder(feature, mean_phi)) %>%
184197
ggplot(aes(x = feature, y = Phi, color = f_val))+
185198
geom_violin(colour = "grey") +
186-
geom_line(aes(group = sample_num), alpha = 0.1,size=0.2) +
199+
geom_line(aes(group = sample_num), alpha = 0.1, size=0.2) +
187200
coord_flip() +
188-
geom_jitter(alpha = 0.6,size=1.5, position=position_jitter(width=0.2, height=0),aes(shape=correct_prediction)) +
201+
geom_jitter(aes(shape=correct_prediction, text = paste("Feature: ", feature,
202+
"<br>Unscaled feature value: ", unscaled_f_val,
203+
"<br>SHAP value: ", Phi,
204+
"<br>Prediction correctness: ", correct_prediction,
205+
"<br>Predicted probability: ", pred_prob,
206+
"<br>Predicted class: ", pred_class)),
207+
alpha = 0.6, size=1.5, position=position_jitter(width=0.2, height=0)) +
189208
scale_shape_manual(values=c(4, 19), guide = FALSE)+
190209
# scale_color_manual(values=c("black","grey")) +
191210
labs(shape = "model prediction") +
@@ -209,8 +228,16 @@ eSHAP_plot <- function(task,
209228
legend.key.width = grid::unit(2,"mm")) +
210229
ylim(min(shap_Mean$Phi)-0.05, max(shap_Mean$Phi)+0.05)
211230

231+
# Convert ggplot to Plotly
232+
shap_plot <- ggplotly(shap_plot, tooltip="text")
212233

213-
shap_plot <- ggplotly(shap_plot)
234+
# Additional plot to show SHAP values vs. predicted probabilities
235+
shap_pred_plot <- shap_Mean %>%
236+
ggplot(aes(x = Phi, y = pred_prob, shape=pred_class)) +
237+
geom_point() +
238+
geom_smooth(method = "loess", se = FALSE) +
239+
labs(x = "SHAP value", y = "Predicted probability") +
240+
theme_minimal()
214241

215-
return(list(shap_plot, shap_Mean_wide, shap_Mean, shap))
242+
return(list(shap_plot, shap_Mean_wide, shap_Mean, shap, shap_pred_plot))
216243
}

0 commit comments

Comments
 (0)