@@ -137,22 +137,6 @@ def plot_convergence(
137
137
return ax
138
138
139
139
140
- def plot_dependence (* args , kind = "pdp" , ** kwargs ): # pylint: disable=unused-argument
141
- """
142
- Partial dependence or individual conditional expectation plot.
143
- """
144
- if kind == "pdp" :
145
- warnings .warn (
146
- "This function has been deprecated. Use plot_pdp instead." ,
147
- FutureWarning ,
148
- )
149
- elif kind == "ice" :
150
- warnings .warn (
151
- "This function has been deprecated. Use plot_ice instead." ,
152
- FutureWarning ,
153
- )
154
-
155
-
156
140
def plot_ice (
157
141
bartrv : Variable ,
158
142
X : npt .NDArray [np .float64 ],
@@ -307,6 +291,7 @@ def plot_pdp(
307
291
var_discrete : Optional [list [int ]] = None ,
308
292
func : Optional [Callable ] = None ,
309
293
samples : int = 200 ,
294
+ ref_line : bool = True ,
310
295
random_seed : Optional [int ] = None ,
311
296
sharey : bool = True ,
312
297
smooth : bool = True ,
@@ -347,6 +332,8 @@ def plot_pdp(
347
332
Arbitrary function to apply to the predictions. Defaults to the identity function.
348
333
samples : int
349
334
Number of posterior samples used in the predictions. Defaults to 200
335
+ ref_line : bool
336
+ If True a reference line is plotted at the mean of the partial dependence. Defaults to True.
350
337
random_seed : Optional[int], by default None.
351
338
Seed used to sample from the posterior. Defaults to None.
352
339
sharey : bool
@@ -402,6 +389,7 @@ def identity(x):
402
389
403
390
count = 0
404
391
fake_X = _create_pdp_data (X , xs_interval , xs_values )
392
+ null_pd = []
405
393
for var in range (len (var_idx )):
406
394
excluded = indices [:]
407
395
excluded .remove (var )
@@ -413,6 +401,7 @@ def identity(x):
413
401
new_x = fake_X [:, var ]
414
402
for s_i in range (shape ):
415
403
p_di = func (p_d [:, :, s_i ])
404
+ null_pd .append (p_di .mean ())
416
405
if var in var_discrete :
417
406
_ , idx_uni = np .unique (new_x , return_index = True )
418
407
y_means = p_di .mean (0 )[idx_uni ]
@@ -442,6 +431,11 @@ def identity(x):
442
431
443
432
count += 1
444
433
434
+ if ref_line :
435
+ ref_val = sum (null_pd ) / len (null_pd )
436
+ for ax_ in np .ravel (axes ):
437
+ ax_ .axhline (ref_val , color = "0.7" , linestyle = "--" )
438
+
445
439
fig .text (- 0.05 , 0.5 , y_label , va = "center" , rotation = "vertical" , fontsize = 15 )
446
440
447
441
return axes
@@ -949,11 +943,13 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
949
943
950
944
indices = least_important_vars [::- 1 ]
951
945
952
- labels = np .array (["+ " + ele if index != 0 else ele for index , ele in enumerate (labels )])
946
+ labels = np .array (
947
+ ["+ " + ele if index != 0 else ele for index , ele in enumerate (labels [indices ])]
948
+ )
953
949
954
950
vi_results = {
955
951
"indices" : np .asarray (indices ),
956
- "labels" : labels [ indices ] ,
952
+ "labels" : labels ,
957
953
"r2_mean" : r2_mean ,
958
954
"r2_hdi" : r2_hdi ,
959
955
"preds" : preds ,
0 commit comments