diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index 73d1a3dbebce..b2f946328917 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -34,6 +34,7 @@ #include #include +#include #include namespace tvm { @@ -494,6 +495,17 @@ struct VarUsageInfo { */ VarUsageInfo CollectVarUsage(const Expr& expr); +/*! + * \brief Get the used variables in an expression. + * + * This function collects all variables that are referenced within the given expression. + * + * \param expr The expression to analyze + * + * \return A set of variable nodes that are used in the expression + */ +TVM_DLL std::set GetUsedVars(const Expr& expr); + /*! * \brief Remove unused statements inside DataflowBlocks. * diff --git a/python/tvm/relax/analysis/__init__.py b/python/tvm/relax/analysis/__init__.py index 592e3bb5db51..7e267b0f7867 100644 --- a/python/tvm/relax/analysis/__init__.py +++ b/python/tvm/relax/analysis/__init__.py @@ -32,6 +32,7 @@ free_symbolic_vars, free_vars, get_static_type, + used_vars, get_var2val, has_reshape_pattern, name_to_binding, diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index af0772ea6cbe..8d40d3d42780 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -312,6 +312,26 @@ def all_vars(expr: Expr) -> List[Var]: return _ffi_api.all_vars(expr) +def used_vars(expr: Expr) -> List[Var]: + """ + Return all variables used in an expression. + + This function collects all variable references within the given expression, + which is useful for analyzing variable dependencies. + + Parameters + ---------- + expr: Expr + The expression to analyze. + + Returns + ------- + ret: List[Var] + List of variables used in the expression. + """ + return _ffi_api.used_vars(expr) # type: ignore + + def all_global_vars(expr: Expr) -> List[GlobalVar]: """ Return all global variables from expression expr. diff --git a/src/relax/analysis/udchain.cc b/src/relax/analysis/udchain.cc index bbdbb7b644ef..fcd628f606cf 100644 --- a/src/relax/analysis/udchain.cc +++ b/src/relax/analysis/udchain.cc @@ -121,10 +121,29 @@ ffi::Map> DataflowBlockUseDef(const DataflowBlock& dfb) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("relax.analysis.udchain", DataflowBlockUseDef); + refl::GlobalDef() + .def("relax.analysis.udchain", DataflowBlockUseDef) + .def("relax.analysis.used_vars", [](const Expr& expr) { + auto used_vars = GetUsedVars(expr); + ffi::Array result; + for (const VarNode* var_node : used_vars) { + result.push_back(ffi::GetRef(var_node)); + } + return result; + }); } VarUsageInfo CollectVarUsage(const Expr& expr) { return UDChain::Collect(expr); } +std::set GetUsedVars(const Expr& expr) { + class UsedVars : public ExprVisitor { + public: + std::set used_vars; + void VisitExpr_(const VarNode* op) override { used_vars.insert(op); } + } visitor; + visitor.VisitExpr(expr); + return std::move(visitor.used_vars); +} + } // namespace relax } // namespace tvm diff --git a/src/relax/ir/binding_rewrite.cc b/src/relax/ir/binding_rewrite.cc index 0bbfef31b83a..a8dcf78155dd 100644 --- a/src/relax/ir/binding_rewrite.cc +++ b/src/relax/ir/binding_rewrite.cc @@ -23,6 +23,7 @@ */ #include +#include #include #include #include @@ -134,17 +135,6 @@ class UpdateDFB : public ExprMutator { } }; -// TODO(masahi): Consider moving this to analysis -std::set GetUsedVars(Expr val) { - class UsedVars : public ExprVisitor { - public: - std::set used_vars; - void VisitExpr_(const VarNode* op) override { used_vars.insert(op); } - } uvar{}; - uvar.VisitExpr(val); - return std::move(uvar.used_vars); -} - void DataflowBlockRewriteNode::Add(Binding binding) { auto [var, val] = [binding] { if (auto vb = binding.as()) { diff --git a/tests/python/relax/test_analysis.py b/tests/python/relax/test_analysis.py index 9f5c200cde47..2845622bbe3e 100644 --- a/tests/python/relax/test_analysis.py +++ b/tests/python/relax/test_analysis.py @@ -17,6 +17,7 @@ from typing import List, Set, Union +import pytest import tvm import tvm.testing from tvm import relax as rx @@ -26,6 +27,7 @@ all_vars, bound_vars, free_vars, + used_vars, has_reshape_pattern, name_to_binding, remove_all_unused, @@ -61,6 +63,27 @@ def test_use_def(): assert set(udc[gv0]) == set() +@pytest.mark.parametrize( + "expr_fn, expected_var_names", + [ + (lambda x, y, z: rx.op.add(x, y), {"x", "y"}), + (lambda x, y, z: rx.op.multiply(x, x), {"x"}), + (lambda x, y, z: rx.Tuple([x, y, z]), {"x", "y", "z"}), + ], + ids=["binary_op", "self_reference", "tuple"], +) +def test_used_vars(expr_fn, expected_var_names): + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x = rx.Var("x", R.Tensor([m, n], "float16")) + y = rx.Var("y", R.Tensor([n], "float16")) + z = rx.Var("z", R.Tensor([m], "float16")) + + expr = expr_fn(x, y, z) + result = used_vars(expr) + assert var_name_set(result) == expected_var_names + + def test_chained_remove_all_unused(): @tvm.script.ir_module class IdentityUnused: