-
Notifications
You must be signed in to change notification settings - Fork 33
MAINT: simplify torch
dtype promotion
#303
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot reviewed 1 out of 1 changed files in this pull request and generated no comments.
Comments suppressed due to low confidence (2)
array_api_compat/torch/_aliases.py:140
- The try/except block in _result_type silently swallows a KeyError when a dtype pair is not found in the promotion table. Consider handling the error explicitly or adding a clarifying comment to document the intended fallback behavior.
return _promotion_table[xdt, ydt]
array_api_compat/torch/_aliases.py:296
- Consider adding tests for _sum_prod_no_axis to verify that the dtype conversion and deep copy behavior work as expected, especially for edge cases involving small integer types like uint8.
return x.clone() if dtype == x.dtype else x.to(dtype)
(torch.complex64, torch.complex128): torch.complex128, | ||
(torch.complex128, torch.complex64): torch.complex128, | ||
(torch.complex128, torch.complex128): torch.complex128, | ||
# Mixed float and complex | ||
(torch.float32, torch.complex64): torch.complex64, | ||
(torch.float32, torch.complex128): torch.complex128, | ||
(torch.float64, torch.complex64): torch.complex128, | ||
(torch.float64, torch.complex128): torch.complex128, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(complex, float)
use cases were missing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @crusaderky
if x.dtype in (torch.uint8, torch.int8, torch.int16, torch.int32): | ||
return x.to(torch.int64) | ||
|
||
return x.clone() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to self: this looks scary, but is in fact just a refactoring. Previously this stanza was duplicated in sum
and prod
.
Returning a copy looks reasonable, too.
Salvaged from #298
sum(x, axis=(), dtype=x.dtype)
andprod
with the same parameters, which were previously returningx
itself, to return a deep copy instead. While this is not part of the Array API, it's what both numpy an base torch do, and probably the most healthy thing to do xref https://data-apis.org/array-api/latest/design_topics/copies_views_and_mutation.html.Discussion here: Spell out where views are allowed array-api#921