Skip to content

Commit 4b02064

Browse files
committed
Finish Copilot code
1 parent 53adf9a commit 4b02064

File tree

8 files changed

+513
-201
lines changed

8 files changed

+513
-201
lines changed

pytensor/link/jax/dispatch/subtensor.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,18 @@
3131
"""
3232

3333

34+
@jax_funcify.register(AdvancedSubtensor1)
35+
def jax_funcify_AdvancedSubtensor1(op, node, **kwargs):
36+
def advanced_subtensor1(x, ilist):
37+
return x[ilist]
38+
39+
return advanced_subtensor1
40+
41+
3442
@jax_funcify.register(Subtensor)
3543
@jax_funcify.register(AdvancedSubtensor)
36-
@jax_funcify.register(AdvancedSubtensor1)
3744
def jax_funcify_Subtensor(op, node, **kwargs):
38-
idx_list = getattr(op, "idx_list", None)
45+
idx_list = op.idx_list
3946

4047
def subtensor(x, *ilists):
4148
indices = indices_from_subtensor(ilists, idx_list)
@@ -47,10 +54,24 @@ def subtensor(x, *ilists):
4754
return subtensor
4855

4956

50-
@jax_funcify.register(IncSubtensor)
5157
@jax_funcify.register(AdvancedIncSubtensor1)
58+
def jax_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
59+
if getattr(op, "set_instead_of_inc", False):
60+
61+
def jax_fn(x, y, ilist):
62+
return x.at[ilist].set(y)
63+
64+
else:
65+
66+
def jax_fn(x, y, ilist):
67+
return x.at[ilist].add(y)
68+
69+
return jax_fn
70+
71+
72+
@jax_funcify.register(IncSubtensor)
5273
def jax_funcify_IncSubtensor(op, node, **kwargs):
53-
idx_list = getattr(op, "idx_list", None)
74+
idx_list = op.idx_list
5475

5576
if getattr(op, "set_instead_of_inc", False):
5677

@@ -77,8 +98,8 @@ def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
7798

7899
@jax_funcify.register(AdvancedIncSubtensor)
79100
def jax_funcify_AdvancedIncSubtensor(op, node, **kwargs):
80-
idx_list = getattr(op, "idx_list", None)
81-
101+
idx_list = op.idx_list
102+
82103
if getattr(op, "set_instead_of_inc", False):
83104

84105
def jax_fn(x, indices, y):

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
)
2121
from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy
2222
from pytensor.tensor import TensorType
23-
from pytensor.tensor.rewriting.subtensor import is_full_slice
2423
from pytensor.tensor.subtensor import (
2524
AdvancedIncSubtensor,
2625
AdvancedIncSubtensor1,
@@ -29,7 +28,7 @@
2928
IncSubtensor,
3029
Subtensor,
3130
)
32-
from pytensor.tensor.type_other import MakeSlice, NoneTypeT, SliceType
31+
from pytensor.tensor.type_other import MakeSlice
3332

3433

3534
def slice_new(self, start, stop, step):
@@ -239,15 +238,15 @@ def {function_name}({", ".join(input_names)}):
239238
@register_funcify_and_cache_key(AdvancedIncSubtensor)
240239
def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
241240
if isinstance(op, AdvancedSubtensor):
242-
x, y, tensor_inputs = node.inputs[0], None, node.inputs[1:]
241+
tensor_inputs = node.inputs[1:]
243242
else:
244-
x, y, *tensor_inputs = node.inputs
243+
tensor_inputs = node.inputs[2:]
245244

246245
# Reconstruct indexing information from idx_list and tensor inputs
247246
basic_idxs = []
248247
adv_idxs = []
249248
input_idx = 0
250-
249+
251250
for i, entry in enumerate(op.idx_list):
252251
if isinstance(entry, slice):
253252
# Basic slice index
@@ -256,12 +255,14 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
256255
# Advanced tensor index
257256
if input_idx < len(tensor_inputs):
258257
idx_input = tensor_inputs[input_idx]
259-
adv_idxs.append({
260-
"axis": i,
261-
"dtype": idx_input.type.dtype,
262-
"bcast": idx_input.type.broadcastable,
263-
"ndim": idx_input.type.ndim,
264-
})
258+
adv_idxs.append(
259+
{
260+
"axis": i,
261+
"dtype": idx_input.type.dtype,
262+
"bcast": idx_input.type.broadcastable,
263+
"ndim": idx_input.type.ndim,
264+
}
265+
)
265266
input_idx += 1
266267

267268
# Special implementation for consecutive integer vector indices

pytensor/link/pytorch/dispatch/subtensor.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
Subtensor,
1010
indices_from_subtensor,
1111
)
12-
from pytensor.tensor.type_other import MakeSlice, SliceType
12+
from pytensor.tensor.type_other import MakeSlice
1313

1414

1515
def check_negative_steps(indices):
@@ -63,8 +63,8 @@ def makeslice(start, stop, step):
6363
@pytorch_funcify.register(AdvancedSubtensor1)
6464
@pytorch_funcify.register(AdvancedSubtensor)
6565
def pytorch_funcify_AdvSubtensor(op, node, **kwargs):
66-
idx_list = getattr(op, "idx_list", None)
67-
66+
idx_list = op.idx_list
67+
6868
def advsubtensor(x, *flattened_indices):
6969
indices = indices_from_subtensor(flattened_indices, idx_list)
7070
check_negative_steps(indices)
@@ -105,7 +105,7 @@ def inc_subtensor(x, y, *flattened_indices):
105105
@pytorch_funcify.register(AdvancedIncSubtensor)
106106
@pytorch_funcify.register(AdvancedIncSubtensor1)
107107
def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
108-
idx_list = getattr(op, "idx_list", None)
108+
idx_list = op.idx_list
109109
inplace = op.inplace
110110
ignore_duplicates = getattr(op, "ignore_duplicates", False)
111111

@@ -139,7 +139,9 @@ def adv_inc_subtensor_no_duplicates(x, y, *flattened_indices):
139139

140140
else:
141141
# Check if we have slice indexing in idx_list
142-
has_slice_indexing = any(isinstance(entry, slice) for entry in idx_list) if idx_list else False
142+
has_slice_indexing = (
143+
any(isinstance(entry, slice) for entry in idx_list) if idx_list else False
144+
)
143145
if has_slice_indexing:
144146
raise NotImplementedError(
145147
"IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch"

pytensor/tensor/basic.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1818,6 +1818,33 @@ def do_constant_folding(self, fgraph, node):
18181818
return True
18191819

18201820

1821+
@_vectorize_node.register(Alloc)
1822+
def vectorize_alloc(op: Alloc, node: Apply, batch_val, *batch_shapes):
1823+
# batch_shapes are usually not batched (they are scalars for the shape)
1824+
# batch_val is the value being allocated.
1825+
1826+
# If shapes are batched, we fall back (complex case)
1827+
if any(
1828+
b_shp.type.ndim > shp.type.ndim
1829+
for b_shp, shp in zip(batch_shapes, node.inputs[1:], strict=True)
1830+
):
1831+
return vectorize_node_fallback(op, node, batch_val, *batch_shapes)
1832+
1833+
# If value is batched, we need to prepend batch dims to the output shape
1834+
val = node.inputs[0]
1835+
batch_ndim = batch_val.type.ndim - val.type.ndim
1836+
1837+
if batch_ndim == 0:
1838+
return op.make_node(batch_val, *batch_shapes)
1839+
1840+
# We need the size of the batch dimensions
1841+
# batch_val has shape (B1, B2, ..., val_dims...)
1842+
batch_dims = [batch_val.shape[i] for i in range(batch_ndim)]
1843+
1844+
new_shapes = batch_dims + list(batch_shapes)
1845+
return op.make_node(batch_val, *new_shapes)
1846+
1847+
18211848
alloc = Alloc()
18221849
pprint.assign(alloc, printing.FunctionPrinter(["alloc"]))
18231850

pytensor/tensor/rewriting/subtensor.py

Lines changed: 28 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
in2out,
1515
node_rewriter,
1616
)
17+
from pytensor.graph.type import Type
1718
from pytensor.raise_op import Assert
1819
from pytensor.scalar import Add, ScalarConstant, ScalarType
1920
from pytensor.scalar import constant as scalar_constant
@@ -212,6 +213,20 @@ def get_advsubtensor_axis(indices):
212213
return axis
213214

214215

216+
def reconstruct_indices(idx_list, tensor_inputs):
217+
"""Reconstruct indices from idx_list and tensor inputs."""
218+
indices = []
219+
input_idx = 0
220+
for entry in idx_list:
221+
if isinstance(entry, slice):
222+
indices.append(entry)
223+
elif isinstance(entry, Type):
224+
if input_idx < len(tensor_inputs):
225+
indices.append(tensor_inputs[input_idx])
226+
input_idx += 1
227+
return indices
228+
229+
215230
@register_specialize
216231
@node_rewriter([AdvancedSubtensor])
217232
def local_replace_AdvancedSubtensor(fgraph, node):
@@ -229,17 +244,9 @@ def local_replace_AdvancedSubtensor(fgraph, node):
229244

230245
indexed_var = node.inputs[0]
231246
tensor_inputs = node.inputs[1:]
232-
247+
233248
# Reconstruct indices from idx_list and tensor inputs
234-
indices = []
235-
input_idx = 0
236-
for entry in node.op.idx_list:
237-
if isinstance(entry, slice):
238-
indices.append(entry)
239-
elif isinstance(entry, Type):
240-
if input_idx < len(tensor_inputs):
241-
indices.append(tensor_inputs[input_idx])
242-
input_idx += 1
249+
indices = reconstruct_indices(node.op.idx_list, tensor_inputs)
243250

244251
axis = get_advsubtensor_axis(indices)
245252

@@ -267,17 +274,9 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):
267274
res = node.inputs[0]
268275
val = node.inputs[1]
269276
tensor_inputs = node.inputs[2:]
270-
277+
271278
# Reconstruct indices from idx_list and tensor inputs
272-
indices = []
273-
input_idx = 0
274-
for entry in node.op.idx_list:
275-
if isinstance(entry, slice):
276-
indices.append(entry)
277-
elif isinstance(entry, Type):
278-
if input_idx < len(tensor_inputs):
279-
indices.append(tensor_inputs[input_idx])
280-
input_idx += 1
279+
indices = reconstruct_indices(node.op.idx_list, tensor_inputs)
281280

282281
axis = get_advsubtensor_axis(indices)
283282

@@ -1112,6 +1111,7 @@ def local_inplace_AdvancedIncSubtensor1(fgraph, node):
11121111
def local_inplace_AdvancedIncSubtensor(fgraph, node):
11131112
if isinstance(node.op, AdvancedIncSubtensor) and not node.op.inplace:
11141113
new_op = type(node.op)(
1114+
node.op.idx_list,
11151115
inplace=True,
11161116
set_instead_of_inc=node.op.set_instead_of_inc,
11171117
ignore_duplicates=node.op.ignore_duplicates,
@@ -1376,6 +1376,7 @@ def local_useless_inc_subtensor_alloc(fgraph, node):
13761376
z_broad[k]
13771377
and not same_shape(xi, y, dim_x=k, dim_y=k)
13781378
and shape_of[y][k] != 1
1379+
and shape_of[xi][k] == 1
13791380
)
13801381
]
13811382

@@ -1778,17 +1779,9 @@ def ravel_multidimensional_bool_idx(fgraph, node):
17781779
else:
17791780
x, y = node.inputs[0], node.inputs[1]
17801781
tensor_inputs = node.inputs[2:]
1781-
1782+
17821783
# Reconstruct indices from idx_list and tensor inputs
1783-
idxs = []
1784-
input_idx = 0
1785-
for entry in node.op.idx_list:
1786-
if isinstance(entry, slice):
1787-
idxs.append(entry)
1788-
elif isinstance(entry, Type):
1789-
if input_idx < len(tensor_inputs):
1790-
idxs.append(tensor_inputs[input_idx])
1791-
input_idx += 1
1784+
idxs = reconstruct_indices(node.op.idx_list, tensor_inputs)
17921785

17931786
if any(
17941787
(
@@ -1829,36 +1822,36 @@ def ravel_multidimensional_bool_idx(fgraph, node):
18291822
# Create new AdvancedSubtensor with updated idx_list
18301823
new_idx_list = list(node.op.idx_list)
18311824
new_tensor_inputs = list(tensor_inputs)
1832-
1825+
18331826
# Update the idx_list and tensor_inputs for the raveled boolean index
18341827
input_idx = 0
18351828
for i, entry in enumerate(node.op.idx_list):
18361829
if isinstance(entry, Type):
18371830
if input_idx == bool_idx_pos:
18381831
new_tensor_inputs[input_idx] = raveled_bool_idx
18391832
input_idx += 1
1840-
1833+
18411834
new_out = AdvancedSubtensor(new_idx_list)(raveled_x, *new_tensor_inputs)
18421835
else:
18431836
# Create new AdvancedIncSubtensor with updated idx_list
18441837
new_idx_list = list(node.op.idx_list)
18451838
new_tensor_inputs = list(tensor_inputs)
1846-
1839+
18471840
# Update the tensor_inputs for the raveled boolean index
18481841
input_idx = 0
18491842
for i, entry in enumerate(node.op.idx_list):
18501843
if isinstance(entry, Type):
18511844
if input_idx == bool_idx_pos:
18521845
new_tensor_inputs[input_idx] = raveled_bool_idx
18531846
input_idx += 1
1854-
1847+
18551848
# The dimensions of y that correspond to the boolean indices
18561849
# must already be raveled in the original graph, so we don't need to do anything to it
18571850
new_out = AdvancedIncSubtensor(
18581851
new_idx_list,
18591852
inplace=node.op.inplace,
18601853
set_instead_of_inc=node.op.set_instead_of_inc,
1861-
ignore_duplicates=node.op.ignore_duplicates
1854+
ignore_duplicates=node.op.ignore_duplicates,
18621855
)(raveled_x, y, *new_tensor_inputs)
18631856
# But we must reshape the output to match the original shape
18641857
new_out = new_out.reshape(x_shape)

0 commit comments

Comments
 (0)