Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Timeline for JaxOpt migration #977

Open
Joshuaalbert opened this issue May 28, 2024 · 29 comments
Open

Timeline for JaxOpt migration #977

Joshuaalbert opened this issue May 28, 2024 · 29 comments

Comments

@Joshuaalbert
Copy link

Joshuaalbert commented May 28, 2024

Hello, what's the current roadmap for jaxopt migration into optax? Will the scope of jaxopt be maintained, or will a trimming/expansion of features happen?

EDIT 27 October, 2024:

Dear all, based on optax's PR's, I suspect optax is adopting a ML-first/only roadmap. This is fine, as it's in their interest. However, the larger science community requires an orthogonal set of optimisation algorithms. I propose that, if there is enough interest among you all, to open a separate community-maintained repository covering the missing space of tools. It would focus on performance, robustness, and ease of use, and welcome ANY category of algorithm, predominantly focusing on the science community. If you're in favour leave a +1 or <3 on this comment, and I'll tag you when I open the repository.

@jlperla
Copy link

jlperla commented Jun 19, 2024

Another question I had on the jaxopt is the extent to which the differentiable optimization features will be ported?

That is, will there be easy ways to get the jvp and vjp linearization around the argmin or min of the optimization process? (e.g. https://jaxopt.github.io/stable/implicit_diff.html#implicit-diff) Of course, none of this is specific to the optimization algorithm itself, but with training-loop style optimization it isn't clear where the custom jvp/vjp rules should be placed. You don't want to recurse through the training loop itself to differentiate.

For reference, in jaxopt this is done with a wrapper around the solver itself which registers the @jax.custom_vjp in things like https://github.com/google/jaxopt/blob/4a50198711e53ed0c25f2be5394825092c2427db/jaxopt/_src/implicit_diff.py#L204

@Joshuaalbert
Copy link
Author

I hope there will be an RFC around this because there are numerous groups interested in JAX based optimization from science (e.g. MODE consortium https://indico.cern.ch/event/1380163/) to general entry-level ML students and everything in between.

@mblondel
Copy link
Collaborator

mblondel commented Jun 24, 2024

So far, we've migrated all loss functions, all tree utilities, perturbation utilities, some projections and an optax compatible LBFGS (with backtracking line search or zoom line search).

The next big item we want to migrate is implicit differentiation. We'll start by migrating the custom_ decorators from JAXopt. This will require users to write a small wrapper on their side, but this is fairly simple; see this example.

We've also been brainstorming a solver API similar to JAXopt but this one will take more time, as we want to integrate it well with the current Optax API. Once we've figured a good solver API, it would be great to also migrate the SciPy wrappers.

@Joshuaalbert
Copy link
Author

Good to hear an update. This sounds like it will drop some off the classic solvers from support e.g. LM and BFGS? Will the main focus still be mini-batched optimisation?

Re: API. Could I suggest that once you have some candidate API structures you open an RFC (request for community comment) issue for a week or so just to see if there is any useful feedback?

@krzysztofrusek
Copy link

@Joshuaalbert there was one PR (#777 ) introducing solver API but it was reverted.

Regarding BFGS and similar, they are covered by optimistix.

@jlperla
Copy link

jlperla commented Jun 24, 2024

@mblondel Having a wrapper, and even a coding pattern where you register your own VJP/JVP is fine but I am a little confused in that it sounds like you are decoupling it from the solver API. Or maybe I misunderstood?

FWIW, I think it would be completely fine if you guys gave an example along the lines of

@jax.jit
@jax.custom_jvp
def my_solver(params):
   # some closure over data/etc.
   # run the jaxopt optimizer training loop in whatever minibatch loop you with.
   # return the argmin and min values
   return min_value, argmin

@jax.jit
@my_solver.defjvp
def my_solver_fwd(params):
   # implementation to return the JVP perturbation to both
   # Assuming that the my_solver has delivered the optimum to precision required for implicit differentiation
   
@jax.jit
@my_solver.defjvp
def my_solver_bwd(primals, tangents):
   # implement using AD rules for implicit differentiation for VJP

Or whatever is correct, and providing the correct template code that can by copy/pasted for the AD rules (which is independent of the optimization method).

Having a solver interface that automatically does that registration/code would be nice, but a hand-coded example is good enough to start?

@mblondel
Copy link
Collaborator

This sounds like it will drop some off the classic solvers from support e.g. LM and BFGS?

BFGS should be easy to migrate using the same optax API as LBFGS. For LM, we want to implement the algorithm from scratch but it's not super high in our priority list.

Could I suggest that once you have some candidate API structures you open an RFC (request for community comment) issue for a week or so just to see if there is any useful feedback?

Good idea

Having a solver interface that automatically does that registration/code would be nice, but a hand-coded example is good enough to start?

Yep, that's the plan. We'll document how to use the custom decorators. Automatic registration will be possible when we figure out a solver API.

@diegoferigo
Copy link

Is there any interest in migrating jaxopt.BoxOSQP? There is some recent interest within the robotics community in the development of JAX-based controllers to be used in closed-loop simulations that can be executed at scale on hardware accelerators (end-to-end differentiable).

As far as I know, this OSQP implementation is the only quadratic programming framework currently available for the JAX ecosystem.

@jlperla
Copy link

jlperla commented Jul 7, 2024

@mblondel Exciting to see all of the progress here! Is there a planned timeline for the release with linesearch/lbfgs/etc.? I would love to showcase some of it for teaching (and sample code for a survey paper where I want to highlight optimization techniques).

@vroulet
Copy link
Collaborator

vroulet commented Jul 8, 2024

We are planning a release soon (just waiting for #998 to be merged). Hopefully by the end of the week.

@lvjonok
Copy link

lvjonok commented Jul 13, 2024

I completely agree with @diegoferigo. Including jaxopt.BoxOSQP in optax would be a fantastic addition. I am enthusiastic about contributing to this feature.

If this aligns with the repository's roadmap, I can begin the implementation and collaborate on reviewing it to ensure it's ready for integration.

@diegoferigo
Copy link

Is there any interest in migrating jaxopt.BoxOSQP? There is some recent interest within the robotics community in the development of JAX-based controllers to be used in closed-loop simulations that can be executed at scale on hardware accelerators (end-to-end differentiable).

As far as I know, this OSQP implementation is the only quadratic programming framework currently available for the JAX ecosystem.

For those reading my previous comment and need a QP solver in JAX, I'd suggest to also have a look to kevin-tracy/qpax.

@lvjonok
Copy link

lvjonok commented Oct 24, 2024

Is there any interest in migrating jaxopt.BoxOSQP? There is some recent interest within the robotics community in the development of JAX-based controllers to be used in closed-loop simulations that can be executed at scale on hardware accelerators (end-to-end differentiable).
As far as I know, this OSQP implementation is the only quadratic programming framework currently available for the JAX ecosystem.

For those reading my previous comment and need a QP solver in JAX, I'd suggest to also have a look to kevin-tracy/qpax.

It sure works, but I wonder if someone conducted comparison between two implementations. I have heard that qpax might be not that efficient in comparison to BoxOSQP.

@diegoferigo
Copy link

I'm working on a project in which we are exploring tha usage of QP solvers in JAX to implement rigid contacts for a hardware accelerated physics engine. In ami-iit/jaxsim#218, @xela-95 provided a first implementation (here where qpax is called). If I recall, we had problems on primal infeasibilities with jaxopt.BoxOSQP, that didn't affect qpax. For the moment, we don't care much about performance, we have other bottlenecks to solve and I believe that we can use something better than a QP solver (like projected gauss-seidel) on the long term.

@dangpzanco
Copy link

I think it would be very useful to have some kind of open tracking of the Jaxopt porting process with a list of features and their respective status (e.g. planned, not planned, being developed, finished, etc). It not very obvious to have to look into a closed issue in order to get some updates...

@Joshuaalbert
Copy link
Author

Leave a +1 or <3 to indicate you want to be invited to the new repo.

Dear all, based on optax's PR's, I suspect optax is adopting a ML-first/only roadmap. This is fine, as it's in their interest. However, the larger science community requires an orthogonal set of optimisation algorithms. I propose that, if there is enough interest among you all, to open a separate community-maintained repository covering the missing space of tools. It would focus on performance, robustness, and ease of use, and welcome ANY category of algorithm, predominantly focusing on the science community. If you're in favour leave a +1 or <3 on this comment, and I'll tag you when I open the repository.

@carlosgmartin
Copy link
Contributor

@Joshuaalbert Can you list some of the algorithms you have in mind?

@Joshuaalbert
Copy link
Author

@carlosgmartin I would make it completely contribution based, with wide inclusion criteria. There would be a simple community-driven peer review process for accepting new contributions. I would say any algorithm would be a suitable candidate as long as it would be useful for others in their field. I myself would contribute some quasi-Newton optimisers, a gradient-free global optimiser, and some EM-type algorithms for maximising Bayesian evidence.

@vroulet
Copy link
Collaborator

vroulet commented Oct 28, 2024

Hello @Joshuaalbert, @carlosgmartin,

Sincere apologies on this. We are stretched thin and as you said deterministic algorithms are not considered a priority internally.

That said, if you want, we can make a folder "solvers" in optax and let you handle it (we have little time to code these algorithms ourselves given regular other issues but we can still help integrating PRs). Carlos did for example an algorithm for the Hungarian algorithm that we are happy to have. Optax requires an internal approval but we can let you handle reviews. The advantage of having such algorithms in optax can be a greater visibility.

If you prefer to work outside of optax, I would recommend taking a look at optimistix. Optimistix solved a few issues that jaxopt had, in particular in terms of compilation times of the solvers (it requires some mental gymnastic to ensure a single compilation of the objective function, so in any case, it would be good to take a look at how it has been handled in optimistix).

I would sincerely encourage you to consider one of the two options above to avoid fragmenting the code and ensure that the algorithms get a greater visibility.

Again we sincerely apologize for the situation and thank you for your continued interest.

PS: I'll reopen the issue to encourage further discussion/better visibility of your proposals for now.

@Joshuaalbert
Copy link
Author

@vroulet I am open to creating a new section inside optax with increased scope. However, I'd like to have a call to discuss details. My desire is to 1) involve the community in the contribution, review, and maintenance process, 2) focus on HPC at scientific scale, 3) focus more on examples (scientists like to copy and modify), and 4) encourage a continuous improvement of the suite of methods. This last point is crucial. As people experiment, they typically improve a method (e.g. initialisation of hyper parameters). A set of standard benchmarks will allow deciding on whether a modification/contribution is of merit.

In terms of the communities that I'm thinking of: HPC groups doing big scale calibration simulation or inference (astro, physics, biology, epidemiology, engineering, ...), those doing experimental design optimisation leveraging AD, and anyone who has a difficult optimisation problem. Lots of these groups are in the process of porting Fortran code to JAX, and many of these groups are currently operating over distributed cluster environments.

My reason for not wishing to do this via optimistix is that there is too strong an enforcement of using Equinox, and I believe this will prevent the community from participating. I also don't think modularisation is so crucial here. As long as the code is well-written, it is alright for there to be some redundancy. Otherwise, the code review process is too lengthy and pushes people away. Hence, my philosophy is if a method does well on benchmarks, and is reasonably well-written, it passes, and can be improved over time.

@vroulet
Copy link
Collaborator

vroulet commented Oct 29, 2024

@Joshuaalbert I've sent you an invite at your email address in Leiden university. If you haven't received it, let me know.

@Joshuaalbert
Copy link
Author

@vroulet works for me

@patrick-kidger
Copy link

Ah, I've just been tagged on this. Interesting reading! I'm just jumping in to comment on this point:

My reason for not wishing to do this via optimistix is that there is too strong an enforcement of using Equinox

This isn't a requirement at all. Optimistix just asks that you pass it any pytree-of-arrays; that's it.

(I would obviously encourage using Equinox anyway ;) But all of the JAX scientific computing libraries I've been involved with have conciously avoided making this a requirement.)

@Joshuaalbert
Copy link
Author

@patrick-kidger would you like to join a call this Thursday with @vroulet and me?

@patrick-kidger
Copy link

Depends on the time! Send me an email (check my website).

@vroulet
Copy link
Collaborator

vroulet commented Oct 30, 2024

I've sent you an invite to [email protected], let me know if you have received it and feel free to move the meeting around if needed (after 7am PDT preferably).

@patrick-kidger
Copy link

LGTM! Received it. The current time works for me :)

@carlosgmartin
Copy link
Contributor

Just a comment:

One unfortunate thing I've seen happen in the JAX ecosystem is the following: A person or handful of people develop a library. Others find it useful and start adopting it as a dependency. After a while, the developers stop maintaining it and become unresponsive, leaving downstream projects adrift.

This is not a knock on the developers. Life happens and people get busy with other things. Most of the time, they're providing a free service to the community. It's just unfortunate for the ecosystem, because it creates uncertainty about which tools to use (that can be relied on long-term), fragmentation in terms of what to use as a replacement, and duplication of work (not just in terms of writing code, but maintaining codebases).

It seems to me that, for foundational libraries, it's generally better to have a few canonical repos with many maintainers and several eyes on the same code (to find bugs, add features, improve documentation, respond to issues, etc.), than a myriad of lightly-maintained repos. In other words, it's good to have "Schelling points" for software development.

@Joshuaalbert
Copy link
Author

@carlosgmartin I think if one looks at the string a unofficial google products that are now in maintenance mode we see this pattern of marooning occuring in a more impactful way. This is precisely why there is such an interest in the roadmap of jaxopt migration. It's an interest out of desperation. The small repos out there that get abandoned will get picked up again, and often improved, if they are useful. This has happened since open source began, and is not a reason against spawning more projects.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests