Only track non-scalar attributes in UNVECTORIZED_NUM_ATTR_DIMS#627
Open
cr-xu wants to merge 3 commits into
Open
Only track non-scalar attributes in UNVECTORIZED_NUM_ATTR_DIMS#627cr-xu wants to merge 3 commits into
UNVECTORIZED_NUM_ATTR_DIMS#627cr-xu wants to merge 3 commits into
Conversation
…ntroducing a defining_features property in beam classes;
Member
Author
|
This one is actually ready for review since a while. |
Contributor
There was a problem hiding this comment.
Pull request overview
This PR refactors how beam attributes are enumerated by (1) clearing the scalar-heavy default Beam.UNVECTORIZED_NUM_ATTR_DIMS, (2) introducing a defining_features property for ParticleBeam/ParameterBeam (mirroring Element), and (3) updating tests and segment metric collection to use these new patterns.
Changes:
- Empties
Beam.UNVECTORIZED_NUM_ATTR_DIMSand relies on subclasses to list only non-scalar attributes. - Adds
defining_featurestoParticleBeamandParameterBeam, and updates multiple tests to use it (excludingspecies). - Updates
Segment.get_beam_attrs_along_segmentto default missingUNVECTORIZED_NUM_ATTR_DIMSentries to0.
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/test_particle_beam.py | Switch dtype/device assertions to iterate over defining_features (excluding species). |
| tests/test_elements.py | Use outgoing_beam.defining_features to validate dtype/device propagation. |
| tests/test_drift.py | Use defining_features for equality checks between tracked beams (but introduces >88-char lines). |
| tests/test_clone.py | Validate cloned beam buffers via defining_features rather than attr-dims dict keys. |
| tests/test_beam.py | Use defining_features for dtype conversion assertions. |
| cheetah/particles/particle_beam.py | Adds ParticleBeam.defining_features. |
| cheetah/particles/parameter_beam.py | Adds ParameterBeam.defining_features. |
| cheetah/particles/beam.py | Clears default UNVECTORIZED_NUM_ATTR_DIMS and introduces abstract defining_features. |
| cheetah/accelerator/segment.py | Defaults missing attr-dims to 0 when stacking metrics along a segment. |
| CHANGELOG.md | Adds an entry describing the attr-dims refactor and segment stacking default. |
Comments suppressed due to low confidence (1)
tests/test_drift.py:109
- This line exceeds the repo’s Flake8
max-line-length = 88(see.flake8). Please wrap the list comprehension across multiple lines (as done in other updated tests) to avoid CI lint failures.
beam_attributes = [attr for attr in outgoing.defining_features if attr != "species"]
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
573
to
578
| broadcasted_results = tuple( | ||
| torch.stack( | ||
| torch.broadcast_tensors(*attr_tensor), | ||
| dim=-(incoming.UNVECTORIZED_NUM_ATTR_DIMS[attr_name] + 1), | ||
| dim=-(incoming.UNVECTORIZED_NUM_ATTR_DIMS.get(attr_name, 0) + 1), | ||
| ) | ||
| for attr_tensor, attr_name in zip(results, attr_name_tuple) |
|
|
||
| # Check that all properties of the two outgoing beams are same | ||
| for attribute in outgoing.UNVECTORIZED_NUM_ATTR_DIMS.keys(): | ||
| beam_attributes = [attr for attr in outgoing.defining_features if attr != "species"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
defining_featuresproperty in the same fashion as for Elements.defining_featuresproperty in beam classes to verify dtypes and devices are tracked properly. All the tested attributes are generateed from these definiing_features and the test was bit redundant.Motivation and Context
Types of changes
Checklist
flake8(required).pytesttests pass (required).pyteston a machine with a CUDA GPU and made sure all tests pass (required).Note: We are using a maximum length of 88 characters per line.