Skip to content

Speedup AdvancedSubtensor1 and AdvancedIncSubtensor1 in C backend #1346

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Apr 7, 2025

These are some of the biggest drags in the C-backend. This PR does some tweaks that increase performance substantically.

AdvancedSubtensor1 benchmark

Before
---------------------------------------------------------------------------------------------------- benchmark: 4 tests ---------------------------------------------------------------------------------------------------
Name (time in us)                                            Min                Max              Mean            StdDev            Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_advanced_subtensor1[gc=True-static_shape=False]      1.7730 (1.0)      62.5270 (5.48)     2.0225 (1.00)     0.6402 (2.22)     1.9640 (1.0)      0.0400 (1.00)    1785;6702      494.4466 (1.00)      83320           1
test_advanced_subtensor1[gc=True-static_shape=True]       1.7830 (1.01)     57.1570 (5.01)     2.0216 (1.0)      0.6144 (2.13)     1.9840 (1.01)     0.0400 (1.0)      690;2893      494.6573 (1.0)       73660           1
test_advanced_subtensor1[gc=False-static_shape=True]      2.1040 (1.19)     11.4020 (1.0)      2.3344 (1.15)     0.2878 (1.0)      2.3040 (1.17)     0.0500 (1.25)    1349;4021      428.3810 (0.87)     102691           1
test_advanced_subtensor1[gc=False-static_shape=False]     2.1140 (1.19)     19.3860 (1.70)     2.2857 (1.13)     0.2986 (1.04)     2.2740 (1.16)     0.0510 (1.27)       65;547      437.5102 (0.88)      11134           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
After
---------------------------------------------------------------------------------------------------- benchmark: 4 tests ----------------------------------------------------------------------------------------------------
Name (time in us)                                            Min                 Max              Mean            StdDev            Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_advanced_subtensor1[gc=False-static_shape=True]      1.2820 (1.0)       45.9260 (1.0)      1.4129 (1.0)      0.5339 (1.0)      1.3730 (1.0)      0.0490 (1.22)      783;971      707.7531 (1.0)       57594           1
test_advanced_subtensor1[gc=True-static_shape=True]       1.4620 (1.14)      87.1230 (1.90)     1.6104 (1.14)     0.5974 (1.12)     1.5630 (1.14)     0.0400 (1.0)     2326;3531      620.9714 (0.88)     115929           1
test_advanced_subtensor1[gc=False-static_shape=False]     1.7530 (1.37)      71.8950 (1.57)     1.9472 (1.38)     0.7248 (1.36)     1.9240 (1.40)     0.0400 (1.00)       65;408      513.5575 (0.73)      10931           1
test_advanced_subtensor1[gc=True-static_shape=False]      1.7730 (1.38)     282.6200 (6.15)     2.0129 (1.42)     1.1193 (2.10)     1.9540 (1.42)     0.0400 (1.0)     1410;6501      496.7973 (0.70)     113676           1
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

AdvancedIncSubtensor1 benchmark

Before
------------------------------------------------------------------------------------------------------------- benchmark: 8 tests ------------------------------------------------------------------------------------------------------------
Name (time in us)                                                             Min                 Max              Mean            StdDev            Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_advanced_incsubtensor1[set_subtensor-gc=False-static_shape=True]      5.2000 (1.0)       80.1600 (2.73)     5.5870 (1.01)     1.7055 (2.56)     5.4600 (1.01)     0.1000 (1.64)     875;1577      178.9867 (0.99)      48810           1
test_advanced_incsubtensor1[set_subtensor-gc=False-static_shape=False]     5.2000 (1.00)     291.1850 (9.92)     5.5165 (1.0)      2.0229 (3.04)     5.4000 (1.0)      0.0610 (1.0)      744;1847      181.2759 (1.0)       50209           1
test_advanced_incsubtensor1[set_subtensor-gc=True-static_shape=True]       5.3200 (1.02)      71.1430 (2.42)     5.6233 (1.02)     1.3633 (2.05)     5.5210 (1.02)     0.0700 (1.15)     628;1601      177.8307 (0.98)      38184           1
test_advanced_incsubtensor1[set_subtensor-gc=True-static_shape=False]      5.4200 (1.04)      29.3550 (1.0)      5.6929 (1.03)     0.6650 (1.0)      5.6400 (1.04)     0.0700 (1.15)     702;1393      175.6585 (0.97)      47170           1
test_advanced_incsubtensor1[inc_subtensor-gc=False-static_shape=True]      5.8510 (1.13)     346.4190 (11.80)    6.2323 (1.13)     2.3741 (3.57)     6.1110 (1.13)     0.0800 (1.31)     579;1374      160.4549 (0.89)      46905           1
test_advanced_incsubtensor1[inc_subtensor-gc=False-static_shape=False]     5.9310 (1.14)      29.5250 (1.01)     6.2690 (1.14)     0.8971 (1.35)     6.1410 (1.14)     0.0800 (1.31)      377;584      159.5152 (0.88)      10442           1
test_advanced_incsubtensor1[inc_subtensor-gc=True-static_shape=True]       5.9720 (1.15)     111.7590 (3.81)     6.3686 (1.15)     1.8011 (2.71)     6.2510 (1.16)     0.0900 (1.48)     806;1734      157.0216 (0.87)      48265           1
test_advanced_incsubtensor1[inc_subtensor-gc=True-static_shape=False]      6.0720 (1.17)      86.5220 (2.95)     6.5262 (1.18)     1.9851 (2.99)     6.3420 (1.17)     0.1510 (2.48)    1039;1465      153.2284 (0.85)      42257           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
After

