@@ -73,56 +73,60 @@ eSHAP_plot <- function(task,
7373 sample.size = 30 ,
7474 seed = 246 ,
7575 subset = 1 ) {
76-
77- # utils::globalVariables(c("feature", "sample_num", "correct_prediction"))
76+ cluster <- NULL
77+ correct_prediction <- NULL
7878 feature <- NULL
79+ f_val <- NULL
80+ fval <- NULL
81+ mean_absolute_shap <- NULL
82+ mean_phi <- NULL
83+ Phi <- NULL
84+ pred_class <- NULL
85+ pred_prob <- NULL
86+ prediction_correctness <- NULL
87+ response <- NULL
7988 sample_num <- NULL
80- correct_prediction <- NULL
81- # library(ggplot2)
89+ truth <- NULL
90+ unscaled_f_val <- NULL
91+
92+ set.seed(seed ) # set seed for reproducibility
8293 mydata <- task $ data()
8394 mydata <- as.data.frame(mydata )
8495 X <- mydata [which(names(mydata [splits $ train ,]) != task $ target_names )]
8596 model <- iml :: Predictor $ new(trained_model , data = X , y = mydata [, task $ target_names ])
86- # randomly subset the target variable and the corresponding rows
87- if (subset < 1 ) {
88- set.seed(seed ) # set seed for reproducibility
89- n <- round(subset * length(splits $ test ))
90- target_index <- sample(splits $ test , size = n , replace = FALSE )
91- mydata <- mydata [target_index , ]
92-
93- # do the prediction for the test set
94- pred_results <- trained_model $ predict(task ,target_index )
95-
96- } else {
97- mydata <- mydata [splits $ test , ]
9897
99- # do the prediction for the test set
100- pred_results <- trained_model $ predict(task ,splits $ test )
101-
102- }
98+ # randomly subset the target variable and the corresponding rows
99+ n <- round(subset * length(splits $ test ))
100+ target_index <- sample(splits $ test , size = n , replace = FALSE )
101+ mydata <- mydata [target_index , ]
102+ # do the prediction for the test set
103+ pred_results <- trained_model $ predict(task ,target_index )
103104
104105 # the test set based on the data split is used to calculate SHAP values
105106 test_set <- as.data.frame(mydata )
106107 feature_names <- colnames(X )
107108 nfeats <- length(feature_names )
108109
110+ # print(pred_results)
111+ # print(pred_results$prob)
112+ # save the predicted probability for the positive class
113+ pred_prob <- pred_results $ prob [,1 ]
109114
110-
111- # save the predicted probability for the positive class (assuming with have a binary classification task)
112- pred_prob <- pred_results $ prob [,2 ]
113115 # mark which samples were correctly predicted and which samples were not
114116 predicted_correct <- mydata $ Class == pred_results $ response
115117
116- test_set.nolab <- mydata
118+ # test_set.nolab <- mydata
117119 # initialize the results list.
118120 shap_values <- vector(" list" , nrow(test_set ))
119121 for (i in seq_along(shap_values )) {
120- set.seed(seed )
121- shap_values [[i ]] <- iml :: Shapley $ new(model , x.interest = test_set [i ,feature_names ],
122+ # set.seed(seed)
123+ shap_values [[i ]] <- iml :: Shapley $ new(model ,
124+ x.interest = test_set [i ,feature_names ],
122125 sample.size = sample.size )$ results
123126 shap_values [[i ]]$ sample_num <- i # identifier to track our instances.
124127 shap_values [[i ]]$ predcorrectness <- predicted_correct [i ]
125128 shap_values [[i ]]$ pred_prob <- pred_prob [i ]
129+ shap_values [[i ]]$ pred_class <- pred_results $ response [i ]
126130 }
127131 data_shap_values <- dplyr :: bind_rows(shap_values ) # collapse the list.
128132
@@ -134,6 +138,7 @@ eSHAP_plot <- function(task,
134138 f_val_lst <- rep(0 ,nfeats )
135139 indiv_correctness <- rep(0 ,nfeats )
136140 pred_prob_rep <- rep(0 ,nfeats )
141+ pred_class_rep <- rep(0 ,nfeats )
137142
138143 feature_values <- gsub(" .*=" ,' ' ,shap $ feature.value )
139144 shap $ feature.value <- as.numeric(feature_values )
@@ -143,49 +148,63 @@ eSHAP_plot <- function(task,
143148 f_val_lst [i ] = list (feature_values [seq(i ,nrow(shap ),nfeats )])
144149 indiv_correctness [i ] = list (shap $ predcorrectness [seq(i ,nrow(shap ),nfeats )])
145150 pred_prob_rep [i ] = list (shap $ pred_prob [seq(i ,nrow(shap ),nfeats )])
151+ pred_class_rep [i ] = list (shap $ pred_class [seq(i ,nrow(shap ),nfeats )])
146152 }
147153
148-
149154 # test_set.nolab[,task$target_names:=NULL]
150- test_set.nolab [,task $ target_names ] <- NULL
155+ mydata [,task $ target_names ] <- NULL
151156 # get the column names of the data frame
152- cols <- colnames(test_set.nolab )
157+ cols <- colnames(mydata )
153158
154159 # loop through each column
155160 for (col in cols ) {
156161 # check if the column is numeric
157- if (! is.numeric(test_set.nolab [[col ]])) {
162+ if (! is.numeric(mydata [[col ]])) {
158163 # convert non-numeric columns to numeric
159- test_set.nolab [[col ]] <- as.numeric(test_set.nolab [[col ]])
164+ mydata [[col ]] <- as.numeric(mydata [[col ]])
160165 }
161166 }
167+ # store feature values
168+ unscaled_f_val_lst <- f_val_lst
162169
163170 # apply transformation for visualization
164171 for (i in 1 : length(f_val_lst )){
165- f_val_lst [[i ]] <- range01(test_set.nolab [,i ])
172+ unscaled_f_val_lst [[i ]] <- mydata [,i ] # not scaled
173+ f_val_lst [[i ]] <- range01(mydata [,i ]) # normalization
166174 }
167175
176+ (unscaled_f_val = as.numeric(unlist(unscaled_f_val_lst )))
168177 (f_val = as.numeric(unlist(f_val_lst )))
169178 (Phi = unlist(indiv_phi ))
170179
171180 shap_Mean <- data.table :: data.table(feature = rep(feature_names ,each = total_reps ),
172181 mean_phi = rep(mean_phi ,each = total_reps ),
173182 Phi = Phi ,
174183 f_val = f_val ,
184+ unscaled_f_val = unscaled_f_val ,
175185 sample_num = rep(1 : nrow(test_set ),length(feature_names )),
176186 correct_prediction = unlist(indiv_correctness ),
177- pred_prob = unlist(pred_prob_rep ))
187+ pred_prob = unlist(pred_prob_rep ),
188+ pred_class = unlist(pred_class_rep ))
178189
179190 shap_Mean_wide <- data.table :: dcast(shap_Mean , sample_num ~ feature , value.var = " Phi" )
180191
181192 shap_Mean $ correct_prediction <- factor (shap_Mean $ correct_prediction , levels = c(FALSE , TRUE ), labels = c(" Incorrect" ," Correct" ))
193+
194+
182195 shap_plot <- shap_Mean %> %
183196 mutate(feature = forcats :: fct_reorder(feature , mean_phi )) %> %
184197 ggplot(aes(x = feature , y = Phi , color = f_val ))+
185198 geom_violin(colour = " grey" ) +
186- geom_line(aes(group = sample_num ), alpha = 0.1 ,size = 0.2 ) +
199+ geom_line(aes(group = sample_num ), alpha = 0.1 , size = 0.2 ) +
187200 coord_flip() +
188- geom_jitter(alpha = 0.6 ,size = 1.5 , position = position_jitter(width = 0.2 , height = 0 ),aes(shape = correct_prediction )) +
201+ geom_jitter(aes(shape = correct_prediction , text = paste(" Feature: " , feature ,
202+ " <br>Unscaled feature value: " , unscaled_f_val ,
203+ " <br>SHAP value: " , Phi ,
204+ " <br>Prediction correctness: " , correct_prediction ,
205+ " <br>Predicted probability: " , pred_prob ,
206+ " <br>Predicted class: " , pred_class )),
207+ alpha = 0.6 , size = 1.5 , position = position_jitter(width = 0.2 , height = 0 )) +
189208 scale_shape_manual(values = c(4 , 19 ), guide = FALSE )+
190209 # scale_color_manual(values=c("black","grey")) +
191210 labs(shape = " model prediction" ) +
@@ -209,8 +228,16 @@ eSHAP_plot <- function(task,
209228 legend.key.width = grid :: unit(2 ," mm" )) +
210229 ylim(min(shap_Mean $ Phi )- 0.05 , max(shap_Mean $ Phi )+ 0.05 )
211230
231+ # Convert ggplot to Plotly
232+ shap_plot <- ggplotly(shap_plot , tooltip = " text" )
212233
213- shap_plot <- ggplotly(shap_plot )
234+ # Additional plot to show SHAP values vs. predicted probabilities
235+ shap_pred_plot <- shap_Mean %> %
236+ ggplot(aes(x = Phi , y = pred_prob , shape = pred_class )) +
237+ geom_point() +
238+ geom_smooth(method = " loess" , se = FALSE ) +
239+ labs(x = " SHAP value" , y = " Predicted probability" ) +
240+ theme_minimal()
214241
215- return (list (shap_plot , shap_Mean_wide , shap_Mean , shap ))
242+ return (list (shap_plot , shap_Mean_wide , shap_Mean , shap , shap_pred_plot ))
216243}
0 commit comments