Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit d718bdd

Browse files
committed
Merge branch 'jax-fix-subtensor' into rewrite-jax-scan
2 parents d43bf96 + d2b41c4 commit d718bdd

File tree

148 files changed

+3416
-2791
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

148 files changed

+3416
-2791
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ jobs:
182182
path: coverage
183183

184184
- name: Upload coverage to Codecov
185-
uses: codecov/codecov-action@v1
185+
uses: codecov/codecov-action@v3
186186
with:
187187
directory: ./coverage/
188188
fail_ci_if_error: true

.pre-commit-config.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ exclude: |
77
)$
88
repos:
99
- repo: https://github.com/pre-commit/pre-commit-hooks
10-
rev: v4.3.0
10+
rev: v4.4.0
1111
hooks:
1212
- id: debug-statements
1313
exclude: |
@@ -25,9 +25,10 @@ repos:
2525
- id: black
2626
language_version: python3
2727
- repo: https://github.com/pycqa/flake8
28-
rev: 5.0.4
28+
rev: 6.0.0
2929
hooks:
3030
- id: flake8
31+
language_version: python39
3132
- repo: https://github.com/pycqa/isort
3233
rev: 5.10.1
3334
hooks:
@@ -47,7 +48,7 @@ repos:
4748
)$
4849
args: ['--in-place', '--remove-all-unused-imports', '--remove-unused-variable']
4950
- repo: https://github.com/pre-commit/mirrors-mypy
50-
rev: v0.982
51+
rev: v0.991
5152
hooks:
5253
- id: mypy
5354
additional_dependencies:

README.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,10 @@ Contributing
122122
We welcome bug reports and fixes and improvements to the documentation.
123123

124124
For more information on contributing, please see the
125-
`contributing guide <https://github.com/aesara-devs/aesara/CONTRIBUTING.md>`.
125+
`contributing guide <https://github.com/aesara-devs/aesara/CONTRIBUTING.md>`__.
126126

127127
A good place to start contributing is by looking through the issues
128-
`here <https://github.com/aesara-devs/aesara/issues`.
128+
`here <https://github.com/aesara-devs/aesara/issues>`__.
129129

130130
Support
131131
=======

aesara/compile/builders.py

Lines changed: 84 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from collections import OrderedDict
33
from copy import copy
44
from functools import partial
5-
from typing import List, Optional, Sequence, cast
5+
from typing import Dict, List, Optional, Sequence, Tuple, cast
66

77
import aesara.tensor as at
88
from aesara import function
@@ -19,7 +19,6 @@
1919
clone_replace,
2020
graph_inputs,
2121
io_connection_pattern,
22-
replace_nominals_with_dummies,
2322
)
2423
from aesara.graph.fg import FunctionGraph
2524
from aesara.graph.null_type import NullType
@@ -82,6 +81,81 @@ def local_traverse(out):
8281
return ret
8382

8483