------------------------------------------------------------------------------------------------------------- benchmark: 8 tests ------------------------------------------------------------------------------------------------------------
Name (time in us)                                                             Min                 Max              Mean            StdDev            Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_advanced_incsubtensor1[set_subtensor-gc=False-static_shape=True]      1.6530 (1.0)       58.8300 (7.25)     1.8268 (1.0)      0.6546 (3.03)     1.7930 (1.0)      0.0500 (1.22)     588;1745      547.4165 (1.0)       98727           1
test_advanced_incsubtensor1[set_subtensor-gc=False-static_shape=False]     1.7730 (1.07)     250.6090 (30.88)    1.9281 (1.06)     1.0618 (4.92)     1.8940 (1.06)     0.0500 (1.22)     290;1008      518.6360 (0.95)      81215           1
test_advanced_incsubtensor1[set_subtensor-gc=True-static_shape=True]       1.9040 (1.15)      92.4230 (11.39)    2.0656 (1.13)     0.7314 (3.39)     2.0040 (1.12)     0.0500 (1.22)    1165;5456      484.1202 (0.88)      93730           1
test_advanced_incsubtensor1[set_subtensor-gc=True-static_shape=False]      2.0840 (1.26)      53.9000 (6.64)     2.3187 (1.27)     0.5710 (2.65)     2.3040 (1.28)     0.0910 (2.22)     764;1078      431.2728 (0.79)      88645           1
test_advanced_incsubtensor1[inc_subtensor-gc=False-static_shape=True]      2.4540 (1.48)     163.9370 (20.20)    2.6131 (1.43)     0.8179 (3.79)     2.5750 (1.44)     0.0410 (1.0)      921;2552      382.6924 (0.70)      90827           1
test_advanced_incsubtensor1[inc_subtensor-gc=True-static_shape=True]       2.7160 (1.64)      51.4560 (6.34)     2.9719 (1.63)     0.7313 (3.39)     2.9150 (1.63)     0.0600 (1.46)    1406;4193      336.4881 (0.61)      95612           1
test_advanced_incsubtensor1[inc_subtensor-gc=False-static_shape=False]     2.9050 (1.76)       8.1160 (1.0)      3.0595 (1.67)     0.2158 (1.0)      3.0460 (1.70)     0.0600 (1.46)      120;249      326.8474 (0.60)      23965           1
test_advanced_incsubtensor1[inc_subtensor-gc=True-static_shape=False]      3.2560 (1.97)     229.7100 (28.30)    3.4273 (1.88)     1.3326 (6.17)     3.3760 (1.88)     0.0600 (1.46)     312;1105      291.7782 (0.53)      52149           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

I added a long-missing check for runtime broadcasting in the python/C/torch implementations (would require a bit more code-changes for numba), which moves towards #1348

Provides a restricted case of #1325

Alloc of zeros is also about twice as fast now, which is benchmarked indirectly in the AdvancedIncSubtensor1 tests


📚 Documentation preview 📚: https://pytensor--1346.org.readthedocs.build/en/1346/

@ricardoV94 ricardoV94 changed the title Speedup AdvancedSubtensor1 and AdvancedIncSubtensor1 in C backends Speedup AdvancedSubtensor1 and AdvancedIncSubtensor1 in C backend Apr 7, 2025
@ricardoV94 ricardoV94 force-pushed the faster_indexing branch 4 times, most recently from 337b91f to 2862c42 Compare April 9, 2025 11:42
@ricardoV94 ricardoV94 force-pushed the faster_indexing branch 2 times, most recently from c0045c3 to c211405 Compare April 9, 2025 13:46
Copy link

