Skip to content

Regression colab example - JAX implementation  #8

@conorhassan

Description

@conorhassan

Hi, great paper!

I implemented the regression colab example (or at least the first VBLLMLP example) in JAX. I wrote the equiv. of the distributions.py by subclassing numpyro.distributions and implemented the Regression and VBLLMLP classes in flax. The model is training but the uncertainty bands are a bit of a mess.

Are there any plans to implement in JAX? Would be keen to maybe help out a little if there was. Would be keen to find the errors in my colab somehow too...

Here is the colab: https://colab.research.google.com/drive/1Rh895u0jP9xEpK7eMOz9JHUX_2CluyLO?usp=sharing

Thanks,
Conor

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions