-
-
Notifications
You must be signed in to change notification settings - Fork 153
Fix the JAX Scan dispatcher
#1202
base: main
Are you sure you want to change the base?
Conversation
|
We can actually work around a lot of the dynamic indexing issues with Then I'll keep making my way through testing more and more scan features. |
e63a54a to
6c61c95
Compare
6c61c95 to
4df1d5a
Compare
4df1d5a to
d1326b8
Compare
|
After a lot of messing around I decided to go for a full rewrite and follow the Numba implementation. I have a minimal version that passes the first 3 That JAX easily complains about dynamic slicing may be a blessing in disguise as it highlights some gaps in Aesara's rewrites, e.g. with #1257 and others. Workarounds that I have currently had to implement could be easily avoided using the adequate rewrites at compile time. I also switched to run the test without rewrites, and I should probably start gathering a set of rewrites that would help with transpilation. How would we go about having backend-specific rewrites?
|
3cf33a6 to
6f5d668
Compare
Numba mode already specializes its rewrites, so check out its definition in |
4bd71bb to
a1b7b5c
Compare
a1b7b5c to
4bd71bb
Compare
4bd71bb to
1212a1c
Compare
|
This is turning into a much bigger PR than expected as I am also trying to fix any issue that prevents me from running the
While I'm at it I am going to fix as many known issues with the JAX dispatcher as possible (issues and tests marked as |
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #1202 +/- ##
==========================================
+ Coverage 74.35% 74.49% +0.14%
==========================================
Files 177 173 -4
Lines 49046 48658 -388
Branches 10379 10390 +11
==========================================
- Hits 36468 36250 -218
+ Misses 10285 10112 -173
- Partials 2293 2296 +3
|
a761818 to
0183921
Compare
|
The following test with a def test_nit_sot_shared():
res, updates = scan(
fn=lambda: RandomStream(seed=1930, rng_ctor=np.random.RandomState).normal(
0, 1, name="a"
),
n_steps=3,
)
jax_fn = function((), res, updates=updates, mode="JAX")
jax_res = jax_fn()
assert jax_res.shape == (3,)The values are correct, but the 65 gen_keys = ["bit_generator", "gauss", "has_gauss", "state"]
66 state_keys = ["key", "pos"]
67
68 for key in gen_keys:
69 if key not in data:
70 raise TypeError()
71
72 for key in state_keys:
73 if key not in data["state"]:
74 raise TypeError()
75
76 state_key = data["state"]["key"]
77 if state_key.shape == (624,) and state_key.dtype == np.uint32:
78 # TODO: Add an option to convert to a `RandomState` instance?
79 return dataIndeed, the shared state for random variables in the JAX backend also contains a The Plus I don't think we need to carry this state around in the JAX backend, isn't only |
0183921 to
558ade1
Compare
973ec08 to
3de5fb3
Compare
2232d76 to
c7097dd
Compare
brandonwillard
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the commit entitled "Return a scalar when the tensor values is a scalar", is there an associated MWE/test case?
Also, the commit description mentions that ScalarFromTensor is being called on scalars, and I want to make sure that those input scalars are TensorTypes scalars, and not ScalarType scalars. The latter would imply that we're missing a rewrite for useless ScalarFromTensors.
2e7b3a8 to
fd37b21
Compare
import aesara
import aesara.tensor as at
a = at.iscalar("a")
x = at.arange(3)
out = x[:a]
aesara.dprint(out)
# Subtensor{:int32:} [id A]
# |ARange{dtype='int64'} [id B]
# | |TensorConstant{0} [id C]
# | |TensorConstant{3} [id D]
# | |TensorConstant{1} [id E]
# |ScalarFromTensor [id F]
# |a [id G]
try:
fn = aesara.function((a,), out, mode="JAX")
fn(1)
except Exception as e:
print(f"\n{e}")
# Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(None, Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).
# Apply node that caused the error: DeepCopyOp(Subtensor{:int32:}.0)
# Toposort index: 2
# Inputs types: [TensorType(int64, (None,))]
# Inputs shapes: [()]
# Inputs strides: [()]
# Inputs values: [array(1, dtype=int32)]
# Outputs clients: [['output']]In this case there are two solutions:
|
As I recall, the trouble with using that is that it's limited to only the (outermost) graph inputs, and we can't compose |
We can always ask users to JIT-compile functions themselves if that's the case, and raise a warning at compilation ("JAX will only be able to JIT-compile your function if you specifiy the {input_position}-th argument ({variable_name}) as static"). Given the number of issues with the JAX backend this work is uncovering, I decided to break the changes down in several smaller PRs and fix the issues unrelated to |
d718bdd to
e528e44
Compare
|
The following code fails: import aesara
import aesara.tensor as at
a_at = at.dvector("a")
res, updates = aesara.scan(
fn=lambda a_t: 2 * a_t,
sequences=a_at
)
jax_fn = aesara.function((a_at,), res, updates=updates, mode="JAX")
jax_fn([0, 1, 2, 3, 4])
# IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).
# Apply node that caused the error: Elemwise{mul,no_inplace}(TensorConstant{(1,) of 2.0}, Subtensor{int64:int64:int8}.0)I print the associated function graph: aesara.dprint(jax_fn)
# Elemwise{mul,no_inplace} [id A] 5
# |TensorConstant{(1,) of 2.0} [id B]
# |Subtensor{int64:int64:int8} [id C] 4
# |a [id D]
# |ScalarFromTensor [id E] 3
# | |Elemwise{Composite{Switch(LE(i0, i1), i1, i2)}}[(0, 0)] [id F] 2
# | |Shape_i{0} [id G] 0
# | | |a [id D]
# | |TensorConstant{0} [id H]
# | |TensorConstant{0} [id I]
# |ScalarFromTensor [id J] 1
# | |Shape_i{0} [id G] 0
# |ScalarConstant{1} [id K]
jax_fn.maker.fgraph.toposort()[4].tag
# scratchpad{'imported_by': ['local_subtensor_merge']}
jax_fn.maker.fgraph.toposort()[2].tag
# scratchpad{'imported_by': ['inplace_elemwise_optimizer']}
# jax_fn.maker.fgraph.toposort()[1].tag
scratchpad{'imported_by': ['local_subtensor_merge']}Several remarks:
|
|
The following code also fails, because of an import aesara
import aesara.tensor as at
from aesara.compile.mode import Mode
from aesara.graph.rewriting.db import RewriteDatabaseQuery
from aesara.link.jax.linker import JAXLinker
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = Mode(JAXLinker(), opts)
res, updates = aesara.scan(
fn=lambda a_tm1, b_tm1: (2 * a_tm1, 2 * b_tm1),
outputs_info=[
{"initial": at.as_tensor(1.0, dtype="floatX"), "taps": [-1]},
{"initial": at.as_tensor(0.5, dtype="floatX"), "taps": [-1]},
],
n_steps=10,
)
jax_fn = function((), res, updates=updates, mode=jax_mode)
aesara.dprint(jax_fn)
# Subtensor{int64::} [id A] 17
# |for{cpu,scan_fn}.0 [id B] 16
# | |TensorConstant{10} [id C]
# | |IncSubtensor{Set;:int64:} [id D] 15
# | | |AllocEmpty{dtype='float64'} [id E] 14
# | | | |Elemwise{add,no_inplace} [id F] 13
# | | | |TensorConstant{10} [id C]
# | | | |Subtensor{int64} [id G] 11
# | | | |Shape [id H] 10
# | | | | |Unbroadcast{0} [id I] 9
# | | | | |InplaceDimShuffle{x} [id J] 8
# | | | | |TensorConstant{1.0} [id K]
# | | | |ScalarConstant{0} [id L]
# | | |Unbroadcast{0} [id I] 9
# | | |ScalarFromTensor [id M] 12
# | | |Subtensor{int64} [id G] 11
# | |IncSubtensor{Set;:int64:} [id N] 7
# | |AllocEmpty{dtype='float64'} [id O] 6
# | | |Elemwise{add,no_inplace} [id P] 5
# | | |TensorConstant{10} [id C]
# | | |Subtensor{int64} [id Q] 3
# | | |Shape [id R] 2
# | | | |Unbroadcast{0} [id S] 1
# | | | |InplaceDimShuffle{x} [id T] 0
# | | | |TensorConstant{0.5} [id U]
# | | |ScalarConstant{0} [id V]
# | |Unbroadcast{0} [id S] 1
# | |ScalarFromTensor [id W] 4
# | |Subtensor{int64} [id Q] 3
# |ScalarConstant{1} [id X]
# Subtensor{int64::} [id Y] 18
# |for{cpu,scan_fn}.1 [id B] 16
# |ScalarConstant{1} [id Z]
# Inner graphs:
# for{cpu,scan_fn}.0 [id B]
# >Elemwise{mul,no_inplace} [id BA]
# > |TensorConstant{2} [id BB]
# > |*0-<TensorType(float64, ())> [id BC] -> [id D]
# >Elemwise{mul,no_inplace} [id BD]
# > |TensorConstant{2} [id BE]
# > |*1-<TensorType(float64, ())> [id BF] -> [id N]
# for{cpu,scan_fn}.1 [id B]
# >Elemwise{mul,no_inplace} [id BA]
# >Elemwise{mul,no_inplace} [id BD]JAX indeed complains that the input to import aesara
import aesara.tensor as at
from aesara.compile.mode import Mode
from aesara.graph.rewriting.db import RewriteDatabaseQuery
from aesara.link.jax.linker import JAXLinker
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = Mode(JAXLinker(), opts)
res, updates = aesara.scan(
fn=lambda a_tm1: 2 * a_tm1,
outputs_info=[
{"initial": at.as_tensor([0.0, 1.0], dtype="floatX"), "taps": [-2]}
],
n_steps=6,
)
jax_fn = function((), res, updates=updates, mode=jax_mode)
aesara.dprint(jax_fn)
# Subtensor{int64::} [id A] 8
# |for{cpu,scan_fn} [id B] 7
# | |TensorConstant{6} [id C]
# | |IncSubtensor{Set;:int64:} [id D] 6
# | |AllocEmpty{dtype='float64'} [id E] 5
# | | |Elemwise{add,no_inplace} [id F] 4
# | | |TensorConstant{6} [id C]
# | | |Subtensor{int64} [id G] 2
# | | |Shape [id H] 1
# | | | |Subtensor{:int64:} [id I] 0
# | | | |TensorConstant{[0. 1.]} [id J]
# | | | |ScalarConstant{2} [id K]
# | | |ScalarConstant{0} [id L]
# | |Subtensor{:int64:} [id I] 0
# | |ScalarFromTensor [id M] 3
# | |Subtensor{int64} [id G] 2
# |ScalarConstant{2} [id N]
# Inner graphs:
# for{cpu,scan_fn} [id B]
# >Elemwise{mul,no_inplace} [id O]
# > |TensorConstant{2} [id P]
# > |*0-<TensorType(float64, ())> [id Q] -> [id D]
# fn = function((), res, updates=updates)
# assert np.allclose(fn(), jax_fn()) |
e528e44 to
a088c4b
Compare
|
I am currently waiting for #1338 to be merged to see what else needs to be fixed in the backend to allow the tests to pass. |
5eafcd5 to
0932c8e
Compare
4912edc to
b09a40e
Compare
This PR tries to address the issues observed in #710 and #924 with the transpilation of
Scanoperators. Most importantly, we increase the test coverage ofScan's functionalities.