Skip to content

BUG/TST: special.logsumexp on non-default device #22756

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 4, 2025

Conversation

crusaderky
Copy link
Contributor

@crusaderky crusaderky commented Mar 28, 2025

This PR sets up a test infrastructure to test device propagation when a function receives in input arrays that don't lay on the default device.
This benefits all multi-device backends, which means PyTorch, JAX, and (for testing purposes only) array-api-strict.

This PR fixes logsumexp, and incidentally finds out that there are bugs in both JAX and PyTorch at the moment of writing that prevent this from functioning properly; so I don't plan to extend the same treatment to other functions until the upstream bugs are solved.

Upstream issues

@github-actions github-actions bot added scipy.special defect A clear bug or issue that prevents SciPy from being installed or used as expected maintenance Items related to regular maintenance tasks labels Mar 28, 2025
@crusaderky crusaderky changed the title BUG/TST: special: run logsumexp on non-default device BUG/TST: special.logsumexp on non-default device Mar 28, 2025
@mdhaber
Copy link
Contributor

mdhaber commented Mar 28, 2025

Closes gh-22680?

@crusaderky
Copy link
Contributor Author

Closes gh-22680?

  • Yes on array-api-strict;
  • Yes on PyTorch, as long as the user didn't call torch.set_default_device(...);
  • Yes on JAX, as long as the user didn't wrap it in jax.jit;

TL;DR yes. It removes scipy's bugs and the user is left with the bugs of their backend of choice.

"""Test input device propagation to output."""
x = xp.asarray(x, device=nondefault_device)
assert xp_device(logsumexp(x)) == nondefault_device
assert xp_device(logsumexp(x, b=x)) == nondefault_device
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At first, if you're working on a machine with a single GPU for example, this PR may appear to be fairly straightforward.

However, when I test on a node/machine that has multiple GPUs and use i.e., SCIPY_DEVICE=cuda python dev.py test -t scipy/special/tests/test_logsumexp.py::TestLogSumExp::test_device -b cupy

this test will currently fail on this branch:

scipy/special/tests/test_logsumexp.py:299: in test_device
    assert xp_device(logsumexp(x)) == nondefault_device
E   assert <CUDA Device 0> == <CUDA Device 1>

Does CuPy require special treatment? Do multiple GPUs require special treatment? It isn't immediately obvious to me, but the nature of the failure suggests that device propagation is not working as intended at first glance.

I'm using CuPy 13.3.0, which is fairly recent. I could try bumping the version maybe.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is our ecosystem currently shimming around cupy.cuda.Device to compensate for not having the device kwarg on the array coercion for CuPy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At first, if you're working on a machine with a single GPU for example, this PR may appear to be fairly straightforward.

On a single GPU machine, cupy has only one device and the test introduced by this PR is skipped.

Is our ecosystem currently shimming around cupy.cuda.Device to compensate for not having the device kwarg on the array coercion for CuPy?

Yes, by array-api-compat.
Looks like multi-device support was not thought through: data-apis/array-api-compat#293

@mdhaber
Copy link
Contributor

mdhaber commented Mar 28, 2025

Well, actually, it probably shouldn't close gh-22680. "Bug" or "lack of support for an experimental feature we eventually want to support", this is a systematic shortcoming of a lot of xp-translated code. There wasn't even GPU default device testing in CI when a lot of the translations were done, and there have been a lot of issues in backends surrounding the device keyword, so it wasn't used anywhere except fft and some parts of signal. Once we can actually test non-default device, we can start adding its use throughout. That should still be tracked by gh-22680 or a new issue.

@crusaderky
Copy link
Contributor Author

I've re-enabled the test on torch; however it will exclusively run on a GPU-enabled host with SCIPY_DEVICE=cpu, which is not something that ever happens in CI.
I've xfailed the test on CuPy and tracking progress on data-apis/array-api-compat#293. I have low expectations that it is fixable because we can't override method. Full discussion in the linked PR.

Copy link
Contributor

@mdhaber mdhaber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only superficially has to do with logsumexp, so I'll provide a superficial review and approval from the logsumexp side of things.

The changes there are minimal, but they appear to be complete. I think the only xp functions used by logsumexp (and private functions) that accept device are full, arange, and asarray. This PR provides the correct device argument to full and arange, and I think the calls to asarray can infer the device correctly based on the input.

Setting up the test fixture, etc., is not really in my wheelhouse, so I'll let others comment on that.

@tylerjereddy
Copy link
Contributor

Just to continue to providing feedback on the multi-device scenario on the latest version of this branch + the latest version of the cognate array-api-compat branch, the next point of failure for one of the test cases is here:

@@ -115,14 +119,18 @@ def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
             # Where result is infinite, we use the direct logsumexp calculation to
             # delegate edge case handling to the behavior of `xp.log` and `xp.exp`,
             # which should follow the C99 standard for complex values.
+            print("xp_device(a) at logsumexp checkpoint 4b:", xp_device(a))
+            xp.exp(a)
             b_exp_a = xp.exp(a) if b is None else b * xp.exp(a)
             sum_ = xp.sum(b_exp_a, axis=axis, keepdims=True)

Even the isolated xp.exp(a) on its own, which I manually added there, fails with: ValueError: The device where the array resides (1) is different from the current device (0). Peer access is unavailable between these devices.. The confusion around this is perhaps better discussed over at the array-api-compat PR though, and you've xfailed the tests with CuPy for now, so not trying to block this.

@crusaderky
Copy link
Contributor Author

Tested that data-apis/array-api-compat#296 fully fixes PyTorch

@crusaderky
Copy link
Contributor Author

I've reworked the fixture to incorporate prior art from #19900 (@lucascolley).

# Note workaround when parsing SCIPY_DEVICE above.
# Also note that when SCIPY_DEVICE=cpu this test won't run in CI
# because CUDA-enabled CI boxes always use SCIPY_DEVICE=cuda.
pytest.xfail(reason="pytorch/pytorch#150199")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# While this issue is specific to jax.jit, it would be unnecessarily
# verbose to skip the test for each jit-capable function and run it for
# those that only support eager mode.
pytest.xfail(reason="jax-ml/jax#26000")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also see jax-ml/jax#27606 (fixed in next JAX release).

Copy link
Member

@lucascolley lucascolley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks Guido!

@lucascolley lucascolley added this to the 1.16.0 milestone Apr 4, 2025
@lucascolley lucascolley merged commit bc5d86a into scipy:main Apr 4, 2025
39 of 41 checks passed
@crusaderky crusaderky deleted the logsumexp_device branch April 4, 2025 14:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
array types Items related to array API support and input array validation (see gh-18286) defect A clear bug or issue that prevents SciPy from being installed or used as expected maintenance Items related to regular maintenance tasks scipy.special
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants