Hi, great paper!
I implemented the regression colab example (or at least the first VBLLMLP example) in JAX. I wrote the equiv. of the distributions.py by subclassing numpyro.distributions and implemented the Regression and VBLLMLP classes in flax. The model is training but the uncertainty bands are a bit of a mess.
Are there any plans to implement in JAX? Would be keen to maybe help out a little if there was. Would be keen to find the errors in my colab somehow too...
Here is the colab: https://colab.research.google.com/drive/1Rh895u0jP9xEpK7eMOz9JHUX_2CluyLO?usp=sharing
Thanks,
Conor