-
-
Notifications
You must be signed in to change notification settings - Fork 5.3k
BUG/TST: special.logsumexp
on non-default device
#22756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
special
: run logsumexp
on non-default devicespecial.logsumexp
on non-default device
Closes gh-22680? |
TL;DR yes. It removes scipy's bugs and the user is left with the bugs of their backend of choice. |
"""Test input device propagation to output.""" | ||
x = xp.asarray(x, device=nondefault_device) | ||
assert xp_device(logsumexp(x)) == nondefault_device | ||
assert xp_device(logsumexp(x, b=x)) == nondefault_device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At first, if you're working on a machine with a single GPU for example, this PR may appear to be fairly straightforward.
However, when I test on a node/machine that has multiple GPUs and use i.e., SCIPY_DEVICE=cuda python dev.py test -t scipy/special/tests/test_logsumexp.py::TestLogSumExp::test_device -b cupy
this test will currently fail on this branch:
scipy/special/tests/test_logsumexp.py:299: in test_device
assert xp_device(logsumexp(x)) == nondefault_device
E assert <CUDA Device 0> == <CUDA Device 1>
Does CuPy require special treatment? Do multiple GPUs require special treatment? It isn't immediately obvious to me, but the nature of the failure suggests that device propagation is not working as intended at first glance.
I'm using CuPy 13.3.0
, which is fairly recent. I could try bumping the version maybe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is our ecosystem currently shimming around cupy.cuda.Device
to compensate for not having the device
kwarg on the array coercion for CuPy?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At first, if you're working on a machine with a single GPU for example, this PR may appear to be fairly straightforward.
On a single GPU machine, cupy has only one device and the test introduced by this PR is skipped.
Is our ecosystem currently shimming around
cupy.cuda.Device
to compensate for not having thedevice
kwarg on the array coercion for CuPy?
Yes, by array-api-compat.
Looks like multi-device support was not thought through: data-apis/array-api-compat#293
Well, actually, it probably shouldn't close gh-22680. "Bug" or "lack of support for an experimental feature we eventually want to support", this is a systematic shortcoming of a lot of xp-translated code. There wasn't even GPU default device testing in CI when a lot of the translations were done, and there have been a lot of issues in backends surrounding the |
I've re-enabled the test on torch; however it will exclusively run on a GPU-enabled host with SCIPY_DEVICE=cpu, which is not something that ever happens in CI. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This only superficially has to do with logsumexp
, so I'll provide a superficial review and approval from the logsumexp
side of things.
The changes there are minimal, but they appear to be complete. I think the only xp
functions used by logsumexp
(and private functions) that accept device
are full
, arange
, and asarray
. This PR provides the correct device
argument to full
and arange
, and I think the calls to asarray
can infer the device correctly based on the input.
Setting up the test fixture, etc., is not really in my wheelhouse, so I'll let others comment on that.
Just to continue to providing feedback on the multi-device scenario on the latest version of this branch + the latest version of the cognate array-api-compat branch, the next point of failure for one of the test cases is here: @@ -115,14 +119,18 @@ def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
# Where result is infinite, we use the direct logsumexp calculation to
# delegate edge case handling to the behavior of `xp.log` and `xp.exp`,
# which should follow the C99 standard for complex values.
+ print("xp_device(a) at logsumexp checkpoint 4b:", xp_device(a))
+ xp.exp(a)
b_exp_a = xp.exp(a) if b is None else b * xp.exp(a)
sum_ = xp.sum(b_exp_a, axis=axis, keepdims=True) Even the isolated |
Tested that data-apis/array-api-compat#296 fully fixes PyTorch |
I've reworked the fixture to incorporate prior art from #19900 (@lucascolley). |
# Note workaround when parsing SCIPY_DEVICE above. | ||
# Also note that when SCIPY_DEVICE=cpu this test won't run in CI | ||
# because CUDA-enabled CI boxes always use SCIPY_DEVICE=cuda. | ||
pytest.xfail(reason="pytorch/pytorch#150199") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Workaround: data-apis/array-api-compat#299
# While this issue is specific to jax.jit, it would be unnecessarily | ||
# verbose to skip the test for each jit-capable function and run it for | ||
# those that only support eager mode. | ||
pytest.xfail(reason="jax-ml/jax#26000") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also see jax-ml/jax#27606 (fixed in next JAX release).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks Guido!
This PR sets up a test infrastructure to test device propagation when a function receives in input arrays that don't lay on the default device.
This benefits all multi-device backends, which means PyTorch, JAX, and (for testing purposes only) array-api-strict.
This PR fixes
logsumexp
, and incidentally finds out that there are bugs in both JAX and PyTorch at the moment of writing that prevent this from functioning properly; so I don't plan to extend the same treatment to other functions until the upstream bugs are solved.Upstream issues
asarray
: device does not propagate from input to output afterset_default_device
pytorch/pytorch#150199.device
attribute inside@jax.jit
jax-ml/jax#26000