diff --git a/python/tvm/relax/backend/cpu_generic/pipeline.py b/python/tvm/relax/backend/cpu_generic/pipeline.py index 74d951b817b1..527cda28d8cc 100644 --- a/python/tvm/relax/backend/cpu_generic/pipeline.py +++ b/python/tvm/relax/backend/cpu_generic/pipeline.py @@ -52,6 +52,7 @@ def finalize_passes(target: tvm.target.Target): # pylint: disable=unused-argume relax.transform.LowerAllocTensor(), relax.transform.KillAfterLastUse(), relax.transform.LowerRuntimeBuiltin(), + relax.transform.CanonicalizeShapeExpr(), relax.transform.ComputePrimValue(), relax.transform.VMShapeLower(), relax.transform.AttachGlobalSymbol(), diff --git a/python/tvm/relax/backend/cuda/pipeline.py b/python/tvm/relax/backend/cuda/pipeline.py index d5c4c0856165..3861036c383b 100644 --- a/python/tvm/relax/backend/cuda/pipeline.py +++ b/python/tvm/relax/backend/cuda/pipeline.py @@ -64,6 +64,7 @@ def finalize_passes(target: tvm.target.Target): # pylint: disable=unused-argume relax.transform.LowerAllocTensor(), relax.transform.KillAfterLastUse(), relax.transform.LowerRuntimeBuiltin(), + relax.transform.CanonicalizeShapeExpr(), relax.transform.ComputePrimValue(), relax.transform.VMShapeLower(), relax.transform.AttachGlobalSymbol(), diff --git a/python/tvm/relax/backend/gpu_generic/pipeline.py b/python/tvm/relax/backend/gpu_generic/pipeline.py index 86c60114c699..f3df2510ad51 100644 --- a/python/tvm/relax/backend/gpu_generic/pipeline.py +++ b/python/tvm/relax/backend/gpu_generic/pipeline.py @@ -63,6 +63,7 @@ def finalize_passes(target: tvm.target.Target): # pylint: disable=unused-argume relax.transform.LowerAllocTensor(), relax.transform.KillAfterLastUse(), relax.transform.LowerRuntimeBuiltin(), + relax.transform.CanonicalizeShapeExpr(), relax.transform.ComputePrimValue(), relax.transform.VMShapeLower(), relax.transform.AttachGlobalSymbol(), diff --git a/python/tvm/relax/backend/rocm/pipeline.py b/python/tvm/relax/backend/rocm/pipeline.py index e74039ca8634..fa1da7cde689 100644 --- a/python/tvm/relax/backend/rocm/pipeline.py +++ b/python/tvm/relax/backend/rocm/pipeline.py @@ -63,6 +63,7 @@ def finalize_passes(target: tvm.target.Target): # pylint: disable=unused-argume relax.transform.LowerAllocTensor(), relax.transform.KillAfterLastUse(), relax.transform.LowerRuntimeBuiltin(), + relax.transform.CanonicalizeShapeExpr(), relax.transform.ComputePrimValue(), relax.transform.VMShapeLower(), relax.transform.AttachGlobalSymbol(), diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index dacbc667be2b..72e23e089519 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -28,6 +28,7 @@ BundleModelParams, CallTIRRewrite, CanonicalizeBindings, + CanonicalizeShapeExpr, CombineParallelMatmul, ComputePrimValue, ConvertLayout, diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index bfd7dbf87d70..2babf0c9ba90 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -735,6 +735,32 @@ def FoldConstant() -> tvm.ir.transform.Pass: return _ffi_api.FoldConstant() # type: ignore +def CanonicalizeShapeExpr() -> tvm.ir.transform.Pass: + """Canonicalize ShapeExpr by replacing compound PrimExpr with fresh symbolic variables. + + VMShapeLower can only handle ShapeExpr where each dimension is either: + - IntImm (concrete integer constant) + - tir::Var (symbolic variable from function parameters or match_cast) + + This pass transforms compound PrimExpr (e.g., n+1, 4*n*m) by: + 1. Creating a fresh tir::Var for each compound expression + 2. Emitting a MatchCast that binds the fresh var to a PrimValue computing the expression + 3. Replacing the compound expression in ShapeExpr with teh fresh var + + Example transformation: + Before: y = R.zeros(R.shape([n + 1]), dtype="float32") + After: _s0_pv: R.Prim(value=_s0) = R.match_cast(R.prim_value(n+1), R.Prim(value=_s0)) + y = R.zeros(R.shape([_s0]), dtype="float32") + + This pass should be applied before ComputePrimValue and before VMShapeLower. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.CanonicalizeShapeExpr() # type: ignore + + def ExpandTupleArguments() -> tvm.ir.transform.Pass: """Expand tuple arguments to internal functions diff --git a/src/relax/transform/canonicalize_shape_expr.cc b/src/relax/transform/canonicalize_shape_expr.cc new file mode 100644 index 000000000000..3a7f997d6bf1 --- /dev/null +++ b/src/relax/transform/canonicalize_shape_expr.cc @@ -0,0 +1,368 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/transform/canonicalize_shape_expr.cc + * \brief Canonicalize ShapeExpr by replacing compound PrimExpr with fresh symbolic variables. + * + * VMShapeLower can only handle expressions where each PrimExpr dimension is either: + * - IntImm (concrete integer constant) + * - tir::Var (symbolic variable from function parameters or match_cast) + * + * This pass transforms compound PrimExpr (e.g., n+1, 4*n*m) in ShapeExpr and struct_info by: + * 1. Creating a fresh tir::Var for each compound expression + * 2. Emitting a MatchCast that binds the fresh var to a PrimValue computing the expression + * 3. Replacing the compound expression with the fresh var everywhere (ShapeExpr and struct_info) + * + * Example transformation: + * Before: y = R.Tensor((n + 1,)) = R.zeros(R.shape([n + 1]), dtype="float32") + * After: _s0_pv: R.Prim(value=_s0) = R.match_cast(R.prim_value(n + 1), R.Prim(value=_s0)) + * y = R.Tensor((_s0,)) = R.zeros(R.shape([_s0]), dtype="float32") + * + * This ensures VMShapeLower only sees simple tir::Var references, which it can resolve + * through the MatchCast bindings. + */ + +#include +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace relax { + +namespace { + +/*! + * \brief Check if a PrimExpr is trivial (already canonical for VMShapeLower) + * + * Trivial expressions are: + * - IntImm: concrete integer constants + * - tir::Var: symbolic variables + * + * Any other expression (arithmetic, casts, etc.) is compound and needs canonicalization. + */ +bool IsTrivialPrimExpr(const PrimExpr& expr) { + return expr->IsInstance() || expr->IsInstance(); +} + +/*! + * \brief Collector for compound PrimExpr in an expression tree. + * + * Scans ShapeExpr nodes and collects all compound (non-trivial) PrimExpr. + */ +class CompoundExprCollector : public ExprVisitor { + public: + void VisitExpr_(const ShapeExprNode* op) override { + for (const PrimExpr& dim : op->values) { + if (!IsTrivialPrimExpr(dim)) { + compound_exprs_.insert(dim); + } + } + ExprVisitor::VisitExpr_(op); + } + + std::unordered_set compound_exprs_; +}; + +/*! + * \brief StructInfo mutator that substitutes PrimExpr according to a mapping. + */ +class StructInfoPrimExprMutator : public StructInfoMutator { + public: + explicit StructInfoPrimExprMutator( + const std::unordered_map& expr_map) + : expr_map_(expr_map) {} + + StructInfo VisitStructInfo_(const TensorStructInfoNode* op) override { + // Substitute PrimExpr in shape + ffi::Shape new_shape = op->shape; + if (new_shape.defined()) { + ffi::Array new_shape_values; + bool shape_changed = false; + + for (const PrimExpr& dim : new_shape->values) { + auto it = expr_map_.find(dim); + if (it != expr_map_.end()) { + new_shape_values.push_back(it->second); + shape_changed = true; + } else { + new_shape_values.push_back(dim); + } + } + + if (shape_changed) { + new_shape = Shape(new_shape_values); + } + } + + DataType new_dtype = op->dtype; + + if (new_shape.same_as(op->shape) && new_dtype == op->dtype) { + return StructInfoMutator::VisitStructInfo_(op); + } + + return TensorStructInfo(new_shape, new_dtype, new_ndim_sinfo); + } + + StructInfo VisitStructInfo_(const ShapeStructInfoNode* op) override { + // Substitute PrimExpr in shape + if (op->values.defined()) { + ffi::Array new_values; + bool changed = false; + + for (size_t i = 0; i < op->values.size(); ++i) { + const PrimExpr& dim = op->values[i]; + auto it = expr_map_.find(dim); + if (it != expr_map_.end()) { + new_values.push_back(it->second); + changed = true; + } else { + new_values.push_back(dim); + } + } + + if (changed) { + return ShapeStructInfo(new_values); + } + } + return StructInfoMutator::VisitStructInfo_(op); + } + + private: + const std::unordered_map& expr_map_; +}; + +/*! + * \brief Mutator to canonicalize ShapeExpr and struct_info by replacing compound PrimExpr + * with fresh symbolic variables bound via MatchCast. + */ +class ShapeExprCanonicalizer : public ExprMutator { + public: + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const FunctionNode* func) override { + // Save old state + BlockBuilder saved_builder = builder_; + + // Create new scope with builder + builder_ = BlockBuilder(); + + // Reset state for each function + sym_var_counter_ = 0; + expr_to_var_.clear(); + + // First pass: collect all compound expressions in the function body + // so we can emit MatchCast bindings at the beginning + CollectCompoundExprsInFunction(func); + + // Visit params + ffi::Array params; + bool all_params_unchanged = true; + for (Var param : func->params) { + Var new_param = this->VisitVarDef(param); + params.push_back(new_param); + if (!param.same_as(new_param)) { + var_remap_[param->vid] = new_param; + all_params_unchanged = false; + } + } + + // Process the function body + Expr new_body = this->VisitWithNewScope(func->body, params); + + // Also substitute in the return struct_info + StructInfo new_ret_sinfo = SubstituteStructInfo(func->ret_struct_info); + + bool ret_sinfo_changed = !StructuralEqual()(new_ret_sinfo, func->ret_struct_info); + bool body_changed = !new_body.same_as(func->body); + + builder_ = saved_builder; + + if (all_params_unchanged && !ret_sinfo_changed && !body_changed) { + return ffi::GetRef(func); + } + + return Function(params, new_body, new_ret_sinfo, func->is_pure, func->attrs, func->span); + } + + void VisitBinding_(const VarBindingNode* binding) override { + // First, emit MatchCast bindings for any compound PrimExpr in ShapeExpr + // This populates expr_to_var_ with mappings from compound expr to fresh vars + EmitMatchCastForCompoundExprs(binding->value); + + // Now visit the binding with substitution + Expr new_value = this->VisitExpr(binding->value); + + // Get the struct_info from the new value and substitute compound exprs + StructInfo new_sinfo = SubstituteStructInfo(GetStructInfo(new_value)); + + // Create a new relax::Var with the substituted struct_info + Var new_var(binding->var->name_hint(), new_sinfo, binding->var->span); + + // Remap the old var to the new var + var_remap_[binding->var->vid] = new_var; + + // Emit the new binding + builder_->EmitNormalized(VarBinding(new_var, new_value)); + } + + void VisitBinding_(const MatchCastNode* binding) override { + // Emit MatchCast bindings for compound PrimExpr in ShapeExpr first + EmitMatchCastForCompoundExprs(binding->value); + + // Visit the value + Expr new_value = this->VisitExpr(binding->value); + + // Substitute in the struct_info + StructInfo new_sinfo = SubstituteStructInfo(GetStructInfo(binding->value)); + + // Create a new relax::Var with the substituted struct_info + Var new_var(binding->var->name_hint(), new_sinfo, binding->var->span); + + var_remap_[binding->var->vid] = new_var; + + builder_->EmitNormalized(MatchCast(new_var, new_value, new_sinfo)); + } + + Expr VisitExpr_(const ShapeExprNode* op) override { + // Rewrite ShapeExpr to replace compound PrimExpr with fresh symbolic variables + ffi::Array new_values; + bool changed = false; + + for (const PrimExpr& dim : op->values) { + if (IsTrivialPrimExpr(dim)) { + new_values.push_back(dim); + } else { + auto it = expr_to_var_.find(dim); + if (it != expr_to_var_.end()) { + new_values.push_back(it->second); + changed = true; + } else { + new_values.push_back(dim); + } + } + } + + if (changed) { + return ShapeExpr(new_values, op->span); + } + return ffi::GetRef(op); + } + + private: + /*! + * \brief Collect all compound expressions in a function body. + */ + void CollectCompoundExprsInFunction(const FunctionNode* func) { + CompoundExprCollector collector; + collector.VisitExpr(func->body); + } + + /*! + * \brief Scan an expression for ShapeExpr nodes and emit MatchCast bindings + * for any compound PrimExpr dimensions. + */ + void EmitMatchCastForCompoundExprs(const Expr& expr) { + CompoundExprCollector collector; + collector.VisitExpr(expr); + + for (const PrimExpr& compound_expr : collector.compound_exprs_) { + EmitMatchCastIfNeeded(compound_expr); + } + } + + /*! + * \brief Substitute compound PrimExpr in a StructInfo with fresh variables. + */ + StructInfo SubstituteStructInfo(const StructInfo& sinfo) { + if (expr_to_var_.empty()) { + return sinfo; + } + StructInfoPrimExprMutator mutator(expr_to_var_); + return mutator.VisitStructInfo(sinfo); + } + + /*! + * \brief Emit a MatchCast binding for a compound PrimExpr if not already done. + */ + void EmitMatchCastIfNeeded(const PrimExpr& expr) { + if (IsTrivialPrimExpr(expr)) { + return; + } + + if (expr_to_var_.count(expr)) { + return; + } + + // Create a fresh tir::Var to hold the computed value + std::string var_name = "_s" + std::to_string(sym_var_counter_++); + tir::Var fresh_tir_var(var_name, expr->dtype); + + // Record the mapping for substitution + expr_to_var_[expr] = fresh_tir_var; + + // Create a PrimValue that computes the compound expression + PrimValue prim_value(expr); + + // Create a PrimStructInfo that declares the fresh variable as the value + PrimStructInfo target_sinfo(fresh_tir_var); + + // Create a Relax Var to hold the MatchCast result + std::string relax_var_name = var_name + "_pv"; + relax::Var match_var(relax_var_name, target_sinfo); + + // Emit the MatchCast binding + builder_->EmitNormalized(MatchCast(match_var, prim_value, target_sinfo)); + } + + BlockBuilder builder_; + int sym_var_counter_ = 0; + std::unordered_map expr_to_var_; +}; + +} // namespace + +Expr CanonicalizeShapeExpr(Expr expr) { return ShapeExprCanonicalizer()(std::move(expr)); } + +namespace transform { + +Pass CanonicalizeShapeExpr() { + auto pass_func = [=](Function f, IRModule m, PassContext pc) { + return Downcast(relax::CanonicalizeShapeExpr(f)); + }; + return CreateFunctionPass(/*pass_function=*/pass_func, + /*opt_level=*/0, + /*pass_name=*/"CanonicalizeShapeExpr", + /*required=*/{}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.CanonicalizeShapeExpr", CanonicalizeShapeExpr); +} + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_transform_canonicalize_shape_expr.py b/tests/python/relax/test_transform_canonicalize_shape_expr.py new file mode 100644 index 000000000000..6723ceae6991 --- /dev/null +++ b/tests/python/relax/test_transform_canonicalize_shape_expr.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Unit tests for the CanonicalizeShapeExpr pass""" + +import pytest +import tvm +import tvm.testing +from tvm import relax +from tvm.script import relax as R +from tvm.script import tir as T + + +def test_nested_compound_shape(): + """Test canonicalization with nested compound shape expressions""" + + @R.function + def before(x: R.Tensor(("n", "m"), "float32")): + n = T.int64() + m = T.int64() + # Nested compound expression: (n + m) * 2 + y: R.Tensor(((n + m) * 2,), "float32") = R.zeros(R.shape([(n + m) * 2]), dtype="float32") + return y + + mod = tvm.IRModule.from_expr(before) + mod = relax.transform.CanonicalizeShapeExpr()(mod) + + # Verify: MatchCast bindings should exist for compound exprs + func = mod["before"] + # Check that no ShapeExpr contains compound expressions anymore + + mod = relax.transform.Normalize()(mod) + mod = relax.transform.ComputePrimValue()(mod) + mod = relax.transform.VMShapeLower()(mod) + + assert "compute_symbolic_expr" in [str(gv) for gv in mod.get_global_vars()] + + +if __name__ == "__main__": + import sys + + print("Running CanonicalizeShapeExpr unit tests...") + print("=" * 80) + + tests = [ + ("Nested compound shape", test_nested_compound_shape), + ] + + passed = 0 + failed = 0 + + for name, test_func in tests: + try: + print(f"\nTest: {name}") + test_func() + print("Result: PASSED") + passed += 1 + except Exception as e: + print(f"Result: FAILED: {e}") + import traceback + + traceback.print_exc() + failed += 1 + + print("\n" + "=" * 80) + print(f"Total tests run: {passed + failed}, Passed: {passed}, Failed: {failed}") + + sys.exit(0 if failed == 0 else 1)