codecov bot commented Apr 9, 2025

Codecov Report

Attention: Patch coverage is 87.01299% with 10 lines in your changes missing coverage. Please review.

Project coverage is 82.05%. Comparing base (3c66aa6) to head (38f9036).
Report is 4 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/link/pytorch/dispatch/subtensor.py 0.00% 2 Missing and 2 partials ⚠️
pytensor/tensor/rewriting/subtensor.py 63.63% 3 Missing and 1 partial ⚠️
pytensor/link/jax/dispatch/subtensor.py 0.00% 1 Missing and 1 partial ⚠️

❌ Your patch check has failed because the patch coverage (87.01%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1346      +/-   ##
==========================================
+ Coverage   82.04%   82.05%   +0.01%     
==========================================
  Files         203      203              
  Lines       48837    48912      +75     
  Branches     8689     8709      +20     
==========================================
+ Hits        40067    40134      +67     
- Misses       6619     6624       +5     
- Partials     2151     2154       +3     
Files with missing lines Coverage Δ
pytensor/link/numba/dispatch/subtensor.py 95.58% <100.00%> (ø)
pytensor/tensor/basic.py 91.18% <100.00%> (+0.04%) ⬆️
pytensor/tensor/subtensor.py 89.98% <100.00%> (+0.50%) ⬆️
pytensor/link/jax/dispatch/subtensor.py 95.83% <0.00%> (-4.17%) ⬇️
pytensor/link/pytorch/dispatch/subtensor.py 85.55% <0.00%> (-3.98%) ⬇️
pytensor/tensor/rewriting/subtensor.py 89.58% <63.63%> (-0.38%) ⬇️

... and 3 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot reviewed 7 out of 7 changed files in this pull request and generated no comments.

Comments suppressed due to low confidence (3)

tests/tensor/test_subtensor.py:1277

  • The generator expression used for 'inc_var_static_shape' should be converted to a tuple (e.g. tuple(...)) to ensure it produces a concrete shape tuple.
inc_var_static_shape = (1 if dim_length == 1 else None for dim_length in inc_shape)

pytensor/link/pytorch/dispatch/subtensor.py:112

  • The _check_runtime_broadcasting method expects four arguments (including the node), so the node argument should be passed to ensure correct runtime checking.
if isinstance(op, AdvancedIncSubtensor1): op._check_runtime_broadcasting(x, y, indices)

pytensor/link/jax/dispatch/subtensor.py:70

  • Similar to the PyTorch dispatch, the node argument is missing when calling _check_runtime_broadcasting. Ensure that the node is passed as the first parameter after self.
if isinstance(op, AdvancedIncSubtensor1): op._check_runtime_broadcasting(x, y, indices)

@jessegrabowski
Copy link
Member

Useless AI strikes again!

@ricardoV94
Copy link
Member Author

ricardoV94 commented Apr 9, 2025

Useless AI strikes again!

It was actually correct, I'm surprised no tests failed, I guess we are not really covering the dispatch of these in JAX/PyTorch because the Op is only introduced during rewrites

@ricardoV94 ricardoV94 requested a review from Copilot April 9, 2025 18:13
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot reviewed 7 out of 7 changed files in this pull request and generated 1 comment.

Comments suppressed due to low confidence (1)

pytensor/tensor/subtensor.py:2247

  • The condition validating index values in _idx_may_be_invalid is non-obvious; please verify that it correctly handles negative indices and reflects the intended bounds check.
return not (min_idx >= 0 or min_idx >= -shape0) and (max_idx < 0 or max_idx < shape0)

}}

if ((PyArray_NDIM({out}) != 1) || ({unexpected_shape0})) {{
PyErr_SetString(PyExc_ValueError, "AdvancedIncSubtensor1: fist input (x) does not have right shape or ndim");
Copy link
Preview

Copilot AI Apr 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo detected in the error message: 'fist input' should be changed to 'first input'.

Suggested change
PyErr_SetString(PyExc_ValueError, "AdvancedIncSubtensor1: fist input (x) does not have right shape or ndim");
PyErr_SetString(PyExc_ValueError, "AdvancedIncSubtensor1: first input (x) does not have right shape or ndim");

Copilot is powered by AI, so mistakes are possible. Review output carefully before use.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants