-
Notifications
You must be signed in to change notification settings - Fork 149
Numba CAReduce: respect acc_dtype #1773
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
3c61b88 to
ad960d6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR fixes the Numba backend implementation of CAReduce operations to properly respect the acc_dtype (accumulation dtype) parameter and corrects the handling of infinity identities for unsigned integer types.
Key Changes:
- Modified
create_multiaxis_reducerto properly handleacc_dtypeparameter separate fromout_dtype, including support for complex-to-real conversions - Fixed infinity identity values for discrete (unsigned/signed integer) types by replacing infinite values with appropriate min/max values
- Updated
__str__methods inCAReduceandFixedOpCAReduceto display accumulation dtype when it differs from output dtype
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| tests/link/numba/test_elemwise.py | Adds three new tests: verifying acc_dtype is respected during accumulation, testing complex-to-float conversions, and testing discrete infinity identities |
| pytensor/tensor/math.py | Updates FixedOpCAReduce.__str__() to display acc_dtype in string representation when different from output dtype |
| pytensor/tensor/elemwise.py | Updates CAReduce.__str__() to display acc_dtype in string representation when different from output dtype |
| pytensor/link/numba/dispatch/elemwise.py | Core implementation changes: refactors create_multiaxis_reducer to handle acc_dtype parameter, fixes infinity identities for discrete types, updates Softmax/SoftmaxGrad/LogSoftmax to use keyword arguments, adds cache versioning |
Also fix infinity identities for unsigned integers
ad960d6 to
690015f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.
| res_shape = x_shape[2] | ||
| res = np.full(res_shape, numba_basic.to_scalar(0.0), dtype=out_dtype) | ||
| res_shape = (x_shape[0], x_shape[1]) | ||
| # identity = 0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this commented?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's to make codegen readable when debugging, otherwise you just see a global identity being used but won't know its value
| # if len(axes) == 1: | ||
| # return create_axis_reducer(scalar_op, identity, axes[0], ndim, dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might as well remove this while we're here?
| def {careduce_fn_name}(x): | ||
| x_shape = x.shape | ||
| res_shape = {res_shape} | ||
| # identity = {identity} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't get it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just for readability
Also fix infinity identities for unsigned integers
Cherry pick from #811