Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit 2c9f692

Browse files
Move tensor_copy rewrite to aesara.tensor.rewriting.basic
1 parent 63f5253 commit 2c9f692

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

aesara/tensor/rewriting/basic.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from aesara.graph.basic import Constant, Variable
1212
from aesara.graph.rewriting.basic import (
1313
NodeRewriter,
14+
RemovalNodeRewriter,
1415
Rewriter,
1516
copy_stack_trace,
1617
in2out,
@@ -35,6 +36,7 @@
3536
join,
3637
ones_like,
3738
switch,
39+
tensor_copy,
3840
zeros,
3941
zeros_like,
4042
)
@@ -1294,3 +1296,6 @@ def __getattr__(name):
12941296
return fn()
12951297

12961298
raise AttributeError(f"module {__name__} has no attribute {name}")
1299+
1300+
1301+
register_canonicalize(RemovalNodeRewriter(tensor_copy), name="remove_tensor_copy")

aesara/tensor/rewriting/shape.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from aesara.graph.fg import FunctionGraph
1414
from aesara.graph.rewriting.basic import (
1515
GraphRewriter,
16-
RemovalNodeRewriter,
1716
check_chain,
1817
copy_stack_trace,
1918
node_rewriter,
@@ -27,7 +26,6 @@
2726
extract_constant,
2827
get_scalar_constant_value,
2928
stack,
30-
tensor_copy,
3129
)
3230
from aesara.tensor.elemwise import DimShuffle, Elemwise
3331
from aesara.tensor.exceptions import NotScalarConstantError, ShapeError
@@ -972,9 +970,6 @@ def local_reshape_lift(fgraph, node):
972970
return [e]
973971

974972

975-
register_canonicalize(RemovalNodeRewriter(tensor_copy), name="remove_tensor_copy")
976-
977-
978973
@register_useless
979974
@register_canonicalize
980975
@node_rewriter([SpecifyShape])

0 commit comments

Comments
 (0)