Skip to content

MAINT: simplify torch dtype promotion #303

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 15, 2025

Conversation

crusaderky
Copy link
Contributor

Salvaged from #298

@Copilot Copilot AI review requested due to automatic review settings April 10, 2025 10:01
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot reviewed 1 out of 1 changed files in this pull request and generated no comments.

Comments suppressed due to low confidence (2)

array_api_compat/torch/_aliases.py:140

  • The try/except block in _result_type silently swallows a KeyError when a dtype pair is not found in the promotion table. Consider handling the error explicitly or adding a clarifying comment to document the intended fallback behavior.
return _promotion_table[xdt, ydt]

array_api_compat/torch/_aliases.py:296

  • Consider adding tests for _sum_prod_no_axis to verify that the dtype conversion and deep copy behavior work as expected, especially for edge cases involving small integer types like uint8.
return x.clone() if dtype == x.dtype else x.to(dtype)

(torch.complex64, torch.complex128): torch.complex128,
(torch.complex128, torch.complex64): torch.complex128,
(torch.complex128, torch.complex128): torch.complex128,
# Mixed float and complex
(torch.float32, torch.complex64): torch.complex64,
(torch.float32, torch.complex128): torch.complex128,
(torch.float64, torch.complex64): torch.complex128,
(torch.float64, torch.complex128): torch.complex128,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(complex, float) use cases were missing

Copy link
Member

@ev-br ev-br left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @crusaderky

if x.dtype in (torch.uint8, torch.int8, torch.int16, torch.int32):
return x.to(torch.int64)

return x.clone()
Copy link
Member

@ev-br ev-br Apr 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: this looks scary, but is in fact just a refactoring. Previously this stanza was duplicated in sum and prod.

Returning a copy looks reasonable, too.

@ev-br ev-br merged commit 9194c5c into data-apis:main Apr 15, 2025
40 checks passed
@ev-br ev-br added this to the 1.12 milestone Apr 15, 2025
@crusaderky crusaderky deleted the torch_promotion_table branch April 15, 2025 12:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants