Skip to content

Commit 1ec251b

Browse files
authored
Fix bug with labels in variable importance, add reference line, remove deprecation warning (#207)
* fix bug labels variable importance, add reference line * revert change
1 parent 77116d1 commit 1ec251b

File tree

3 files changed

+17
-21
lines changed

3 files changed

+17
-21
lines changed

pymc_bart/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from pymc_bart.utils import (
2020
compute_variable_importance,
2121
plot_convergence,
22-
plot_dependence,
2322
plot_ice,
2423
plot_pdp,
2524
plot_scatter_submodels,
@@ -35,14 +34,13 @@
3534
"SubsetSplitRule",
3635
"compute_variable_importance",
3736
"plot_convergence",
38-
"plot_dependence",
3937
"plot_ice",
4038
"plot_pdp",
4139
"plot_scatter_submodels",
4240
"plot_variable_importance",
4341
"plot_variable_inclusion",
4442
]
45-
__version__ = "0.8.0"
43+
__version__ = "0.8.1"
4644

4745

4846
pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART]

pymc_bart/utils.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -137,22 +137,6 @@ def plot_convergence(
137137
return ax
138138

139139

140-
def plot_dependence(*args, kind="pdp", **kwargs): # pylint: disable=unused-argument
141-
"""
142-
Partial dependence or individual conditional expectation plot.
143-
"""
144-
if kind == "pdp":
145-
warnings.warn(
146-
"This function has been deprecated. Use plot_pdp instead.",
147-
FutureWarning,
148-
)
149-
elif kind == "ice":
150-
warnings.warn(
151-
"This function has been deprecated. Use plot_ice instead.",
152-
FutureWarning,
153-
)
154-
155-
156140
def plot_ice(
157141
bartrv: Variable,
158142
X: npt.NDArray[np.float64],
@@ -307,6 +291,7 @@ def plot_pdp(
307291
var_discrete: Optional[list[int]] = None,
308292
func: Optional[Callable] = None,
309293
samples: int = 200,
294+
ref_line: bool = True,
310295
random_seed: Optional[int] = None,
311296
sharey: bool = True,
312297
smooth: bool = True,
@@ -347,6 +332,8 @@ def plot_pdp(
347332
Arbitrary function to apply to the predictions. Defaults to the identity function.
348333
samples : int
349334
Number of posterior samples used in the predictions. Defaults to 200
335+
ref_line : bool
336+
If True a reference line is plotted at the mean of the partial dependence. Defaults to True.
350337
random_seed : Optional[int], by default None.
351338
Seed used to sample from the posterior. Defaults to None.
352339
sharey : bool
@@ -402,6 +389,7 @@ def identity(x):
402389

403390
count = 0
404391
fake_X = _create_pdp_data(X, xs_interval, xs_values)
392+
null_pd = []
405393
for var in range(len(var_idx)):
406394
excluded = indices[:]
407395
excluded.remove(var)
@@ -413,6 +401,7 @@ def identity(x):
413401
new_x = fake_X[:, var]
414402
for s_i in range(shape):
415403
p_di = func(p_d[:, :, s_i])
404+
null_pd.append(p_di.mean())
416405
if var in var_discrete:
417406
_, idx_uni = np.unique(new_x, return_index=True)
418407
y_means = p_di.mean(0)[idx_uni]
@@ -442,6 +431,11 @@ def identity(x):
442431

443432
count += 1
444433

434+
if ref_line:
435+
ref_val = sum(null_pd) / len(null_pd)
436+
for ax_ in np.ravel(axes):
437+
ax_.axhline(ref_val, color="0.7", linestyle="--")
438+
445439
fig.text(-0.05, 0.5, y_label, va="center", rotation="vertical", fontsize=15)
446440

447441
return axes
@@ -949,11 +943,13 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
949943

950944
indices = least_important_vars[::-1]
951945

952-
labels = np.array(["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)])
946+
labels = np.array(
947+
["+ " + ele if index != 0 else ele for index, ele in enumerate(labels[indices])]
948+
)
953949

954950
vi_results = {
955951
"indices": np.asarray(indices),
956-
"labels": labels[indices],
952+
"labels": labels,
957953
"r2_mean": r2_mean,
958954
"r2_hdi": r2_hdi,
959955
"preds": preds,

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ line-length = 100
1010
select = ["E", "F", "I", "PL", "UP", "W"]
1111
ignore = [
1212
"PLR2004", # Checks for the use of unnamed numerical constants ("magic") values in comparisons.
13+
"PLR0913", #Too many arguments in function definition
14+
1315
]
1416

1517
[tool.ruff.lint.pylint]

0 commit comments

Comments
 (0)