-
Notifications
You must be signed in to change notification settings - Fork 199
Feature: adaptive learning rate schedulers #1629
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Feature: adaptive learning rate schedulers #1629
Conversation
- Add scheduler infrastructure to base NeuralInference class - Implement _create_lr_scheduler() supporting 6 scheduler types: * plateau (ReduceLROnPlateau) * exponential (ExponentialLR) * cosine (CosineAnnealingLR) * step (StepLR) * multistep (MultiStepLR) * cyclic (CyclicLR) - Enhanced _converged() with optional min LR threshold - Update NPE_A and NPE_C train() method signatures - Integrate scheduler stepping in NPE training loop - Add learning rate tracking and TensorBoard logging - Maintain full backward compatibility 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
- 16 comprehensive tests covering all scheduler functionality - Tests all 6 scheduler types with parameter validation - Verifies actual learning rate reduction behavior - Tests dictionary-based configuration - Validates minimum LR threshold convergence - Ensures backward compatibility (no scheduler = no change) - Tests all NPE variants (NPE_A, NPE_B, NPE_C) - Error handling for invalid scheduler types - Parameter override functionality - Resume training with scheduler state preservation Utilizes existing SBI test fixtures and follows established patterns. All tests pass with the new scheduler implementation. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
- Add scheduler parameters (lr_scheduler, lr_scheduler_kwargs, min_lr_threshold) to NLE base class - Integrate scheduler creation and stepping in nle_base.py training loop - Add learning rate tracking and logging to summary - Update MNLE train method signature to support scheduler parameters - Support all 6 scheduler types: plateau, exponential, cosine, step, multistep, cyclic - Maintain backward compatibility - no scheduler defaults to constant learning rate - Add min_lr_threshold for early stopping when learning rate becomes too low 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
- Add scheduler parameters (lr_scheduler, lr_scheduler_kwargs, min_lr_threshold) to NRE base class - Integrate scheduler creation and stepping in nre_base.py training loop - Add learning rate tracking and logging to summary - Update all NRE variant train method signatures (NRE_A, NRE_B, NRE_C, BNRE) - Add missing 'Any' imports to typing statements in NRE files - Support all 6 scheduler types: plateau, exponential, cosine, step, multistep, cyclic - Maintain backward compatibility - no scheduler defaults to constant learning rate - Add min_lr_threshold for early stopping when learning rate becomes too low 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
- Test scheduler creation and integration for NLE_A and MNLE - Test all scheduler types: plateau, exponential, step, cosine - Test dictionary configuration for schedulers - Test learning rate reduction and tracking - Test min_lr_threshold early stopping - Test backward compatibility without schedulers - Test error handling for invalid scheduler types - Test scheduler kwargs override functionality - Test resume training with scheduler state preservation 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
- Test scheduler creation and integration for all NRE variants (NRE_A, NRE_B, NRE_C, BNRE) - Test all scheduler types: plateau, exponential, step, multistep, cyclic - Test dictionary configuration for schedulers - Test learning rate reduction and tracking - Test min_lr_threshold early stopping - Test backward compatibility without schedulers - Test error handling for invalid scheduler types - Test scheduler kwargs override functionality - Test resume training with scheduler state preservation - Test CyclicLR scheduler with learning rate variation 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
- Add how-to guide for learning rate schedulers with practical examples - Add advanced tutorial notebook with comparative analysis and visualizations - Update documentation index files to include new scheduler docs - Cover all supported schedulers: plateau, exponential, step, multistep, cosine, cyclic - Include configuration examples, best practices, and troubleshooting - Demonstrate usage across NPE, NLE, and NRE methods 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
|
Thanks a lot @christopherlovell for creating this PR, looks like a great addition. We will do a review of the proposed changes soon! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did a first high-level pass over the changes - great additions overall! 👏
Before we go into the details, please have a look at my two main comments on the class structure for the LR scheduler options and their kwargs, and for the tests. thanks!
| lr_scheduler: Optional[Union[str, Dict[str, Any]]], | ||
| lr_scheduler_kwargs: Optional[Dict[str, Any]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of passing both, lr_scheduler and the corresponding kwargs, I suggest to use dataclasses like we recently introduced for the PosteriorParameters, see #1619. You would need to define a base class LrSchedulerParameters defining the interface and holder all shared parameters, and then a subclass for every LR scheduler type.
We should still offer the basic user to pass a string with the type, then we would just pick the corresponding parameter class and use the default. But as soon as the user wants to specify specific options, they should create their parameter class themselves and pass it here.
This will reduce some of the if-else below and strongly improve the type hints, IDE suggestions and debugging. Let me know if there are any questions and when I am missing something.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fixture on top is great, but the tests are quite verbose and repetitive at the moment. I suggest to use pytest.mark.parametrize to test the lr schedulers for the different trainers (npe, nle, nre) and scheduler types all in one test file. as mnle is basically nle, I think it's fine to not include it here.
another think to include would be the vector field estimators, e.g., fmpe and npse, because they handle convergence criteria a bit differently, see #1544 .
Add Adaptive Learning Rate Schedulers to Training
Implements adaptive learning rate schedulers into the NPE, NLE and NRE methods. All changes maintain backward compatibility - existing code continues to work unchanged.
This feature was implemented with assistance from https://claude.ai/code
Co-Authored-By: Claude [email protected]
Changes Made
Core Implementation
CosineAnnealingLR, StepLR, MultiStepLR, and CyclicLR
Testing
Documentation
API Usage