84+
def construct_nominal_fgraph(
85+
inputs: Sequence[Variable], outputs: Sequence[Variable]
86+
) -> Tuple[
87+
FunctionGraph,
88+
Sequence[Variable],
89+
Dict[Variable, Variable],
90+
Dict[Variable, Variable],
91+
]:
92+
"""Construct an inner-`FunctionGraph` with ordered nominal inputs."""
93+
dummy_inputs = []
94+
for n, inp in enumerate(inputs):
95+
if (
96+
not isinstance(inp, Variable)
97+
or isinstance(inp, Constant)
98+
or isinstance(inp, SharedVariable)
99+
):
100+
raise TypeError(
101+
f"Inputs and outputs must be non-Constant/shared Variable instances; got {inp}"
102+
)
103+
104+
dummy_inputs.append(inp.type())
105+
106+
dummy_shared_inputs = []
107+
shared_inputs = []
108+
for var in graph_inputs(outputs, inputs):
109+
if isinstance(var, SharedVariable):
110+
# To correctly support shared variables the inner-graph should
111+
# not see them; otherwise, there will be problems with
112+
# gradients.
113+
# That's why we collect the shared variables and replace them
114+
# with dummies.
115+
shared_inputs.append(var)
116+
dummy_shared_inputs.append(var.type())
117+
elif var not in inputs and not isinstance(var, Constant):
118+
raise MissingInputError(f"OpFromGraph is missing an input: {var}")
119+
120+
replacements = dict(zip(inputs + shared_inputs, dummy_inputs + dummy_shared_inputs))
121+
122+
new = rebuild_collect_shared(
123+
cast(Sequence[Variable], outputs),
124+
inputs=inputs + shared_inputs,
125+
replace=replacements,
126+
copy_inputs_over=False,
127+
)
128+
(
129+
local_inputs,
130+
local_outputs,
131+
(clone_d, update_d, update_expr, new_shared_inputs),
132+
) = new
133+
134+
assert len(local_inputs) == len(inputs) + len(shared_inputs)
135+
assert len(local_outputs) == len(outputs)
136+
assert not update_d
137+
assert not update_expr
138+
assert not new_shared_inputs
139+
140+
fgraph = FunctionGraph(local_inputs, local_outputs, clone=False)
141+
142+
# The inputs need to be `NominalVariable`s so that we can merge
143+
# inner-graphs
144+
nominal_local_inputs = tuple(
145+
NominalVariable(n, var.type) for n, var in enumerate(local_inputs)
146+
)
147+
148+
fgraph.replace_all(zip(local_inputs, nominal_local_inputs))
149+
150+
for i, inp in enumerate(fgraph.inputs):
151+
nom_inp = nominal_local_inputs[i]
152+
fgraph.inputs[i] = nom_inp
153+
fgraph.clients.pop(inp, None)
154+
fgraph.add_input(nom_inp)
155+
156+
return fgraph, shared_inputs, update_d, update_expr
157+
158+
85159
class OpFromGraph(Op, HasInnerGraph):
86160
r"""
87161
This creates an `Op` from inputs and outputs lists of variables.
@@ -333,66 +407,21 @@ def __init__(
333407
if not (isinstance(inputs, list) and isinstance(outputs, list)):
334408
raise TypeError("Inputs and outputs must be lists")
335409

336-
for i in inputs + outputs:
337-
if not isinstance(i, Variable):
410+
for out in outputs:
411+
if not isinstance(out, Variable):
338412
raise TypeError(
339-
f"Inputs and outputs must be Variable instances; got {i}"
413+
f"Inputs and outputs must be Variable instances; got {out}"
340414
)
341-
if i in inputs:
342-
if isinstance(i, Constant):
343-
raise TypeError(f"Constants not allowed as inputs; {i}")
344-
if isinstance(i, SharedVariable):
345-
raise TypeError(f"SharedVariables not allowed as inputs; {i}")
346-
347-
for var in graph_inputs(outputs, inputs):
348-
if var not in inputs and not isinstance(var, (Constant, SharedVariable)):
349-
raise MissingInputError(f"OpFromGraph is missing an input: {var}")
350415

351416
if "updates" in kwargs or "givens" in kwargs:
352-
raise NotImplementedError("Updates and givens are not allowed here")
417+
raise NotImplementedError("Updates and givens are not supported")
353418

354419
self.is_inline = inline
355420

356-
# To correctly support shared variables the inner fct should
357-
# not see them. Otherwise there is a problem with the gradient.
358-
self.shared_inputs = []
359-
for var in graph_inputs(outputs):
360-
if isinstance(var, SharedVariable):
361-
self.shared_inputs.append(var)
362-
363-
inputs, outputs = replace_nominals_with_dummies(inputs, outputs)
364-
365-
# The inputs should be `NominalVariable`s, so that graphs can be merged
366-
replacements = {}
367-
for n, v in enumerate(inputs):
368-
replacements[v] = NominalVariable(n, v.type)
369-
370-
shared_vars = [
371-
NominalVariable(n, var.type)
372-
for n, var in enumerate(self.shared_inputs, start=len(inputs) + 1)
373-
]
374-
375-
replacements.update(dict(zip(self.shared_inputs, shared_vars)))
376-
377-
new = rebuild_collect_shared(
378-
cast(Sequence[Variable], outputs),
379-
inputs=inputs + shared_vars,
380-
replace=replacements,
381-
copy_inputs_over=False,
421+
self.fgraph, self.shared_inputs, _, _ = construct_nominal_fgraph(
422+
inputs, outputs
382423
)
383-
(
384-
local_inputs,
385-
local_outputs,
386-
(clone_d, update_d, update_expr, shared_inputs),
387-
) = new
388-
389-
assert len(local_inputs) == len(inputs) + len(self.shared_inputs)
390-
assert len(local_outputs) == len(outputs)
391-
assert not update_d
392-
assert not update_expr
393-
assert not shared_inputs
394-
395-
self.fgraph = FunctionGraph(local_inputs, local_outputs, clone=False)
424+
396425
self.kwargs = kwargs
397426
self.input_types = [inp.type for inp in inputs]
398427
self.output_types = [out.type for out in outputs]
@@ -415,6 +444,7 @@ def __init__(
415444
else:
416445
self.set_lop_overrides("default")
417446
self._lop_type = "lop"
447+
418448
self.set_rop_overrides(rop_overrides)
419449

420450
self._connection_pattern = connection_pattern

aesara/compile/debugmode.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -848,17 +848,17 @@ def _get_preallocated_maps(
848848
or "ALL" in prealloc_modes
849849
):
850850
max_ndim = 0
851-
rev_out_broadcastable = []
851+
rev_out_shape = []
852852
for r in considered_outputs:
853853
if isinstance(r.type, TensorType):
854854
if max_ndim < r.ndim:
855-
rev_out_broadcastable += [True] * (r.ndim - max_ndim)
855+
rev_out_shape += [1] * (r.ndim - max_ndim)
856856
max_ndim = r.ndim
857-
assert len(rev_out_broadcastable) == max_ndim
857+
assert len(rev_out_shape) == max_ndim
858858

859-
for i, b in enumerate(r.broadcastable[::-1]):
860-
rev_out_broadcastable[i] = rev_out_broadcastable[i] and b
861-
out_broadcastable = rev_out_broadcastable[::-1]
859+
for i, s in enumerate(r.type.shape[::-1]):
860+
rev_out_shape[i] = 1 if rev_out_shape[i] == 1 and s == 1 else None
861+
out_shape = rev_out_shape[::-1]
862862

863863
if "strided" in prealloc_modes or "ALL" in prealloc_modes:
864864
check_ndim = config.DebugMode__check_preallocated_output_ndim
@@ -887,14 +887,14 @@ def _get_preallocated_maps(
887887
# Moreover, to avoid memory problems, we do not test with strides
888888
# 2 and -2 on those dimensions.
889889
step_signs_list = []
890-
for b in out_broadcastable[-check_ndim:]:
891-
if b:
890+
for s in out_shape[-check_ndim:]:
891+
if s == 1:
892892
step_signs_list.append((1,))
893893
else:
894894
step_signs_list.append((-1, 1))
895895

896896
# Use the same step on all dimensions before the last check_ndim.
897-
if all(out_broadcastable[:-check_ndim]):
897+
if all(s == 1 for s in out_shape[:-check_ndim]):
898898
step_signs_list = [(1,)] + step_signs_list
899899
else:
900900
step_signs_list = [(-1, 1)] + step_signs_list
@@ -905,7 +905,7 @@ def _get_preallocated_maps(
905905

906906
# First, the dimensions above check_ndim, then the other ones
907907
# Do not test with 2 or -2 for dimensions above check_ndim
908-
steps = [step_signs[0]] * len(out_broadcastable[:-check_ndim])
908+
steps = [step_signs[0]] * len(out_shape[:-check_ndim])
909909
steps += [s * step_size for s in step_signs[1:]]
910910

911911
name = f"strided{tuple(steps)}"
@@ -932,8 +932,8 @@ def _get_preallocated_maps(
932932

933933
if "wrong_size" in prealloc_modes or "ALL" in prealloc_modes:
934934
# For each dimension, try size-1, size, size+1
935-
for dim, b in enumerate(out_broadcastable):
936-
if b:
935+
for dim, s in enumerate(out_shape):
936+
if s == 1:
937937
# The shape has to be 1
938938
continue
939939

@@ -947,11 +947,11 @@ def _get_preallocated_maps(
947947
for r in considered_outputs:
948948
if isinstance(r.type, TensorType):
949949
r_shape_diff = shape_diff[: r.ndim]
950-
out_shape = [
950+
new_buf_shape = [
951951
max((s + sd), 0)
952952
for s, sd in zip(r_vals[r].shape, r_shape_diff)
953953
]
954-
new_buf = np.empty(out_shape, dtype=r.type.dtype)
954+
new_buf = np.empty(new_buf_shape, dtype=r.type.dtype)
955955
new_buf[...] = np.asarray(def_val).astype(r.type.dtype)
956956
wrong_size[r] = new_buf
957957

aesara/compile/function/pfunc.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
44
"""
55

