Skip to content

Commit d4e8cad

Browse files
authored
bump release (#205)
* bump release * fix zip and new args
1 parent bcdf77d commit d4e8cad

File tree

5 files changed

+15
-9
lines changed

5 files changed

+15
-9
lines changed

pymc_bart/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
"plot_variable_importance",
4343
"plot_variable_inclusion",
4444
]
45-
__version__ = "0.7.1"
45+
__version__ = "0.8.0"
4646

4747

4848
pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART]

pymc_bart/bart.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def get_moment(rv, size, *rv_inputs):
175175
return cls.get_moment(rv, size, *rv_inputs)
176176

177177
cls.rv_op = bart_op
178-
params = [X, Y, m, alpha, beta, split_prior]
178+
params = [X, Y, m, alpha, beta]
179179
return super().__new__(cls, name, *params, **kwargs)
180180

181181
@classmethod

pymc_bart/pgbart.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import numpy as np
1818
import numpy.typing as npt
1919
from numba import njit
20+
from pymc.initial_point import PointType
2021
from pymc.model import Model, modelcontext
2122
from pymc.pytensorf import inputvars, join_nonshared_inputs, make_shared_replacements
2223
from pymc.step_methods.arraystep import ArrayStepShared
@@ -125,9 +126,12 @@ def __init__( # noqa: PLR0915
125126
num_particles: int = 10,
126127
batch: tuple[float, float] = (0.1, 0.1),
127128
model: Optional[Model] = None,
129+
initial_point: PointType | None = None,
130+
compile_kwargs: dict | None = None, # pylint: disable=unused-argument
128131
):
129132
model = modelcontext(model)
130-
initial_values = model.initial_point()
133+
if initial_point is None:
134+
initial_point = model.initial_point()
131135
if vars is None:
132136
vars = model.value_vars
133137
else:
@@ -150,7 +154,7 @@ def __init__( # noqa: PLR0915
150154
self.m = self.bart.m
151155
self.response = self.bart.response
152156

153-
shape = initial_values[value_bart.name].shape
157+
shape = initial_point[value_bart.name].shape
154158

155159
self.shape = 1 if len(shape) == 1 else shape[0]
156160

@@ -217,8 +221,8 @@ def __init__( # noqa: PLR0915
217221

218222
self.num_particles = num_particles
219223
self.indices = list(range(1, num_particles))
220-
shared = make_shared_replacements(initial_values, vars, model)
221-
self.likelihood_logp = logp(initial_values, [model.datalogp], vars, shared)
224+
shared = make_shared_replacements(initial_point, vars, model)
225+
self.likelihood_logp = logp(initial_point, [model.datalogp], vars, shared)
222226
self.all_particles = [
223227
[ParticleTree(self.a_tree) for _ in range(self.m)] for _ in range(self.trees_shape)
224228
]

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
pymc>=5.16.2, <=5.18
1+
pymc>=5.16.2, <=5.19.1
22
arviz>=0.18.0
33
numba
44
matplotlib

tests/test_bart.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,8 +248,10 @@ def test_categorical_model(separate_trees, split_rule):
248248
separate_trees=separate_trees,
249249
)
250250
y = pm.Categorical("y", p=pm.math.softmax(lo.T, axis=-1), observed=Y)
251-
idata = pm.sample(random_seed=3415, tune=300, draws=300)
252-
idata = pm.sample_posterior_predictive(idata, predictions=True, extend_inferencedata=True)
251+
idata = pm.sample(tune=300, draws=300, random_seed=3415)
252+
idata = pm.sample_posterior_predictive(
253+
idata, predictions=True, extend_inferencedata=True, random_seed=3415
254+
)
253255

254256
# Fit should be good enough so right category is selected over 50% of time
255257
assert (idata.predictions.y.median(["chain", "draw"]) == Y).all()

0 commit comments

Comments
 (0)