|
4 | 4 | from dataclasses import dataclass |
5 | 5 | from importlib.util import find_spec |
6 | 6 | from math import prod |
7 | | -from typing import TYPE_CHECKING, Any, Optional |
| 7 | +from typing import TYPE_CHECKING, Any, Literal, Optional |
8 | 8 |
|
9 | 9 | import numpy as np |
10 | 10 | import pandas as pd |
11 | 11 | from numpy.typing import NDArray |
12 | | -from nutpie.compiled_pyfunc import from_pyfunc |
13 | | -from nutpie.sample import CompiledModel |
14 | 12 |
|
15 | 13 | from nutpie import _lib |
| 14 | +from nutpie.compiled_pyfunc import from_pyfunc |
| 15 | +from nutpie.sample import CompiledModel |
16 | 16 |
|
17 | 17 | try: |
18 | 18 | from numba.extending import intrinsic |
@@ -184,7 +184,7 @@ def _compile_pymc_model_numba(model: "pm.Model", **kwargs) -> CompiledPyMCModel: |
184 | 184 | for val in [*logp_fn_pt.get_shared(), *expand_fn_pt.get_shared()]: |
185 | 185 | if val.name in shared_data and val not in seen: |
186 | 186 | raise ValueError(f"Shared variables must have unique names: {val.name}") |
187 | | - shared_data[val.name] = val.get_value().copy() |
| 187 | + shared_data[val.name] = val.get_value() |
188 | 188 | shared_vars[val.name] = val |
189 | 189 | seen.add(val) |
190 | 190 |
|
@@ -308,7 +308,7 @@ def logp_fn_jax_grad(x, **shared): |
308 | 308 | for val in [*logp_fn_pt.get_shared(), *expand_fn_pt.get_shared()]: |
309 | 309 | if val.name in shared_data and val not in seen: |
310 | 310 | raise ValueError(f"Shared variables must have unique names: {val.name}") |
311 | | - shared_data[val.name] = jax.numpy.asarray(val.get_value().copy()) |
| 311 | + shared_data[val.name] = jax.numpy.asarray(val.get_value()) |
312 | 312 | shared_vars[val.name] = val |
313 | 313 | seen.add(val) |
314 | 314 |
|
@@ -356,8 +356,12 @@ def expand(x, **shared): |
356 | 356 |
|
357 | 357 |
|
358 | 358 | def compile_pymc_model( |
359 | | - model: "pm.Model", *, backend="numba", gradient_backend=None, **kwargs |
360 | | -) -> CompiledPyMCModel: |
| 359 | + model: "pm.Model", |
| 360 | + *, |
| 361 | + backend: Literal["numba", "jax"] = "numba", |
| 362 | + gradient_backend: Literal["pytensor", "jax"] | None = None, |
| 363 | + **kwargs, |
| 364 | +) -> CompiledModel: |
361 | 365 | """Compile necessary functions for sampling a pymc model. |
362 | 366 |
|
363 | 367 | Parameters |
|
0 commit comments