6-
import logging
76
from copy import copy
87
from typing import Optional
98

@@ -16,11 +15,6 @@
1615
from aesara.graph.fg import FunctionGraph
1716

1817

19-
_logger = logging.getLogger("aesara.compile.function.pfunc")
20-
21-
__docformat__ = "restructuredtext en"
22-
23-
2418
def rebuild_collect_shared(
2519
outputs,
2620
inputs=None,
@@ -78,10 +72,12 @@ def rebuild_collect_shared(
7872
shared_inputs = []
7973

8074
def clone_v_get_shared_updates(v, copy_inputs_over):
81-
"""
82-
Clones a variable and its inputs recursively until all are in clone_d.
83-
Also appends all shared variables met along the way to shared inputs,
84-
and their default_update (if applicable) to update_d and update_expr.
75+
r"""Clones a variable and its inputs recursively until all are in `clone_d`.
76+
77+
Also, it appends all `SharedVariable`\s met along the way to
78+
`shared_inputs` and their corresponding
79+
`SharedVariable.default_update`\s (when applicable) to `update_d` and
80+
`update_expr`.
8581
8682
"""
8783
# this co-recurses with clone_a
@@ -103,7 +99,7 @@ def clone_v_get_shared_updates(v, copy_inputs_over):
10399
elif isinstance(v, SharedVariable):
104100
if v not in shared_inputs:
105101
shared_inputs.append(v)
106-
if hasattr(v, "default_update"):
102+
if v.default_update is not None:
107103
# Check that v should not be excluded from the default
108104
# updates list
109105
if no_default_updates is False or (
@@ -419,22 +415,24 @@ def construct_pfunc_ins_and_outs(
419415
givens = []
420416

421417
if not isinstance(params, (list, tuple)):
422-
raise Exception("in pfunc() the first argument must be a list or " "a tuple")
418+
raise TypeError("The `params` argument must be a list or a tuple")
423419

424420
if not isinstance(no_default_updates, bool) and not isinstance(
425421
no_default_updates, list
426422
):
427-
raise TypeError("no_default_update should be either a boolean or " "a list")
423+
raise TypeError("The `no_default_update` argument must be a boolean or list")
428424

429-
if len(updates) > 0 and any(
430-
isinstance(v, Variable) for v in iter_over_pairs(updates)
425+
if len(updates) > 0 and not all(
426+
isinstance(pair, (tuple, list))
427+
and len(pair) == 2
428+
and isinstance(pair[0], Variable)
429+
for pair in iter_over_pairs(updates)
431430
):
432-
raise ValueError(
433-
"The updates parameter must be an OrderedDict/dict or a list of "
434-
"lists/tuples with 2 elements"
431+
raise TypeError(
432+
"The `updates` parameter must be an ordered mapping or a list of pairs"
435433
)
436434

437-
# transform params into aesara.compile.In objects.
435+
# Transform params into aesara.compile.In objects.
438436
inputs = [
439437
_pfunc_param_to_in(p, allow_downcast=allow_input_downcast) for p in params
440438
]

0 commit comments

Comments
 (0)