@@ -736,17 +736,23 @@ def FoldConstant() -> tvm.ir.transform.Pass:
736736
737737
738738def CanonicalizeShapeExpr () -> tvm .ir .transform .Pass :
739- """Canonicalize ShapeExpr by lifting compound PrimExpr into separate bindings .
739+ """Canonicalize ShapeExpr by replacing compound PrimExpr with fresh symbolic variables .
740740
741741 VMShapeLower can only handle ShapeExpr where each dimension is either:
742742 - IntImm (concrete integer constant)
743- - tir::Var (symbolic variable)
743+ - tir::Var (symbolic variable from function parameters or match_cast )
744744
745- This pass lifts compound PrimExpr (e.g., n+1, 4*n*m, etc.) into separate shape bindings
746- with MatchCast to extract symbolic variables, ensuring VMShapeLower receives only
747- canonical shape expressions.
745+ This pass transforms compound PrimExpr (e.g., n+1, 4*n*m) by:
746+ 1. Creating a fresh tir::Var for each compound expression
747+ 2. Emitting a MatchCast that binds the fresh var to a PrimValue computing the expression
748+ 3. Replacing the compound expression in ShapeExpr with teh fresh var
748749
749- This pass should be applied after ComputePrimValue and before VMShapeLower.
750+ Example transformation:
751+ Before: y = R.zeros(R.shape([n + 1]), dtype="float32")
752+ After: _s0_pv: R.Prim(value=_s0) = R.match_cast(R.prim_value(n+1), R.Prim(value=_s0))
753+ y = R.zeros(R.shape([_s0]), dtype="float32")
754+
755+ This pass should be applied before ComputePrimValue and before VMShapeLower.
750756
751757 Returns
752758 -------
0 commit comments