Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference Gym: adding and/or updating ground truth expectations #1992

Open
reubenharry opened this issue Feb 16, 2025 · 3 comments
Open

Inference Gym: adding and/or updating ground truth expectations #1992

reubenharry opened this issue Feb 16, 2025 · 3 comments

Comments

@reubenharry
Copy link

(I've divided this into three subsections, and can split into separate issues if preferable)

Length of ground truth runs

As I understand it, the current ground truth estimates are obtained from Stan with 150000 samples and 10 chains.

For certain models, such as `gym.targets.VectorModel(gym.targets.BrownianMotionUnknownScalesMissingMiddleObservations(), flatten_sample_transformations=True,), I have produced my own ground truths via longer runs of Blackjax's NUTS (10 million steps, 4 chains), and found results that differ enough to matter for my use cases (namely, estimating efficiency of different samplers)

from blackjax run: [ 0.11525708 0.09256472 0.05635736 -0.03410918 -0.05100336 -0.18196875
-0.18945307 -0.25923407 -0.25987643 -0.32402724 -0.22958763 -0.28165078
-0.3362609 -0.38868254 -0.44175696 -0.4945148 -0.5447676 -0.6013282
-0.6559048 -0.7087315 -0.75866866 -0.8134075 -0.8074223 -0.7784713
-0.82167107 -0.7737639 -0.743899 -0.7613981 -0.6401507 -0.6669518
-0.64461184 0.11305185]

from gym: [ 0.11984811 0.10274264 0.06093274 -0.03870019 -0.04362268 -0.19021639
-0.1856622 -0.26851514 -0.26010785 -0.3334386 -0.21788554 -0.2735482
-0.33083084 -0.38252977 -0.43280044 -0.49400684 -0.54860604 -0.60449123
-0.65569454 -0.7083658 -0.76391494 -0.8189823 -0.8105346 -0.7771473
-0.8268097 -0.7768991 -0.7374106 -0.7740582 -0.6294383 -0.670295
-0.6432216 0.10105278]

(See e.g. the hierarchical params, in particular, the second and final elements of the array).

If my results are actually more accurate (of course it's possible there's a mistake on my end), then would it be possible to switch to the results of a longer run (either of Stan or Blackjax, but see the final section below) in inference-gym?

Adding ground truth expectations of second moment

I would also like to add ground truth estimates of the second moment, i.e. $\mathcal{E}[x^2]$. Would it be possible for me to add these to certain inference-gym models?

Blackjax vs Stan

Currently, Stan is used by inference-gym to produce samples for ground truth estimates, run via CmdStanPy. How open would inference-gym be to switching to Blackjax's NUTS implementation instead, to obtain a fully Python setup? (Or even the TFP NUTS implementation)

@SiegeLordEx
Copy link
Member

then would it be possible to switch to the results of a longer run (either of Stan or Blackjax, but see the final section below) in inference-gym?

Yeah. For Stan, just do the longer run and report the diagnostics.

Would it be possible for me to add these to certain inference-gym models?

Yep, no problem.

How open would inference-gym be to switching to Blackjax's NUTS implementation instead, to obtain a fully Python setup? (Or even the TFP NUTS implementation)

It's mainly a question of trust. Stan comes with pretty "battle tested" and stringent diagnostics and has a wide userbase to report any systematic errors. I know TFP does not have such diagnostics (although we can outsource to, e.g., arviz), and I'm not sure about blackjax. If you can convince me that they can be trusted, then there's no technical reason not to have an option for alternate sources (we'd refactor get_ground_truth.py to have a flag to select the inference engine).

@SiegeLordEx
Copy link
Member

One perennial issue with Stan alternatives based on JAX (or any ML framework really) is that they tend to run in single precision which can mask certain issues (https://arxiv.org/html/2411.04260v1#S5 has some discussion). Also, when running on accelerators, I've seen non-trivial changes in behavior compared to CPU. I'd be somewhat uncomfortable using non-double precision, non-CPU-based inference for ground truths. Of course, JAX can be run in such a configuration, so it's not blocker, but just an important consideration.

@reubenharry
Copy link
Author

OK, sounds good. Probably easiest for me to use Stan for now.

By the way, in terms of building trust for other implementations, what I'm currently working on (and indeed why I'm using inference-gym) is building a benchmarking and diagnostics package, that runs a subset of the inference-gym models on Blackjax' samplers and shows that the behavior matches Stan (in this case on double precision, CPU).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants