-
Notifications
You must be signed in to change notification settings - Fork 129
Speedup AdvancedSubtensor1 and AdvancedIncSubtensor1 in C backend #1346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
337b91f
to
2862c42
Compare
c0045c3
to
c211405
Compare
Codecov ReportAttention: Patch coverage is
❌ Your patch check has failed because the patch coverage (87.01%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1346 +/- ##
==========================================
+ Coverage 82.04% 82.05% +0.01%
==========================================
Files 203 203
Lines 48837 48912 +75
Branches 8689 8709 +20
==========================================
+ Hits 40067 40134 +67
- Misses 6619 6624 +5
- Partials 2151 2154 +3
🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot reviewed 7 out of 7 changed files in this pull request and generated no comments.
Comments suppressed due to low confidence (3)
tests/tensor/test_subtensor.py:1277
- The generator expression used for 'inc_var_static_shape' should be converted to a tuple (e.g. tuple(...)) to ensure it produces a concrete shape tuple.
inc_var_static_shape = (1 if dim_length == 1 else None for dim_length in inc_shape)
pytensor/link/pytorch/dispatch/subtensor.py:112
- The _check_runtime_broadcasting method expects four arguments (including the node), so the node argument should be passed to ensure correct runtime checking.
if isinstance(op, AdvancedIncSubtensor1): op._check_runtime_broadcasting(x, y, indices)
pytensor/link/jax/dispatch/subtensor.py:70
- Similar to the PyTorch dispatch, the node argument is missing when calling _check_runtime_broadcasting. Ensure that the node is passed as the first parameter after self.
if isinstance(op, AdvancedIncSubtensor1): op._check_runtime_broadcasting(x, y, indices)
Useless AI strikes again! |
It was actually correct, I'm surprised no tests failed, I guess we are not really covering the dispatch of these in JAX/PyTorch because the Op is only introduced during rewrites |
Also add checks for runtime broadcast
c211405
to
38f9036
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot reviewed 7 out of 7 changed files in this pull request and generated 1 comment.
Comments suppressed due to low confidence (1)
pytensor/tensor/subtensor.py:2247
- The condition validating index values in _idx_may_be_invalid is non-obvious; please verify that it correctly handles negative indices and reflects the intended bounds check.
return not (min_idx >= 0 or min_idx >= -shape0) and (max_idx < 0 or max_idx < shape0)
}} | ||
|
||
if ((PyArray_NDIM({out}) != 1) || ({unexpected_shape0})) {{ | ||
PyErr_SetString(PyExc_ValueError, "AdvancedIncSubtensor1: fist input (x) does not have right shape or ndim"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo detected in the error message: 'fist input' should be changed to 'first input'.
PyErr_SetString(PyExc_ValueError, "AdvancedIncSubtensor1: fist input (x) does not have right shape or ndim"); | |
PyErr_SetString(PyExc_ValueError, "AdvancedIncSubtensor1: first input (x) does not have right shape or ndim"); |
Copilot is powered by AI, so mistakes are possible. Review output carefully before use.
These are some of the biggest drags in the C-backend. This PR does some tweaks that increase performance substantically.
AdvancedSubtensor1 benchmark
Before
After
AdvancedIncSubtensor1 benchmark
Before
After
I added a long-missing check for runtime broadcasting in the python/C/torch implementations (would require a bit more code-changes for numba), which moves towards #1348
Provides a restricted case of #1325
Alloc of zeros is also about twice as fast now, which is benchmarked indirectly in the
AdvancedIncSubtensor1
tests📚 Documentation preview 📚: https://pytensor--1346.org.readthedocs.build/en/1346/