Skip to content

Commit 43d6c3a

Browse files
authored
Update BART Hawks NBs with new plots (#775)
* Improve BART categorical with new plots * remove autoreload cell * Use func for softmax * handle func upstream
1 parent ccef86c commit 43d6c3a

File tree

2 files changed

+294
-185
lines changed

2 files changed

+294
-185
lines changed

examples/bart/bart_categorical_hawks.ipynb

+244-175
Large diffs are not rendered by default.

examples/bart/bart_categorical_hawks.myst.md

+50-10
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ jupytext:
55
format_name: myst
66
format_version: 0.13
77
kernelspec:
8-
display_name: Python 3 (ipykernel)
8+
display_name: pymc-examples
99
language: python
10-
name: python3
10+
name: pymc-examples
1111
myst:
1212
substitutions:
1313
conda_dependencies: pymc-bart
@@ -43,6 +43,8 @@ import pymc as pm
4343
import pymc_bart as pmb
4444
import seaborn as sns
4545
46+
from scipy.special import softmax
47+
4648
warnings.simplefilter(action="ignore", category=FutureWarning)
4749
```
4850

@@ -143,21 +145,48 @@ vi_results = pmb.compute_variable_importance(idata, μ, x_0, method="VI", random
143145
pmb.plot_variable_importance(vi_results);
144146
```
145147

146-
It can be observed that with the covariables `Hallux`, `Culmen`, and `Wing` we achieve the same R$^2$ value that we obtained with all the covariables, this is that the last two covariables contribute less than the other three to the classification. One thing we have to take into account in this is that the HDI is quite wide, which gives us less precision on the results, later we are going to see a way to reduce this.
148+
It can be observed that with the covariables `Hallux`, `Culmen`, and `Wing` we achieve the same $R^2$ value that we obtained with all the covariables, this is that the last two covariables contribute less than the other three to the classification. One thing we have to take into account in this is that the HDI is quite wide, which gives us less precision on the results; later we are going to see a way to reduce this.
147149

148-
+++
150+
We can also plot the scatter plot of the submodels' predictions to the full model's predictions to get an idea of how each new covariate improves the submodel's predictions.
151+
152+
```{code-cell} ipython3
153+
axes = pmb.plot_scatter_submodels(
154+
vi_results, grid=(5, 3), figsize=(12, 14), plot_kwargs={"alpha_scatter": 0.05}
155+
)
156+
plt.suptitle("Comparison of submodels' predictions to full model's\n", fontsize=18)
157+
for ax, cat in zip(axes, np.repeat(species, len(vi_results["labels"]))):
158+
ax.set(title=f"Species {cat}")
159+
```
149160

150161
### Partial Dependence Plot
151162

152-
Let's check the behavior of each covariable for each species with `pmb.plot_pdp()`, which shows the marginal effect a covariate has on the predicted variable, while we average over all the other covariates.
163+
Let's check the behavior of each covariable for each species with `pmb.plot_pdp()`, which shows the marginal effect a covariate has on the predicted variable, while we average over all the other covariates. Since our response variable is categorical, we'll pass `softmax` as the inverse link function to `plot_pdp`.
164+
165+
You can see we have to be careful with the `softmax` function, because it's not vectorized: it considers relationships between elements, so the specific axis along which we apply it matters. By default, scipy applies to all axes, but we want to apply it to the last axis, since that's where the categories are. To make sure of that, we use `np.apply_along_axis` and pass it in a lambda function.
153166

154167
```{code-cell} ipython3
155-
pmb.plot_pdp(μ, X=x_0, Y=y_0, grid=(5, 3), figsize=(12, 7));
168+
axes = pmb.plot_pdp(
169+
μ,
170+
X=x_0,
171+
Y=y_0,
172+
grid=(5, 3),
173+
figsize=(12, 12),
174+
func=lambda x: np.apply_along_axis(softmax, axis=-1, arr=x),
175+
)
176+
plt.suptitle("Partial Dependence Plots\n", fontsize=18)
177+
for (i, ax), cat in zip(enumerate(axes), np.tile(species, len(vi_results["labels"]))):
178+
ax.set(title=f"Species {cat}")
156179
```
157180

158-
The pdp plot, together with the Variable Importance plot, confirms that `Tail` is the covariable with the smaller effect over the predicted variable. In the Variable Importance plot `Tail` is the last covariable to be added and does not improve the result, in the pdp plot `Tail` has the flattest response. For the rest of the covariables in this plot, it's hard to see which of them have more effect over the predicted variable, because they have great variability, showed in the HDI wide, same as before later we are going to see a way to reduce this variability. Finally, some variability depends on the amount of data for each species, which we can see in the `counts` from one of the covariables using Pandas `.describe()` and grouping the data from "Species" with `.groupby("Species")`.
181+
The Partial Dependence Plot, together with the Variable Importance plot, confirms that `Tail` is the covariable with the smaller effect over the predicted variable: in the Variable Importance plot, `Tail` is the last covariate to be added and does not improve the result; in the PDP plot `Tail` has the flattest response.
159182

160-
+++
183+
For the rest of the covariate in this plot, it's hard to see which of them have more effect over the predicted variable, because they have great variability, showed in the HDI width.
184+
185+
Finally, some variability depends on the amount of data for each species, which we can see in the `counts` of each covariable for each species:
186+
187+
```{code-cell} ipython3
188+
Hawks.groupby("Species").count()
189+
```
161190

162191
### Predicted vs Observed
163192

@@ -222,10 +251,20 @@ pmb.plot_variable_importance(vi_results);
222251
```
223252

224253
```{code-cell} ipython3
225-
pmb.plot_pdp(μ_t, X=x_0, Y=y_0, grid=(5, 3), figsize=(12, 7));
254+
axes = pmb.plot_pdp(
255+
μ_t,
256+
X=x_0,
257+
Y=y_0,
258+
grid=(5, 3),
259+
figsize=(12, 12),
260+
func=lambda x: np.apply_along_axis(softmax, axis=-1, arr=x),
261+
)
262+
plt.suptitle("Partial Dependence Plots\n", fontsize=18)
263+
for (i, ax), cat in zip(enumerate(axes), np.tile(species, len(vi_results["labels"]))):
264+
ax.set(title=f"Species {cat}")
226265
```
227266

228-
Comparing these two plots with the previous ones shows a marked reduction in the variance for each one. In the case of `pmb.plot_variable_importance()` there are smallers error bands with an R$^{2}$ value more close to 1. And for `pm.plot_pdp()` we can see thinner bands and a reduction in the limits on the y-axis, this is a representation of the reduction of the uncertainty due to adjusting the trees separately. Another benefit of this is that is more visible the behavior of each covariable for each one of the species.
267+
Comparing these two plots with the previous ones shows a marked reduction in the variance for each one. In the case of `pmb.plot_variable_importance()` there are smallers error bands with an $R^{2}$ value closer to 1. And for `pmb.plot_pdp()` we can see thinner HDI bands. This is a representation of the reduction of the uncertainty due to adjusting the trees separately. Another benefit of this is that the behavior of each covariable for each one of the species is more visible.
229268

230269
With all these together, we can select `Hallux`, `Culmen`, and, `Wing` as covariables to make the classification.
231270

@@ -259,6 +298,7 @@ all
259298
## Authors
260299
- Authored by [Pablo Garay](https://github.com/PabloGGaray) and [Osvaldo Martin](https://aloctavodia.github.io/) in May, 2024
261300
- Updated by Osvaldo Martin in Dec, 2024
301+
- Expanded by [Alex Andorra](https://github.com/AlexAndorra) in Feb, 2025
262302

263303
+++
264304

0 commit comments

Comments
 (0)