diff --git a/CMakeLists.txt b/CMakeLists.txt index ec7bd6c51453..7a228d67995a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -484,6 +484,15 @@ include(cmake/modules/Git.cmake) include(cmake/modules/LibInfo.cmake) include(cmake/modules/contrib/Mrvl.cmake) +tvm_option(USE_Z3 "Build with Z3 SMT solver support" OFF) + +if (USE_Z3) + find_package(Z3 REQUIRED) + list(APPEND COMPILER_SRCS src/target/z3/z3_prover_on.cc) +else() + list(APPEND COMPILER_SRCS src/target/z3/z3_prover_off.cc) +endif() + set(LIBINFO_FILE ${CMAKE_CURRENT_LIST_DIR}/src/support/libinfo.cc) add_lib_info(${LIBINFO_FILE}) list(REMOVE_ITEM COMPILER_SRCS ${LIBINFO_FILE}) @@ -541,6 +550,42 @@ else() target_link_libraries(tvm_runtime PUBLIC tvm_ffi_shared) endif() +if(USE_Z3) + target_include_directories(tvm_objs PRIVATE ${Z3_CXX_INCLUDE_DIRS}) + target_include_directories(tvm_runtime_objs PRIVATE ${Z3_CXX_INCLUDE_DIRS}) + target_include_directories(tvm_libinfo_objs PRIVATE ${Z3_CXX_INCLUDE_DIRS}) + target_link_libraries(tvm PRIVATE z3::libz3) + + if (APPLE) + # `libz3.dylib` from z3-solver on pypi have a "wrong" name `libz3.dylib`, + # so it won't be searched in rpath. We patch it to `@rpath/libz3.dylib` here. + # `POST_BUILD` command needs to be in same cmake file where the target's created. + add_custom_command(TARGET tvm POST_BUILD + COMMAND install_name_tool -change "libz3.dylib" "@rpath/libz3.dylib" $ + COMMENT "Patching libz3 reference to use @rpath" + ) + else() + find_program(PATCHELF_EXECUTABLE patchelf) + if (PATCHELF_EXECUTABLE) + execute_process( + COMMAND ${PATCHELF_EXECUTABLE} --print-soname ${Z3_LIBRARY} + OUTPUT_VARIABLE Z3_SONAME + OUTPUT_STRIP_TRAILING_WHITESPACE + RESULT_VARIABLE Z3_SONAME_RESULT + ) + if(NOT Z3_SONAME_RESULT EQUAL "0") + message(FATAL_ERROR "Failed to get Z3 soname using patchelf") + endif() + message("-- Z3 SONAME: ${Z3_SONAME}") + add_custom_command(TARGET tvm POST_BUILD + COMMAND ${PATCHELF_EXECUTABLE} --replace-needed ${Z3_SONAME} libz3.so $ + COMMENT "Patching libz3 reference to use soname ${Z3_SONAME}" + ) + else() + message("patchelf not found, skip.") + endif() + endif() +endif() target_include_directories(tvm_runtime PUBLIC "$") set_property(TARGET tvm_runtime APPEND PROPERTY LINK_OPTIONS "${TVM_VISIBILITY_FLAG}") @@ -748,7 +793,7 @@ endif() # Change relative paths in backtrace to absolute ones if(TVM_IS_DEBUG_BUILD) - set(FILE_PREFIX_MAP_FLAG "-ffile-prefix-map=..=${CMAKE_CURRENT_SOURCE_DIR}") + # set(FILE_PREFIX_MAP_FLAG "-ffile-prefix-map=..=${CMAKE_CURRENT_SOURCE_DIR}") target_compile_options(tvm PRIVATE "${FILE_PREFIX_MAP_FLAG}") CHECK_CXX_COMPILER_FLAG("${FILE_PREFIX_MAP_FLAG}" FILE_PREFIX_MAP_SUPPORTED) if(FILE_PREFIX_MAP_SUPPORTED) diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 099643d0a0bb..7c4fdbe75c7e 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -33,6 +33,7 @@ #include #include #include +#include "tvm/ffi/object.h" namespace tvm { /*! \brief namespace of arithmetic analysis. */ @@ -175,6 +176,8 @@ class ConstIntBoundAnalyzer { friend class ConstraintContext; explicit ConstIntBoundAnalyzer(Analyzer* parent); TVM_DLL ~ConstIntBoundAnalyzer(); + // Deep-copy internal state from another instance (for Analyzer::Clone) + void CopyFrom(const ConstIntBoundAnalyzer& other); /*! * \brief Update the internal state to enter constraint. * \param constraint A constraint expression. @@ -254,6 +257,8 @@ class ModularSetAnalyzer { friend class ConstraintContext; explicit ModularSetAnalyzer(Analyzer* parent); TVM_DLL ~ModularSetAnalyzer(); + // Deep-copy internal state from another instance (for Analyzer::Clone) + void CopyFrom(const ModularSetAnalyzer& other); /*! * \brief Update the internal state to enter constraint. * \param constraint A constraint expression. @@ -294,7 +299,7 @@ class RewriteSimplifier { * * \return an exit function that must be called to cleanup the constraint can be nullptr. */ - TVM_DLL std::function EnterConstraint(const PrimExpr& constraint); + TVM_DLL std::function EnterConstraint(const PrimExpr& constraint, bool is_assume=false); /*! \brief Flags to enable more computationally-intensive simplifications * @@ -407,6 +412,8 @@ class RewriteSimplifier { friend class CanonicalSimplifier; explicit RewriteSimplifier(Analyzer* parent); TVM_DLL ~RewriteSimplifier(); + // Deep-copy internal state from another instance (for Analyzer::Clone) + void CopyFrom(const RewriteSimplifier& other); class Impl; /*! \brief Internal impl */ Impl* impl_; @@ -438,6 +445,8 @@ class CanonicalSimplifier { friend class ConstraintContext; explicit CanonicalSimplifier(Analyzer* parent); TVM_DLL ~CanonicalSimplifier(); + // Deep-copy internal state from another instance (for Analyzer::Clone) + void CopyFrom(const CanonicalSimplifier& other); class Impl; /*! \brief Internal impl */ Impl* impl_; @@ -523,6 +532,8 @@ class TransitiveComparisonAnalyzer { friend class ConstraintContext; TransitiveComparisonAnalyzer(); TVM_DLL ~TransitiveComparisonAnalyzer(); + // Deep-copy internal state from another instance (for Analyzer::Clone) + void CopyFrom(const TransitiveComparisonAnalyzer& other); class Impl; /*! \brief Internal impl */ std::unique_ptr impl_; @@ -553,8 +564,8 @@ class ConstraintContext { * \param analyzer The analyzer. * \param constraint The constraint to be applied. */ - ConstraintContext(Analyzer* analyzer, PrimExpr constraint) - : analyzer_(analyzer), constraint_(constraint) {} + ConstraintContext(Analyzer* analyzer, PrimExpr constraint, bool is_assume=false) + : analyzer_(analyzer), constraint_(constraint), is_assume_(is_assume) {} // enter the scope. void EnterWithScope(); // exit the scope. @@ -565,6 +576,7 @@ class ConstraintContext { PrimExpr constraint_; /*! \brief function to be called in recovery */ std::vector> recovery_functions_; + bool is_assume_; }; /*! @@ -616,11 +628,85 @@ class IntSetAnalyzer { friend class Analyzer; explicit IntSetAnalyzer(Analyzer* parent); TVM_DLL ~IntSetAnalyzer(); + // Deep-copy internal state from another instance (for Analyzer::Clone) + void CopyFrom(const IntSetAnalyzer& other); class Impl; /*! \brief Internal impl */ Impl* impl_; }; +class Z3Prover { + public: + /*! + * \brief Update binding of var to a new expression. + * + * \param var The variable of interest. + * \param new_range The range of allowed values for this var. + * \param allow_override whether we allow override of existing information. + */ + TVM_DLL void Bind(const Var& var, const Range& new_range, bool allow_override = false); + + /*! + * \brief Update binding of var to a new expression. + * + * \param var The variable of interest. + * \param expr The bound expression + * \param allow_override whether we allow override of existing information. + */ + TVM_DLL void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false); + + /*! + * \brief Whether can we prove expr is always true. + * + * \param expr The expression. + * \return Whether we can prove it. + * + * \note Analyzer will call into sub-analyzers to get the result. + */ + TVM_DLL bool CanProve(const PrimExpr & expr); + + /*! + * \brief Update the internal state to enter constraint. + * \param constraint A constraint expression. + * + * \return an exit function that must be called to cleanup the constraint can be nullptr. + */ + std::function EnterConstraint(const PrimExpr& constraint, bool is_assume=false); + + /*! + * \brief Get the SMTLIB2 representation of the current context + * \param expr The optional expression to check + * \return The SMTLIB2 string + */ + ffi::String GetSMTLIB2(const ffi::Optional expr); + + /*! + * \brief Get statistics about Z3 prover + * \return The statistics string + */ + ffi::String GetStats(); + + /*! + * \brief Set timeout in milliseconds for Z3 prover + * \param timeout_ms The timeout in milliseconds + */ + void SetTimeoutMs(unsigned timeout_ms); + + /*! + * \brief Set resource limitation for Z3 prover + * \param rlimit the resource limitation (like maxinum step or sth.) + */ + void SetRLimit(unsigned rlimit); + + private: + friend class Analyzer; + explicit Z3Prover(Analyzer* parent); + TVM_DLL ~Z3Prover(); + void CopyFrom(const Z3Prover & other); + class Impl; + Impl* impl_; +}; + /*! * \brief Analyzer that contains bunch of sub-analyzers. * @@ -650,8 +736,15 @@ class TVM_DLL Analyzer { IntSetAnalyzer int_set; /*! \brief sub-analyzer transitive comparisons */ TransitiveComparisonAnalyzer transitive_comparisons; + /*! \brief analyzer using z3 */ + Z3Prover z3_prover; /*! \brief constructor */ Analyzer(); + /*! + * \brief Create a deep copy of this Analyzer, including all sub-analyzer states. + * \return A new Analyzer with copied internal state. + */ + std::unique_ptr Clone() const; /*! * \brief Mark the value as non-negative value globally in analyzer. * diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index 5e38f3876937..117198214a0e 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -309,6 +309,5 @@ class TensorMapType : public Type { TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE_WITHOUT_DEFAULT_CONSTRUCTOR(TensorMapType, Type, TensorMapTypeNode); }; - } // namespace tvm #endif // TVM_IR_TYPE_H_ diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h index db5776890ab9..4be475c09419 100644 --- a/include/tvm/script/ir_builder/tir/frame.h +++ b/include/tvm/script/ir_builder/tir/frame.h @@ -350,11 +350,12 @@ class LetFrameNode : public TIRFrameNode { /*! \brief The value we bind var to */ PrimExpr value; + static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() .def_ro("var", &LetFrameNode::var) - .def_ro("value", &LetFrameNode::value); + .def_rw("value", &LetFrameNode::value); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.LetFrame", LetFrameNode, TIRFrameNode); diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index 07c7fe262bb3..273aa7f63f4b 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -488,6 +488,7 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(UInt, DataType::UInt); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Int, DataType::Int); #define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(FuncName, FDType, Size) \ + TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x2, FDType(Size, 2)); \ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x4, FDType(Size, 4)); \ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x8, FDType(Size, 8)); \ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x16, FDType(Size, 16)); \ @@ -507,6 +508,7 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int); #define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(DType, FDType) \ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType, FDType(1)); \ + TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x2, FDType(2)); \ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x4, FDType(4)); \ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x8, FDType(8)); \ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x16, FDType(16)); \ diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 529765469165..b615ab503522 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -731,9 +731,21 @@ class CallNode : public PrimExprNode { /*! \brief The arguments. */ ffi::Array args; + /*! + * \brief Additional annotations about the call. + * + * These annotations can be used to pass additional metadata + * to lowering passes. For tile operators, this can include + * coalesced_width, disable_tma, eviction_policy, etc. + */ + ffi::Map annotations; + static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("op", &CallNode::op).def_ro("args", &CallNode::args); + refl::ObjectDef() + .def_ro("op", &CallNode::op) + .def_ro("args", &CallNode::args) + .def_ro("annotations", &CallNode::annotations); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Call", CallNode, PrimExprNode); }; @@ -744,7 +756,9 @@ class CallNode : public PrimExprNode { */ class Call : public PrimExpr { public: - TVM_DLL Call(DataType dtype, RelaxExpr op, ffi::Array args, Span span = Span()); + TVM_DLL Call(DataType dtype, RelaxExpr op, ffi::Array args, + ffi::Map annotations = {}, + Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Call, PrimExpr, CallNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode); }; diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 57f868151418..005e8f5532ee 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -722,17 +722,17 @@ TVM_DLL PrimExpr fast_erf_float_expr(PrimExpr arg, int bits); // Intrinsic operators #define TVM_DECLARE_INTRIN_UNARY(OpName) \ - inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \ - static const Op& op = Op::Get("tir." #OpName); \ - if (x.dtype().is_bfloat16()) { \ - DataType bf16_dtype = x.dtype(); \ - DataType fp32_dtype(kDLFloat, 32, bf16_dtype.lanes()); \ - PrimExpr x_fp32 = tir::Cast(fp32_dtype, {x}, span); \ - PrimExpr result_fp32 = tir::Call(fp32_dtype, op, {x_fp32}, span); \ - return tir::Cast(bf16_dtype, {result_fp32}, span); \ - } else { \ - return tir::Call(x.dtype(), op, {x}, span); \ - } \ + inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \ + static const Op& op = Op::Get("tir." #OpName); \ + if (x.dtype().is_bfloat16()) { \ + DataType bf16_dtype = x.dtype(); \ + DataType fp32_dtype(kDLFloat, 32, bf16_dtype.lanes()); \ + PrimExpr x_fp32 = tir::Cast(fp32_dtype, {x}, span); \ + PrimExpr result_fp32 = tir::Call(fp32_dtype, op, {x_fp32}, {}, span); \ + return tir::Cast(bf16_dtype, {result_fp32}, span); \ + } else { \ + return tir::Call(x.dtype(), op, {x}, {}, span); \ + } \ } TVM_DECLARE_INTRIN_UNARY(exp); @@ -764,7 +764,7 @@ TVM_DECLARE_INTRIN_UNARY(clz); #define TVM_DECLARE_INTRIN_BINARY(OpName) \ inline PrimExpr OpName(PrimExpr x, PrimExpr y, Span span = Span()) { \ static const Op& op = Op::Get("tir." #OpName); \ - return tir::Call(x.dtype(), op, {x, y}, span); \ + return tir::Call(x.dtype(), op, {x, y}, {}, span); \ } TVM_DECLARE_INTRIN_BINARY(atan2); diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index a768a7dd4f31..12b2e66429ce 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -535,7 +535,7 @@ class ScheduleNode : public runtime::Object { * \return The reindex stage block. */ virtual BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, - BufferIndexType buffer_index_type) = 0; + BufferIndexType buffer_index_type, bool skip_simplify = false) = 0; /******** Schedule: Data movement ********/ virtual BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index, const ffi::String& storage_scope) = 0; diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 0831b84cf6fe..75ba37b43fb8 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1319,6 +1319,8 @@ constexpr const char* explicit_read_region = "explicit_read_region"; */ constexpr const char* explicit_write_region = "explicit_write_region"; +constexpr const char* tilelang_assume = "tl.assume"; + /*! \brief ,ark a ForNode represent an irregular loop of non-structural control flow edges. */ constexpr const char* irregular_loop_mark = "irregular_loop_mark"; diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index ef4830a46adf..4d0678099582 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1285,12 +1285,15 @@ inline Tensor take(const Tensor& a, ffi::Variant indices, int for (size_t j = 0; j < static_cast(axis); ++j) { real_indices.push_back(out_index[j]); } - auto idx = truncmod(truncmod(get_index(indices_position), axis_dim) + axis_dim, axis_dim); + PrimExpr idx = get_index(indices_position); real_indices.push_back(idx); for (size_t j = axis + indices_len; j < out_index.size(); ++j) { real_indices.push_back(out_index[j]); } - return a(real_indices); + PrimExpr in_bounds = idx >= 0 && idx < axis_dim; + return tvm::if_then_else( + in_bounds, a(real_indices), + tvm::tir::make_const(a->dtype, std::numeric_limits::quiet_NaN())); }, name, tag); } diff --git a/pyproject.toml b/pyproject.toml index 987e17928408..54e0cc91dc3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,7 @@ dependencies = [ "psutil", "scipy", "tornado", - "typing_extensions", + "typing_extensions" ] # Optional dependencies for different features @@ -73,6 +73,9 @@ importer-paddle = ["paddlepaddle"] autotvm = ["xgboost"] autoscheduler = ["xgboost"] +# SMT support +z3 = ["z3-solver>=4.13.0"] + # Development and testing dev = [ "black", @@ -110,6 +113,7 @@ all = [ "tflite", "paddlepaddle", "xgboost", + "z3-solver>=4.13.0" ] [project.urls] diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index c5c8fc067cc8..d8c7e88656b9 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -17,7 +17,7 @@ # pylint: disable=invalid-name """Arithmetic data structure and utility""" import enum -from typing import Union +from typing import Union, Dict import tvm_ffi from tvm import ir, tir @@ -108,22 +108,80 @@ class Analyzer: def __init__(self): _mod = _ffi_api.CreateAnalyzer() - self._const_int_bound = _mod("const_int_bound") - self._const_int_bound_update = _mod("const_int_bound_update") - self._const_int_bound_is_bound = _mod("const_int_bound_is_bound") - self._bind = _mod("bind") - self._modular_set = _mod("modular_set") - self._simplify = _mod("Simplify") - self._rewrite_simplify = _mod("rewrite_simplify") - self._get_rewrite_simplify_stats = _mod("get_rewrite_simplify_stats") - self._reset_rewrite_simplify_stats = _mod("reset_rewrite_simplify_stats") - self._canonical_simplify = _mod("canonical_simplify") - self._int_set = _mod("int_set") - self._enter_constraint_context = _mod("enter_constraint_context") - self._can_prove_equal = _mod("can_prove_equal") - self._can_prove = _mod("can_prove") - self._get_enabled_extensions = _mod("get_enabled_extensions") - self._set_enabled_extensions = _mod("set_enabled_extensions") + self._assign_functions(_mod) + + def _assign_functions(self, mod_factory): + # Save factory for later use (e.g., clone) + self._factory = mod_factory + self._const_int_bound = mod_factory("const_int_bound") + self._const_int_bound_update = mod_factory("const_int_bound_update") + self._const_int_bound_is_bound = mod_factory("const_int_bound_is_bound") + self._bind = mod_factory("bind") + self._modular_set = mod_factory("modular_set") + self._simplify = mod_factory("Simplify") + self._rewrite_simplify = mod_factory("rewrite_simplify") + self._get_rewrite_simplify_stats = mod_factory("get_rewrite_simplify_stats") + self._reset_rewrite_simplify_stats = mod_factory("reset_rewrite_simplify_stats") + self._canonical_simplify = mod_factory("canonical_simplify") + self._int_set = mod_factory("int_set") + self._enter_constraint_context = mod_factory("enter_constraint_context") + self._can_prove_equal = mod_factory("can_prove_equal") + self._can_prove = mod_factory("can_prove") + self._get_smtlib2 = mod_factory("get_smtlib2") + self._set_z3_timeout_ms = mod_factory("set_z3_timeout_ms") + self._set_z3_rlimit = mod_factory("set_z3_rlimit") + self._get_z3_stats = mod_factory("get_z3_stats") + self._get_enabled_extensions = mod_factory("get_enabled_extensions") + self._set_enabled_extensions = mod_factory("set_enabled_extensions") + # Clone factory returns another mod_factory when invoked + self._clone_factory = mod_factory("clone") + + def get_smtlib2(self, expr: tir.PrimExpr = None) -> str: + return self._get_smtlib2(expr) + + def set_z3_timeout_ms(self, timeout_ms: int) -> None: + """Set z3 timeout in milliseconds. + + Parameters + ---------- + timeout_ms : int + The timeout in milliseconds. + """ + self._set_z3_timeout_ms(timeout_ms) + + def set_z3_rlimit(self, max_step: int) -> None: + """Set z3 max step. + + Parameters + ---------- + max_step : int + The maximum number of steps. + """ + self._set_z3_rlimit(max_step) + + def get_z3_stats(self) -> str: + """Get z3 statistics. + + Returns + ------- + stats : str + The z3 statistics. + """ + return self._get_z3_stats() + + def clone(self) -> "Analyzer": + """Create a deep copy of this Analyzer, including internal state. + + Returns + ------- + Analyzer + A new Analyzer instance with the same analysis state. + """ + # _clone_factory() returns a new factory bound to the cloned C++ Analyzer + new_factory = self._clone_factory() + obj = Analyzer.__new__(Analyzer) + Analyzer._assign_functions(obj, new_factory) + return obj def const_int_bound(self, expr: tir.PrimExpr) -> ConstIntBound: """Find constant integer bound for expr. @@ -227,7 +285,7 @@ def canonical_simplify(self, expr: tir.PrimExpr) -> tir.PrimExpr: """ return self._canonical_simplify(expr) - def int_set(self, expr: tir.PrimExpr, dom_map: dict[tir.Var, IntSet]) -> IntSet: + def int_set(self, expr: tir.PrimExpr, dom_map: Dict[tir.Var, IntSet]) -> IntSet: """Compute a symbolic IntSet that covers expr for all values in dom_map. Parameters diff --git a/python/tvm/base.py b/python/tvm/base.py index 8e88364e2600..f5bdc215ce1e 100644 --- a/python/tvm/base.py +++ b/python/tvm/base.py @@ -26,8 +26,8 @@ # ---------------------------- # Python3 version. # ---------------------------- -if not (sys.version_info[0] >= 3 and sys.version_info[1] >= 9): - PY3STATEMENT = "The minimal Python requirement is Python 3.9" +if not (sys.version_info[0] >= 3 and sys.version_info[1] >= 8): + PY3STATEMENT = "The minimal Python requirement is Python 3.8" raise Exception(PY3STATEMENT) # ---------------------------- @@ -42,7 +42,7 @@ def _load_lib(): if sys.platform.startswith("win32"): for path in libinfo.get_dll_directories(): os.add_dll_directory(path) - lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL) + lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL | os.RTLD_LAZY) return lib, os.path.basename(lib_path[0]) diff --git a/python/tvm/libinfo.py b/python/tvm/libinfo.py index 59caf7b2fd3a..5062e80997e2 100644 --- a/python/tvm/libinfo.py +++ b/python/tvm/libinfo.py @@ -53,7 +53,7 @@ def get_dll_directories(): dll_path = [] if os.environ.get("TVM_LIBRARY_PATH", None): - dll_path.append(os.environ["TVM_LIBRARY_PATH"]) + dll_path.extend(os.environ["TVM_LIBRARY_PATH"].split(":")) if sys.platform.startswith("linux") or sys.platform.startswith("freebsd"): dll_path.extend(split_env_var("LD_LIBRARY_PATH", ":")) @@ -232,9 +232,12 @@ def find_include_path(name=None, search_path=None, optional=False): dmlc_include_path = [] else: tvm_include_path = [os.path.join(p, "include") for p in header_path] - tvm_ffi_include_path = [ - os.path.join(p, "3rdparty", "tvm-ffi", "include") for p in header_path - ] + + # Augment with system-installed tvm_ffi includes if available + from tvm_ffi import libinfo as _tvm_ffi_libinfo # type: ignore + tvm_ffi_include_path = [] + tvm_ffi_include_path.append(_tvm_ffi_libinfo.find_include_path()) + dlpack_include_path = [ os.path.join(p, "3rdparty", "tvm-ffi", "3rdparty", "dlpack", "include") for p in header_path diff --git a/python/tvm/runtime/support.py b/python/tvm/runtime/support.py index 4a2e9ef50847..07145a74612f 100644 --- a/python/tvm/runtime/support.py +++ b/python/tvm/runtime/support.py @@ -18,7 +18,7 @@ """Runtime support infra of TVM.""" import re -from typing import TypeVar +from typing import TypeVar, Type import tvm_ffi @@ -73,7 +73,7 @@ def _regex_match(regex_pattern: str, match_against: str) -> bool: T = TypeVar("T") -def derived_object(cls: type[T]) -> type[T]: +def derived_object(cls: Type[T]) -> Type[T]: """A decorator to register derived subclasses for TVM objects. Parameters diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index a08e66789fa3..11fd37ef2196 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -22,7 +22,7 @@ import sys import threading from numbers import Integral -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union # isort: off from typing_extensions import Literal @@ -1378,6 +1378,17 @@ def buffer_store( ) +def customized_code(code: str): + """Add a customized code block. + + Parameters + ---------- + code : str + The code block to be added. + """ + return _ffi_api.CustomizedCode(code) # type: ignore[attr-defined] # pylint: disable=no-member + + def evaluate(value: PrimExpr) -> None: """Evaluate the input expression. @@ -1419,173 +1430,321 @@ def func( return func - -# pylint: disable=invalid-name -int8 = func_gen(("Int8")) -int16 = func_gen(("Int16")) -int32 = func_gen(("Int32")) -int64 = func_gen(("Int64")) -int8x4 = func_gen(("Int8x4")) -int16x4 = func_gen(("Int16x4")) -int32x4 = func_gen(("Int32x4")) -int64x4 = func_gen(("Int64x4")) -int8x8 = func_gen(("Int8x8")) -int16x8 = func_gen(("Int16x8")) -int32x8 = func_gen(("Int32x8")) -int64x8 = func_gen(("Int64x8")) -int8x16 = func_gen(("Int8x16")) -int16x16 = func_gen(("Int16x16")) -int32x16 = func_gen(("Int32x16")) -int64x16 = func_gen(("Int64x16")) -int8x32 = func_gen(("Int8x32")) -int16x32 = func_gen(("Int16x32")) -int32x32 = func_gen(("Int32x32")) -int64x32 = func_gen(("Int64x32")) -int8x64 = func_gen(("Int8x64")) -int16x64 = func_gen(("Int16x64")) -int32x64 = func_gen(("Int32x64")) -int64x64 = func_gen(("Int64x64")) - -uint8 = func_gen(("UInt8")) -uint16 = func_gen(("UInt16")) -uint32 = func_gen(("UInt32")) -uint64 = func_gen(("UInt64")) -uint8x4 = func_gen(("UInt8x4")) -uint16x4 = func_gen(("UInt16x4")) -uint32x4 = func_gen(("UInt32x4")) -uint64x4 = func_gen(("UInt64x4")) -uint8x8 = func_gen(("UInt8x8")) -uint16x8 = func_gen(("UInt16x8")) -uint32x8 = func_gen(("UInt32x8")) -uint64x8 = func_gen(("UInt64x8")) -uint8x16 = func_gen(("UInt8x16")) -uint16x16 = func_gen(("UInt16x16")) -uint32x16 = func_gen(("UInt32x16")) -uint64x16 = func_gen(("UInt64x16")) -uint8x32 = func_gen(("UInt8x32")) -uint16x32 = func_gen(("UInt16x32")) -uint32x32 = func_gen(("UInt32x32")) -uint64x32 = func_gen(("UInt64x32")) -uint8x64 = func_gen(("UInt8x64")) -uint16x64 = func_gen(("UInt16x64")) -uint32x64 = func_gen(("UInt32x64")) -uint64x64 = func_gen(("UInt64x64")) - -float16 = func_gen(("Float16")) -float32 = func_gen(("Float32")) -float64 = func_gen(("Float64")) -float16x2 = func_gen(("Float16x2")) -float32x2 = func_gen(("Float32x2")) -float64x2 = func_gen(("Float64x2")) -float16x4 = func_gen(("Float16x4")) -float32x4 = func_gen(("Float32x4")) -float64x4 = func_gen(("Float64x4")) -float16x8 = func_gen(("Float16x8")) -float32x8 = func_gen(("Float32x8")) -float64x8 = func_gen(("Float64x8")) -float16x16 = func_gen(("Float16x16")) -float32x16 = func_gen(("Float32x16")) -float64x16 = func_gen(("Float64x16")) -float16x32 = func_gen(("Float16x32")) -float32x32 = func_gen(("Float32x32")) -float64x32 = func_gen(("Float64x32")) -float16x64 = func_gen(("Float16x64")) -float32x64 = func_gen(("Float32x64")) -float64x64 = func_gen(("Float64x64")) - -# Float8 variants -float8_e3m4 = func_gen(("Float8E3M4")) -float8_e3m4x2 = func_gen(("Float8E3M4x2")) -float8_e3m4x4 = func_gen(("Float8E3M4x4")) -float8_e3m4x8 = func_gen(("Float8E3M4x8")) -float8_e3m4x16 = func_gen(("Float8E3M4x16")) -float8_e3m4x32 = func_gen(("Float8E3M4x32")) -float8_e3m4x64 = func_gen(("Float8E3M4x64")) - -float8_e4m3 = func_gen(("Float8E4M3")) -float8_e4m3x2 = func_gen(("Float8E4M3x2")) -float8_e4m3x4 = func_gen(("Float8E4M3x4")) -float8_e4m3x8 = func_gen(("Float8E4M3x8")) -float8_e4m3x16 = func_gen(("Float8E4M3x16")) -float8_e4m3x32 = func_gen(("Float8E4M3x32")) -float8_e4m3x64 = func_gen(("Float8E4M3x64")) - -float8_e4m3b11fnuz = func_gen(("Float8E4M3B11FNUZ")) -float8_e4m3b11fnuzx2 = func_gen(("Float8E4M3B11FNUZx2")) -float8_e4m3b11fnuzx4 = func_gen(("Float8E4M3B11FNUZx4")) -float8_e4m3b11fnuzx8 = func_gen(("Float8E4M3B11FNUZx8")) -float8_e4m3b11fnuzx16 = func_gen(("Float8E4M3B11FNUZx16")) -float8_e4m3b11fnuzx32 = func_gen(("Float8E4M3B11FNUZx32")) -float8_e4m3b11fnuzx64 = func_gen(("Float8E4M3B11FNUZx64")) - -float8_e4m3fn = func_gen(("Float8E4M3FN")) -float8_e4m3fnx2 = func_gen(("Float8E4M3FNx2")) -float8_e4m3fnx4 = func_gen(("Float8E4M3FNx4")) -float8_e4m3fnx8 = func_gen(("Float8E4M3FNx8")) -float8_e4m3fnx16 = func_gen(("Float8E4M3FNx16")) -float8_e4m3fnx32 = func_gen(("Float8E4M3FNx32")) -float8_e4m3fnx64 = func_gen(("Float8E4M3FNx64")) - -float8_e4m3fnuz = func_gen(("Float8E4M3FNUZ")) -float8_e4m3fnuzx2 = func_gen(("Float8E4M3FNUZx2")) -float8_e4m3fnuzx4 = func_gen(("Float8E4M3FNUZx4")) -float8_e4m3fnuzx8 = func_gen(("Float8E4M3FNUZx8")) -float8_e4m3fnuzx16 = func_gen(("Float8E4M3FNUZx16")) -float8_e4m3fnuzx32 = func_gen(("Float8E4M3FNUZx32")) -float8_e4m3fnuzx64 = func_gen(("Float8E4M3FNUZx64")) - -float8_e5m2 = func_gen(("Float8E5M2")) -float8_e5m2x2 = func_gen(("Float8E5M2x2")) -float8_e5m2x4 = func_gen(("Float8E5M2x4")) -float8_e5m2x8 = func_gen(("Float8E5M2x8")) -float8_e5m2x16 = func_gen(("Float8E5M2x16")) -float8_e5m2x32 = func_gen(("Float8E5M2x32")) -float8_e5m2x64 = func_gen(("Float8E5M2x64")) - -float8_e5m2fnuz = func_gen(("Float8E5M2FNUZ")) -float8_e5m2fnuzx2 = func_gen(("Float8E5M2FNUZx2")) -float8_e5m2fnuzx4 = func_gen(("Float8E5M2FNUZx4")) -float8_e5m2fnuzx8 = func_gen(("Float8E5M2FNUZx8")) -float8_e5m2fnuzx16 = func_gen(("Float8E5M2FNUZx16")) -float8_e5m2fnuzx32 = func_gen(("Float8E5M2FNUZx32")) -float8_e5m2fnuzx64 = func_gen(("Float8E5M2FNUZx64")) - -float8_e8m0fnu = func_gen(("Float8E8M0FNU")) -float8_e8m0fnux2 = func_gen(("Float8E8M0FNUx2")) -float8_e8m0fnux4 = func_gen(("Float8E8M0FNUx4")) -float8_e8m0fnux8 = func_gen(("Float8E8M0FNUx8")) -float8_e8m0fnux16 = func_gen(("Float8E8M0FNUx16")) -float8_e8m0fnux32 = func_gen(("Float8E8M0FNUx32")) -float8_e8m0fnux64 = func_gen(("Float8E8M0FNUx64")) - -# Float6 variants -float6_e2m3fn = func_gen(("Float6E2M3FN")) -float6_e2m3fnx2 = func_gen(("Float6E2M3FNx2")) -float6_e2m3fnx4 = func_gen(("Float6E2M3FNx4")) -float6_e2m3fnx8 = func_gen(("Float6E2M3FNx8")) -float6_e2m3fnx16 = func_gen(("Float6E2M3FNx16")) -float6_e2m3fnx32 = func_gen(("Float6E2M3FNx32")) -float6_e2m3fnx64 = func_gen(("Float6E2M3FNx64")) - -float6_e3m2fn = func_gen(("Float6E3M2FN")) -float6_e3m2fnx2 = func_gen(("Float6E3M2FNx2")) -float6_e3m2fnx4 = func_gen(("Float6E3M2FNx4")) -float6_e3m2fnx8 = func_gen(("Float6E3M2FNx8")) -float6_e3m2fnx16 = func_gen(("Float6E3M2FNx16")) -float6_e3m2fnx32 = func_gen(("Float6E3M2FNx32")) -float6_e3m2fnx64 = func_gen(("Float6E3M2FNx64")) - -# Float4 variants -float4_e2m1fn = func_gen(("Float4E2M1FN")) -float4_e2m1fnx2 = func_gen(("Float4E2M1FNx2")) -float4_e2m1fnx4 = func_gen(("Float4E2M1FNx4")) -float4_e2m1fnx8 = func_gen(("Float4E2M1FNx8")) -float4_e2m1fnx16 = func_gen(("Float4E2M1FNx16")) -float4_e2m1fnx32 = func_gen(("Float4E2M1FNx32")) -float4_e2m1fnx64 = func_gen(("Float4E2M1FNx64")) - -bfloat16 = func_gen(("BFloat16")) -# pylint: enable=invalid-name +if TYPE_CHECKING: + class int8: ... + class int16: ... + class int32: ... + class int64: ... + class int8x4: ... + class int16x4: ... + class int32x4: ... + class int64x4: ... + class int8x8: ... + class int16x8: ... + class int32x8: ... + class int64x8: ... + class int8x16: ... + class int16x16: ... + class int32x16: ... + class int64x16: ... + class int8x32: ... + class int16x32: ... + class int32x32: ... + class int64x32: ... + class int8x64: ... + class int16x64: ... + class int32x64: ... + class int64x64: ... + class uint8: ... + class uint16: ... + class uint32: ... + class uint64: ... + class uint8x4: ... + class uint16x4: ... + class uint32x4: ... + class uint64x4: ... + class uint8x8: ... + class uint16x8: ... + class uint32x8: ... + class uint64x8: ... + class uint8x16: ... + class uint16x16: ... + class uint32x16: ... + class uint64x16: ... + class uint8x32: ... + class uint16x32: ... + class uint32x32: ... + class uint64x32: ... + class uint8x64: ... + class uint16x64: ... + class uint32x64: ... + class uint64x64: ... + class float16: ... + class float32: ... + class float64: ... + class float16x2: ... + class float32x2: ... + class float64x2: ... + class float16x4: ... + class float32x4: ... + class float64x4: ... + class float16x8: ... + class float32x8: ... + class float64x8: ... + class float16x16: ... + class float32x16: ... + class float64x16: ... + class float16x32: ... + class float32x32: ... + class float64x32: ... + class float16x64: ... + class float32x64: ... + class float64x64: ... + class float8_e3m4: ... + class float8_e3m4x2: ... + class float8_e3m4x4: ... + class float8_e3m4x8: ... + class float8_e3m4x16: ... + class float8_e3m4x32: ... + class float8_e3m4x64: ... + class float8_e4m3: ... + class float8_e4m3x2: ... + class float8_e4m3x4: ... + class float8_e4m3x8: ... + class float8_e4m3x16: ... + class float8_e4m3x32: ... + class float8_e4m3x64: ... + class float8_e4m3b11fnuz: ... + class float8_e4m3b11fnuzx2: ... + class float8_e4m3b11fnuzx4: ... + class float8_e4m3b11fnuzx8: ... + class float8_e4m3b11fnuzx16: ... + class float8_e4m3b11fnuzx32: ... + class float8_e4m3b11fnuzx64: ... + class float8_e4m3fn: ... + class float8_e4m3fnx2: ... + class float8_e4m3fnx4: ... + class float8_e4m3fnx8: ... + class float8_e4m3fnx16: ... + class float8_e4m3fnx32: ... + class float8_e4m3fnx64: ... + class float8_e4m3fnuz: ... + class float8_e4m3fnuzx2: ... + class float8_e4m3fnuzx4: ... + class float8_e4m3fnuzx8: ... + class float8_e4m3fnuzx16: ... + class float8_e4m3fnuzx32: ... + class float8_e4m3fnuzx64: ... + class float8_e5m2: ... + class float8_e5m2x2: ... + class float8_e5m2x4: ... + class float8_e5m2x8: ... + class float8_e5m2x16: ... + class float8_e5m2x32: ... + class float8_e5m2x64: ... + class float8_e5m2fnuz: ... + class float8_e5m2fnuzx2: ... + class float8_e5m2fnuzx4: ... + class float8_e5m2fnuzx8: ... + class float8_e5m2fnuzx16: ... + class float8_e5m2fnuzx32: ... + class float8_e5m2fnuzx64: ... + class float8_e8m0fnu: ... + class float8_e8m0fnux2: ... + class float8_e8m0fnux4: ... + class float8_e8m0fnux8: ... + class float8_e8m0fnux16: ... + class float8_e8m0fnux32: ... + class float8_e8m0fnux64: ... + class float6_e2m3fn: ... + class float6_e2m3fnx2: ... + class float6_e2m3fnx4: ... + class float6_e2m3fnx8: ... + class float6_e2m3fnx16: ... + class float6_e2m3fnx32: ... + class float6_e2m3fnx64: ... + class float6_e3m2fn: ... + class float6_e3m2fnx2: ... + class float6_e3m2fnx4: ... + class float6_e3m2fnx8: ... + class float6_e3m2fnx16: ... + class float6_e3m2fnx32: ... + class float6_e3m2fnx64: ... + class float4_e2m1fn: ... + class float4_e2m1fnx2: ... + class float4_e2m1fnx4: ... + class float4_e2m1fnx8: ... + class float4_e2m1fnx16: ... + class float4_e2m1fnx32: ... + class float4_e2m1fnx64: ... + class bfloat16: ... +else: + # pylint: disable=invalid-name + int8 = func_gen(("Int8")) + int16 = func_gen(("Int16")) + int32 = func_gen(("Int32")) + int64 = func_gen(("Int64")) + int8x4 = func_gen(("Int8x4")) + int16x4 = func_gen(("Int16x4")) + int32x4 = func_gen(("Int32x4")) + int64x4 = func_gen(("Int64x4")) + int8x8 = func_gen(("Int8x8")) + int16x8 = func_gen(("Int16x8")) + int32x8 = func_gen(("Int32x8")) + int64x8 = func_gen(("Int64x8")) + int8x16 = func_gen(("Int8x16")) + int16x16 = func_gen(("Int16x16")) + int32x16 = func_gen(("Int32x16")) + int64x16 = func_gen(("Int64x16")) + int8x32 = func_gen(("Int8x32")) + int16x32 = func_gen(("Int16x32")) + int32x32 = func_gen(("Int32x32")) + int64x32 = func_gen(("Int64x32")) + int8x64 = func_gen(("Int8x64")) + int16x64 = func_gen(("Int16x64")) + int32x64 = func_gen(("Int32x64")) + int64x64 = func_gen(("Int64x64")) + + uint8 = func_gen(("UInt8")) + uint16 = func_gen(("UInt16")) + uint32 = func_gen(("UInt32")) + uint64 = func_gen(("UInt64")) + uint8x4 = func_gen(("UInt8x4")) + uint16x4 = func_gen(("UInt16x4")) + uint32x4 = func_gen(("UInt32x4")) + uint64x4 = func_gen(("UInt64x4")) + uint8x8 = func_gen(("UInt8x8")) + uint16x8 = func_gen(("UInt16x8")) + uint32x8 = func_gen(("UInt32x8")) + uint64x8 = func_gen(("UInt64x8")) + uint8x16 = func_gen(("UInt8x16")) + uint16x16 = func_gen(("UInt16x16")) + uint32x16 = func_gen(("UInt32x16")) + uint64x16 = func_gen(("UInt64x16")) + uint8x32 = func_gen(("UInt8x32")) + uint16x32 = func_gen(("UInt16x32")) + uint32x32 = func_gen(("UInt32x32")) + uint64x32 = func_gen(("UInt64x32")) + uint8x64 = func_gen(("UInt8x64")) + uint16x64 = func_gen(("UInt16x64")) + uint32x64 = func_gen(("UInt32x64")) + uint64x64 = func_gen(("UInt64x64")) + + float16 = func_gen(("Float16")) + float32 = func_gen(("Float32")) + float64 = func_gen(("Float64")) + float16x2 = func_gen(("Float16x2")) + float32x2 = func_gen(("Float32x2")) + float64x2 = func_gen(("Float64x2")) + float16x4 = func_gen(("Float16x4")) + float32x4 = func_gen(("Float32x4")) + float64x4 = func_gen(("Float64x4")) + float16x8 = func_gen(("Float16x8")) + float32x8 = func_gen(("Float32x8")) + float64x8 = func_gen(("Float64x8")) + float16x16 = func_gen(("Float16x16")) + float32x16 = func_gen(("Float32x16")) + float64x16 = func_gen(("Float64x16")) + float16x32 = func_gen(("Float16x32")) + float32x32 = func_gen(("Float32x32")) + float64x32 = func_gen(("Float64x32")) + float16x64 = func_gen(("Float16x64")) + float32x64 = func_gen(("Float32x64")) + float64x64 = func_gen(("Float64x64")) + + # Float8 variants + float8_e3m4 = func_gen(("Float8E3M4")) + float8_e3m4x2 = func_gen(("Float8E3M4x2")) + float8_e3m4x4 = func_gen(("Float8E3M4x4")) + float8_e3m4x8 = func_gen(("Float8E3M4x8")) + float8_e3m4x16 = func_gen(("Float8E3M4x16")) + float8_e3m4x32 = func_gen(("Float8E3M4x32")) + float8_e3m4x64 = func_gen(("Float8E3M4x64")) + + float8_e4m3 = func_gen(("Float8E4M3")) + float8_e4m3x2 = func_gen(("Float8E4M3x2")) + float8_e4m3x4 = func_gen(("Float8E4M3x4")) + float8_e4m3x8 = func_gen(("Float8E4M3x8")) + float8_e4m3x16 = func_gen(("Float8E4M3x16")) + float8_e4m3x32 = func_gen(("Float8E4M3x32")) + float8_e4m3x64 = func_gen(("Float8E4M3x64")) + + float8_e4m3b11fnuz = func_gen(("Float8E4M3B11FNUZ")) + float8_e4m3b11fnuzx2 = func_gen(("Float8E4M3B11FNUZx2")) + float8_e4m3b11fnuzx4 = func_gen(("Float8E4M3B11FNUZx4")) + float8_e4m3b11fnuzx8 = func_gen(("Float8E4M3B11FNUZx8")) + float8_e4m3b11fnuzx16 = func_gen(("Float8E4M3B11FNUZx16")) + float8_e4m3b11fnuzx32 = func_gen(("Float8E4M3B11FNUZx32")) + float8_e4m3b11fnuzx64 = func_gen(("Float8E4M3B11FNUZx64")) + + float8_e4m3fn = func_gen(("Float8E4M3FN")) + float8_e4m3fnx2 = func_gen(("Float8E4M3FNx2")) + float8_e4m3fnx4 = func_gen(("Float8E4M3FNx4")) + float8_e4m3fnx8 = func_gen(("Float8E4M3FNx8")) + float8_e4m3fnx16 = func_gen(("Float8E4M3FNx16")) + float8_e4m3fnx32 = func_gen(("Float8E4M3FNx32")) + float8_e4m3fnx64 = func_gen(("Float8E4M3FNx64")) + + float8_e4m3fnuz = func_gen(("Float8E4M3FNUZ")) + float8_e4m3fnuzx2 = func_gen(("Float8E4M3FNUZx2")) + float8_e4m3fnuzx4 = func_gen(("Float8E4M3FNUZx4")) + float8_e4m3fnuzx8 = func_gen(("Float8E4M3FNUZx8")) + float8_e4m3fnuzx16 = func_gen(("Float8E4M3FNUZx16")) + float8_e4m3fnuzx32 = func_gen(("Float8E4M3FNUZx32")) + float8_e4m3fnuzx64 = func_gen(("Float8E4M3FNUZx64")) + + float8_e5m2 = func_gen(("Float8E5M2")) + float8_e5m2x2 = func_gen(("Float8E5M2x2")) + float8_e5m2x4 = func_gen(("Float8E5M2x4")) + float8_e5m2x8 = func_gen(("Float8E5M2x8")) + float8_e5m2x16 = func_gen(("Float8E5M2x16")) + float8_e5m2x32 = func_gen(("Float8E5M2x32")) + float8_e5m2x64 = func_gen(("Float8E5M2x64")) + + float8_e5m2fnuz = func_gen(("Float8E5M2FNUZ")) + float8_e5m2fnuzx2 = func_gen(("Float8E5M2FNUZx2")) + float8_e5m2fnuzx4 = func_gen(("Float8E5M2FNUZx4")) + float8_e5m2fnuzx8 = func_gen(("Float8E5M2FNUZx8")) + float8_e5m2fnuzx16 = func_gen(("Float8E5M2FNUZx16")) + float8_e5m2fnuzx32 = func_gen(("Float8E5M2FNUZx32")) + float8_e5m2fnuzx64 = func_gen(("Float8E5M2FNUZx64")) + + float8_e8m0fnu = func_gen(("Float8E8M0FNU")) + float8_e8m0fnux2 = func_gen(("Float8E8M0FNUx2")) + float8_e8m0fnux4 = func_gen(("Float8E8M0FNUx4")) + float8_e8m0fnux8 = func_gen(("Float8E8M0FNUx8")) + float8_e8m0fnux16 = func_gen(("Float8E8M0FNUx16")) + float8_e8m0fnux32 = func_gen(("Float8E8M0FNUx32")) + float8_e8m0fnux64 = func_gen(("Float8E8M0FNUx64")) + + # Float6 variants + float6_e2m3fn = func_gen(("Float6E2M3FN")) + float6_e2m3fnx2 = func_gen(("Float6E2M3FNx2")) + float6_e2m3fnx4 = func_gen(("Float6E2M3FNx4")) + float6_e2m3fnx8 = func_gen(("Float6E2M3FNx8")) + float6_e2m3fnx16 = func_gen(("Float6E2M3FNx16")) + float6_e2m3fnx32 = func_gen(("Float6E2M3FNx32")) + float6_e2m3fnx64 = func_gen(("Float6E2M3FNx64")) + + float6_e3m2fn = func_gen(("Float6E3M2FN")) + float6_e3m2fnx2 = func_gen(("Float6E3M2FNx2")) + float6_e3m2fnx4 = func_gen(("Float6E3M2FNx4")) + float6_e3m2fnx8 = func_gen(("Float6E3M2FNx8")) + float6_e3m2fnx16 = func_gen(("Float6E3M2FNx16")) + float6_e3m2fnx32 = func_gen(("Float6E3M2FNx32")) + float6_e3m2fnx64 = func_gen(("Float6E3M2FNx64")) + + # Float4 variants + float4_e2m1fn = func_gen(("Float4E2M1FN")) + float4_e2m1fnx2 = func_gen(("Float4E2M1FNx2")) + float4_e2m1fnx4 = func_gen(("Float4E2M1FNx4")) + float4_e2m1fnx8 = func_gen(("Float4E2M1FNx8")) + float4_e2m1fnx16 = func_gen(("Float4E2M1FNx16")) + float4_e2m1fnx32 = func_gen(("Float4E2M1FNx32")) + float4_e2m1fnx64 = func_gen(("Float4E2M1FNx64")) + + bfloat16 = func_gen(("BFloat16")) + # pylint: enable=invalid-name def boolean(expr: Optional[PrimExpr] = None, is_size_var: bool = False) -> PrimExpr: diff --git a/python/tvm/script/parser/core/doc.py b/python/tvm/script/parser/core/doc.py index 74174f066727..f8c400ad1667 100644 --- a/python/tvm/script/parser/core/doc.py +++ b/python/tvm/script/parser/core/doc.py @@ -18,6 +18,7 @@ import ast import inspect +import sys import typing from collections import defaultdict @@ -318,4 +319,150 @@ def __call__(self, node): ) + +def _py_version() -> typing.Tuple[int, int]: + return (sys.version_info.major, sys.version_info.minor) + + +def _register_constant_handling(): + if _py_version() not in [(3, 6), (3, 7)]: + return + + def as_constant(f) -> doc.Constant: + def to_doc_func(x: ast.AST) -> doc.Constant: + return doc.Constant( + value=getattr(x, f) if isinstance(f, str) else f(x), + kind=None, + lineno=x.lineno, + col_offset=x.col_offset, + end_lineno=x.lineno, + end_col_offset=x.col_offset, + ) + + return to_doc_func + + register_to_doc("Str")(as_constant("s")) + register_to_doc("NameConstant")(as_constant("value")) + register_to_doc("Num")(as_constant("n")) + register_to_doc("Bytes")(as_constant("s")) + register_to_doc("Ellipsis")(as_constant(lambda _: ...)) + + +def _register_subscription_handling(): + if _py_version() >= (3, 9): + return + + def subscript_to_doc(x: ast.Subscript) -> doc.Subscript: + if isinstance(x.slice, ast.Slice): + return doc.Subscript( + value=to_doc(x.value), + slice=doc.Slice( + lower=to_doc(x.slice.lower), + upper=to_doc(x.slice.upper), + step=to_doc(x.slice.step), + lineno=getattr(x.slice, "lineno", None), + col_offset=getattr(x.slice, "col_offset", None), + end_lineno=getattr(x.slice, "end_lineno", None), + end_col_offset=getattr(x.slice, "end_col_offset", None), + ), + ctx=to_doc(x.ctx), + lineno=getattr(x, "lineno", None), + col_offset=getattr(x, "col_offset", None), + end_lineno=getattr(x, "end_lineno", None), + end_col_offset=getattr(x, "end_col_offset", None), + ) + if isinstance(x.slice, ast.ExtSlice): + return doc.Subscript( + value=to_doc(x.value), + slice=doc.Tuple( + elts=[to_doc(i) for i in x.slice.dims], + ctx=doc.Load( + lineno=None, + col_offset=None, + end_lineno=None, + end_col_offset=None, + ), + lineno=getattr(x, "lineno", None), + col_offset=getattr(x, "col_offset", None), + end_lineno=getattr(x, "end_lineno", None), + end_col_offset=getattr(x, "end_col_offset", None), + ), + ctx=to_doc(x.ctx), + lineno=getattr(x, "lineno", None), + col_offset=getattr(x, "col_offset", None), + end_lineno=getattr(x, "end_lineno", None), + end_col_offset=getattr(x, "end_col_offset", None), + ) + if isinstance(x.slice, ast.Index): + return doc.Subscript( + value=to_doc(x.value), + slice=to_doc(x.slice.value), + ctx=to_doc(x.ctx), + lineno=getattr(x, "lineno", None), + col_offset=getattr(x, "col_offset", None), + end_lineno=getattr(x, "end_lineno", None), + end_col_offset=getattr(x, "end_col_offset", None), + ) + raise TypeError(f"Unknown subscript type: {type(x.slice)}") + + def subscript_from_doc(x: doc.Subscript) -> ast.Subscript: + if isinstance(x.slice, doc.Slice): + result = ast.Subscript( + value=from_doc(x.value), + slice=from_doc(x.slice), + ctx=from_doc(x.ctx), + ) + elif isinstance(x.slice, doc.Tuple): + + def remap_dim(doc_item: doc.Expr) -> ast.Expr: + ast_item = from_doc(doc_item) + if isinstance(ast_item, (ast.Index, ast.Slice)): + return ast_item + return ast.Index(value=ast_item) + + # ast.ExtSlice requires a non-empty list of dims, and each dim must be either + # a Slice or an Index. + if x.slice.elts: + ast_slice = ast.ExtSlice(dims=[*map(remap_dim, x.slice.elts)]) + else: + ast_slice = ast.Index(value=ast.Tuple(elts=[], ctx=from_doc(x.ctx))) + result = ast.Subscript(value=from_doc(x.value), slice=ast_slice, ctx=from_doc(x.ctx)) + else: + result = ast.Subscript( + value=from_doc(x.value), + slice=ast.Index(value=from_doc(x.slice)), + ctx=from_doc(x.ctx), + ) + result.lineno = x.lineno + result.col_offset = x.col_offset + result.end_lineno = x.end_lineno + result.end_col_offset = x.end_col_offset + return result + + register_to_doc("Subscript")(subscript_to_doc) + register_from_doc("Subscript")(subscript_from_doc) + + +def _register_index_handling(): + if _py_version() >= (3, 9): + return + + def index_to_doc(x: ast.Index) -> doc.Expr: + return to_doc(x.value) + + def index_from_doc(x: doc.Expr) -> ast.Index: + result = ast.Index(value=from_doc(x), ctx=from_doc(x.ctx)) + result.lineno = x.lineno + result.col_offset = x.col_offset + result.end_lineno = x.end_lineno + result.end_col_offset = x.end_col_offset + return result + + register_to_doc("Index")(index_to_doc) + register_from_doc("Index")(index_from_doc) + + _register_default() +_register_constant_handling() +_register_subscription_handling() +_register_index_handling() diff --git a/python/tvm/script/parser/core/evaluator.py b/python/tvm/script/parser/core/evaluator.py index 7668fa99e611..f23c69824bde 100644 --- a/python/tvm/script/parser/core/evaluator.py +++ b/python/tvm/script/parser/core/evaluator.py @@ -57,6 +57,7 @@ doc.Not: lambda a: not a, doc.UAdd: lambda a: +a, doc.USub: lambda a: -a, + doc.IfExp: tvm.tir.op.if_then_else, } @@ -174,8 +175,8 @@ def _visit(self, node: doc.AST) -> Any: if ( isinstance(node, doc.Call) and hasattr(node.func, "attr") - and node.func.attr not in ["reads", "writes", "match_buffer", "realize"] - ) or isinstance(node, (doc.BinOp, doc.UnaryOp, doc.Compare, doc.BoolOp, doc.IfExp)): + and node.func.attr not in ["reads", "writes", "match_buffer", "realize", "copy"] + ) or isinstance(node, (doc.BinOp, doc.UnaryOp, doc.Compare, doc.BoolOp)): if isinstance(node, doc.BinOp): args = [node.left, node.right] elif isinstance(node, doc.UnaryOp): @@ -529,14 +530,14 @@ def _eval_expr( def _eval_op( - op: doc.AST, + op_or_type: Union[doc.AST, Type], values: List[Any], ): """Operation expression evaluation implementation for TVMScript parser. Parameters ---------- - op : doc.AST + op_or_type : Union[doc.AST, Type] The root node of AST tree node of operation expression to evaluate. values : List[Any] @@ -547,7 +548,9 @@ def _eval_op( res : Any The evaluation result. """ - op_type = type(op) # pylint: disable=protected-access + op_type = ( + type(op_or_type) if isinstance(op_or_type, doc.AST) else op_or_type + ) # pylint: disable=protected-access for i, v in enumerate(values): v_type = getattr(type(v), "_dispatch_type", None) if v_type is None: diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index f8cbc0b4f5bc..92244d5a0472 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -22,7 +22,7 @@ import tvm from tvm.ir import GlobalVar, PrimType -from tvm.tir import Buffer, IterVar, PrimExpr, Var +from tvm.tir import Buffer, BufferLoad, IterVar, PrimExpr, Var from ...ir_builder import ir as I from ...ir_builder import tir as T @@ -138,6 +138,9 @@ def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) - res = value.__enter__() IRBuilder.name(var_name, res) return res + elif isinstance(value, Buffer) and value.scope() == "local.var": + IRBuilder.name(var_name, value) + return BufferLoad(value, indices=[0]) elif isinstance(value, (Buffer, IterVar)) or ( isinstance(value, Var) and not self.var_table.exist(value) ): @@ -277,8 +280,21 @@ def visit_assign(self: Parser, node: doc.Assign) -> None: else: indices = self.eval_expr(lhs.slice) T.buffer_store(self.eval_expr(lhs.value), rhs, indices) - else: - self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value) + return + + # Handle local.var buffer store + if isinstance(lhs, doc.Name) and lhs.id in self.var_table.get(): + lhs_value = self.eval_expr(lhs) + if ( + isinstance(lhs_value, BufferLoad) + and lhs_value.buffer.scope() == "local.var" + and len(lhs_value.indices) == 1 + and lhs_value.indices[0] == 0 + ): + T.buffer_store(lhs_value.buffer, rhs, indices=[0]) + return + + self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value) @dispatch.register(token="tir", type_name="AugAssign") diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index 259017608275..f333c14986f2 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -195,6 +195,8 @@ def __getitem__(self, indices): indices = [indices] has_slice = any(isinstance(i, slice) for i in indices) has_step = any(isinstance(i, slice) and i.step is not None for i in indices) + if has_step: + raise RuntimeError("Buffer slicing with step is not supported.") analyzer = Analyzer() if has_slice and not has_step: region = [] diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index f5476230c19b..ecfd90acc13b 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -27,7 +27,7 @@ assert(isinstance(y, tvm.tir.Add)) assert(y.a == x) """ -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union import tvm_ffi import tvm.ir._ffi_api @@ -1257,6 +1257,9 @@ class Call(PrimExprWithOp): args : list of Expr The input arguments to the call + annotations : Optional[Dict[str, Object]] + Additional annotations about the call. + span : Optional[Span] The location of this expression in the source code. """ @@ -1265,7 +1268,12 @@ class Call(PrimExprWithOp): args: List[PrimExpr] def __init__( - self, dtype: str, op: Union[Op, str], args: List[PrimExpr], span: Optional[Span] = None + self, + dtype: str, + op: Union[Op, str], + args: List[PrimExpr], + annotations: Optional[Dict] = None, + span: Optional[Span] = None, ) -> None: if isinstance(op, str): if not op.startswith("tir."): @@ -1278,7 +1286,7 @@ def __init__( % op ) op = Op.get(op) - self.__init_handle_by_constructor__(_ffi_api.Call, dtype, op, args, span) # type: ignore + self.__init_handle_by_constructor__(_ffi_api.Call, dtype, op, args, annotations, span) # type: ignore @tvm_ffi.register_object("tir.Let") diff --git a/python/tvm/tir/functor.py b/python/tvm/tir/functor.py index c2594835fedf..d5bc20b76f9f 100644 --- a/python/tvm/tir/functor.py +++ b/python/tvm/tir/functor.py @@ -362,7 +362,6 @@ def visit_attr_stmt_(self, op: AttrStmt) -> None: op : AttrStmt The AttrStmt to be visited. """ - print("visit_attr_stmt_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_if_then_else_(self, op: IfThenElse) -> None: @@ -375,7 +374,6 @@ def visit_if_then_else_(self, op: IfThenElse) -> None: op : IfThenElse The IfThenElse to be visited. """ - print("visit_if_then_else_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_let_stmt_(self, op: LetStmt) -> None: @@ -388,7 +386,6 @@ def visit_let_stmt_(self, op: LetStmt) -> None: op : LetStmt The LetStmt to be visited. """ - print("visit_let_stmt_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_for_(self, op: For) -> None: @@ -401,7 +398,6 @@ def visit_for_(self, op: For) -> None: op : For The For to be visited. """ - print("visit_for_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_while_(self, op: While) -> None: @@ -414,7 +410,6 @@ def visit_while_(self, op: While) -> None: op : While The While to be visited. """ - print("visit_while_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_allocate_(self, op: Allocate) -> None: @@ -427,7 +422,6 @@ def visit_allocate_(self, op: Allocate) -> None: op : Allocate The Allocate to be visited. """ - print("visit_allocate_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_allocate_const_(self, op: AllocateConst) -> None: @@ -440,7 +434,6 @@ def visit_allocate_const_(self, op: AllocateConst) -> None: op : AllocateConst The AllocateConst to be visited. """ - print("visit_allocate_const_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_decl_buffer_(self, op: DeclBuffer) -> None: @@ -453,7 +446,6 @@ def visit_decl_buffer_(self, op: DeclBuffer) -> None: op : DeclBuffer The DeclBuffer to be visited. """ - print("visit_decl_buffer_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_buffer_store_(self, op: BufferStore) -> None: @@ -466,7 +458,6 @@ def visit_buffer_store_(self, op: BufferStore) -> None: op : BufferStore The BufferStore to be visited. """ - print("visit_buffer_store_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_buffer_realize_(self, op: BufferRealize) -> None: @@ -479,7 +470,6 @@ def visit_buffer_realize_(self, op: BufferRealize) -> None: op : BufferRealize The BufferRealize to be visited. """ - print("visit_buffer_realize_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_assert_stmt_(self, op: AssertStmt) -> None: @@ -492,7 +482,6 @@ def visit_assert_stmt_(self, op: AssertStmt) -> None: op : AssertStmt The AssertStmt to be visited. """ - print("visit_assert_stmt_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_seq_stmt_(self, op: SeqStmt) -> None: @@ -505,7 +494,6 @@ def visit_seq_stmt_(self, op: SeqStmt) -> None: op : SeqStmt The SeqStmt to be visited. """ - print("visit_seq_stmt_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_evaluate_(self, op: Evaluate) -> None: @@ -518,7 +506,6 @@ def visit_evaluate_(self, op: Evaluate) -> None: op : Evaluate The Evaluate to be visited. """ - print("visit_evaluate_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_block_(self, op: Block) -> None: @@ -531,7 +518,6 @@ def visit_block_(self, op: Block) -> None: op : Block The Block to be visited. """ - print("visit_block_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_block_realize_(self, op: BlockRealize) -> None: @@ -544,7 +530,6 @@ def visit_block_realize_(self, op: BlockRealize) -> None: op : BlockRealize The BlockRealize to be visited. """ - print("visit_block_realize_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore def visit_var_(self, op: Var) -> None: diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 7f3badcfebad..2e96d98489a8 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -42,7 +42,7 @@ def _pack_buffer(buf, span=None): const(0, dtype=buf.dtype), buf.elem_offset, ] - return Call("handle", Op.get("tir.tvm_stack_make_array"), pack_args, span) + return Call("handle", Op.get("tir.tvm_stack_make_array"), pack_args, span=span) def call_packed_lowered(*args, span=None): @@ -71,7 +71,7 @@ def call_packed_lowered(*args, span=None): te.extern : Create tensor with extern function call. """ call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] - return Call("int32", Op.get("tir.tvm_call_packed_lowered"), call_args, span) + return Call("int32", Op.get("tir.tvm_call_packed_lowered"), call_args, span=span) def call_cpacked_lowered(*args, span=None): @@ -97,7 +97,7 @@ def call_cpacked_lowered(*args, span=None): te.extern : Create tensor with extern function call. """ call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] - return Call("int32", Op.get("tir.tvm_call_cpacked_lowered"), call_args, span) + return Call("int32", Op.get("tir.tvm_call_cpacked_lowered"), call_args, span=span) def call_packed(*args, span=None): @@ -128,7 +128,7 @@ def call_packed(*args, span=None): te.extern : Create tensor with extern function call. """ call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] - return Call("int32", Op.get("tir.tvm_call_packed"), call_args, span) + return Call("int32", Op.get("tir.tvm_call_packed"), call_args, span=span) def call_cpacked(*args, span=None): @@ -155,10 +155,10 @@ def call_cpacked(*args, span=None): te.extern : Create tensor with extern function call. """ call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] - return Call("int32", Op.get("tir.tvm_call_cpacked"), call_args, span) + return Call("int32", Op.get("tir.tvm_call_cpacked"), call_args, span=span) -def call_intrin(dtype, func_name, *args, span=None): +def call_intrin(dtype, func_name, *args, annotations=None, span=None): """Build expression by calling an intrinsic function. Intrinsics can be overloaded with multiple data types via @@ -175,6 +175,9 @@ def call_intrin(dtype, func_name, *args, span=None): args : list Positional arguments. + annotations : Optional[Dict[str, Object]] + Additional annotations about the call. + span : Optional[Span] The location of this operator in the source code. @@ -183,7 +186,11 @@ def call_intrin(dtype, func_name, *args, span=None): call : PrimExpr The call expression. """ - return Call(dtype, func_name, args, span) + + # Convert to TVM Map + if annotations is not None: + annotations = {k: tir.const(v) if isinstance(v, (int, bool)) else v for k, v in annotations.items()} + return Call(dtype, func_name, args, annotations=annotations, span=span) def call_pure_extern(dtype, func_name, *args, span=None): @@ -208,7 +215,7 @@ def call_pure_extern(dtype, func_name, *args, span=None): call : PrimExpr The call expression. """ - return Call(dtype, Op.get("tir.call_pure_extern"), [func_name, *args], span) + return Call(dtype, Op.get("tir.call_pure_extern"), [func_name, *args], span=span) def call_extern(dtype, func_name, *args, span=None): @@ -571,11 +578,10 @@ def address_of(obj: Union[Buffer, BufferLoad], span: Optional[Span] = None) -> P The call expression. """ if isinstance(obj, Buffer): - n_dim = len(obj.shape) buffer_load = BufferLoad(obj, [0] * n_dim) return call_intrin("handle", "tir.address_of", buffer_load, span=span) - elif isinstance(obj, BufferLoad): + elif isinstance(obj, (BufferLoad, Var)): return call_intrin("handle", "tir.address_of", obj, span=span) else: raise ValueError(f"Invalid object type: {type(obj)}") @@ -1885,6 +1891,7 @@ def ret(val, span=None): def thread_return(span=None): """Return from a GPU thread + Parameters ---------- span : Optional[Span] diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 0d41ffe94307..95effb643fd7 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -1910,7 +1910,8 @@ def resize_cache_index( @type_checked def reindex( - self, block: Union[BlockRV, str], buffer: Union[Tuple[str, int], str, Buffer] + self, block: Union[BlockRV, str], buffer: Union[Tuple[str, int], str, Buffer], + skip_simplify: bool = False, ) -> BlockRV: """Create a block that read/write a buffer region into a read/write cache with reindexing. The layout of the cache will be the same as by the iterators of the block that reads/writes @@ -1942,6 +1943,9 @@ def reindex( If `buffer` is a Buffer object, it must exist within the reads/writes of the block. + skip_simplify: bool + Whether to skip the simplification of the indices. + Returns ------- reindex_block : BlockRV @@ -1997,7 +2001,7 @@ def after_reindex( assert buffer_index_type in ["read", "write"], "Invalid buffer_index_type" buffer_index_type_enum = 0 if buffer_index_type == "read" else 1 return _ffi_api.ScheduleReIndex( # type: ignore # pylint: disable=no-member - self, block, buffer_index, buffer_index_type_enum + self, block, buffer_index, buffer_index_type_enum, skip_simplify ) ########## Schedule: Data movement ########## diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index f6f0b9f4d8df..3b5f9a3712d3 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include "./scalable_expression.h" #include "const_fold.h" @@ -38,7 +39,21 @@ Analyzer::Analyzer() modular_set(this), rewrite_simplify(this), canonical_simplify(this), - int_set(this) {} + int_set(this), + z3_prover(this) {} + +std::unique_ptr Analyzer::Clone() const { + auto cloned = std::make_unique(); + // Copy per-sub-analyzer states + cloned->const_int_bound.CopyFrom(this->const_int_bound); + cloned->modular_set.CopyFrom(this->modular_set); + cloned->rewrite_simplify.CopyFrom(this->rewrite_simplify); + cloned->canonical_simplify.CopyFrom(this->canonical_simplify); + cloned->int_set.CopyFrom(this->int_set); + cloned->transitive_comparisons.CopyFrom(this->transitive_comparisons); + cloned->z3_prover.CopyFrom(this->z3_prover); + return cloned; +} void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) { PrimExpr new_expr = expr; @@ -51,6 +66,7 @@ void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) { this->canonical_simplify.Update(var, new_expr, allow_override); this->int_set.Update(var, this->int_set(new_expr), allow_override); this->transitive_comparisons.Bind(var, expr, allow_override); + this->z3_prover.Bind(var, expr, allow_override); } void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) { @@ -61,6 +77,7 @@ void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) { this->const_int_bound.Bind(var, range, allow_override); this->int_set.Bind(var, range, allow_override); this->transitive_comparisons.Bind(var, range, allow_override); + this->z3_prover.Bind(var, range, allow_override); } // skip modular_set // skip rewrite simplify @@ -127,9 +144,10 @@ void ConstraintContext::EnterWithScope() { // entering the scope. recovery_functions_.push_back(analyzer_->const_int_bound.EnterConstraint(constraint_)); recovery_functions_.push_back(analyzer_->modular_set.EnterConstraint(constraint_)); - recovery_functions_.push_back(analyzer_->rewrite_simplify.EnterConstraint(constraint_)); + recovery_functions_.push_back(analyzer_->rewrite_simplify.EnterConstraint(constraint_, is_assume_)); recovery_functions_.push_back(analyzer_->int_set.EnterConstraint(constraint_)); recovery_functions_.push_back(analyzer_->transitive_comparisons.EnterConstraint(constraint_)); + recovery_functions_.push_back(analyzer_->z3_prover.EnterConstraint(constraint_)); } void ConstraintContext::ExitWithScope() { @@ -195,7 +213,103 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) { } PrimExpr simplified = Simplify(expr); const int64_t* as_int = tir::as_const_int(simplified); - if (as_int && *as_int) return true; + if (as_int && *as_int) { return true; } + + // Structured boolean reasoning for Or/And (and their bitwise counterparts on bool) + // Evaluate children with the same proof strength. + if (const auto* not_node = simplified.as()) { + PrimExpr a = not_node->a; + // Try direct complements on common comparators + if (const auto* p = a.as()) { + return CanProve(tir::GE(p->a, p->b), strength); + } + if (const auto* p = a.as()) { + return CanProve(tir::GT(p->a, p->b), strength); + } + if (const auto* p = a.as()) { + return CanProve(tir::LE(p->a, p->b), strength); + } + if (const auto* p = a.as()) { + return CanProve(tir::LT(p->a, p->b), strength); + } + if (const auto* p = a.as()) { + return CanProve(tir::NE(p->a, p->b), strength); + } + if (const auto* p = a.as()) { + return CanProve(tir::EQ(p->a, p->b), strength); + } + // De Morgan on canonical boolean nodes + if (const auto* or_node = a.as()) { + PrimExpr lhs = tir::Not(or_node->a); + PrimExpr rhs = tir::Not(or_node->b); + return CanProve(tir::And(lhs, rhs), strength); + } + if (const auto* and_node = a.as()) { + PrimExpr lhs = tir::Not(and_node->a); + PrimExpr rhs = tir::Not(and_node->b); + return CanProve(tir::Or(lhs, rhs), strength); + } + // De Morgan on bitwise boolean calls + if (const auto* c = a.as()) { + using namespace tir; + if (c->op.same_as(builtin::bitwise_or()) && c->args.size() == 2 && a.dtype().is_bool()) { + PrimExpr lhs = tir::Not(c->args[0]); + PrimExpr rhs = tir::Not(c->args[1]); + return CanProve(tir::And(lhs, rhs), strength); + } + if (c->op.same_as(builtin::bitwise_and()) && c->args.size() == 2 && a.dtype().is_bool()) { + PrimExpr lhs = tir::Not(c->args[0]); + PrimExpr rhs = tir::Not(c->args[1]); + return CanProve(tir::Or(lhs, rhs), strength); + } + } + if (const auto* inner_not = a.as()) { + // Double negation + return CanProve(inner_not->a, strength); + } + // Fallback: if `a` simplifies to constant false, then Not(a) is true + PrimExpr a_simpl = Simplify(a); + const int64_t* a_const = tir::as_const_int(a_simpl); + if (a_const && *a_const == 0) { return true; } + // Otherwise, cannot conclude true + } + if (const auto* or_node = simplified.as()) { + if (CanProve(or_node->a, strength)) { + return true; + } + if (CanProve(or_node->b, strength)) { + return true; + } + } + if (const auto* and_node = simplified.as()) { + bool lhs = CanProve(and_node->a, strength); + bool rhs = CanProve(and_node->b, strength); + if (lhs && rhs) { + return true; + } + } + if (const auto* call = simplified.as()) { + using namespace tir; + if (call->op.same_as(builtin::bitwise_or()) && call->args.size() == 2 && + simplified.dtype().is_bool()) { + if (CanProve(call->args[0], strength) || CanProve(call->args[1], strength)) { + return true; + } + } + if (call->op.same_as(builtin::bitwise_and()) && call->args.size() == 2 && + simplified.dtype().is_bool()) { + bool lhs = CanProve(call->args[0], strength); + bool rhs = CanProve(call->args[1], strength); + if (lhs && rhs) { + return true; + } + } + if (call->op.same_as(builtin::bitwise_not()) && call->args.size() == 1 && + simplified.dtype().is_bool()) { + // Treat as logical not and reuse Not handling by constructing tir::Not + return CanProve(tir::Not(call->args[0]), strength); + } + } if (strength >= ProofStrength::kSymbolicBound) { // NOTE: we intentionally only pattern match common bound predicate i < bound // and put this implementation at the top-level. @@ -221,11 +335,16 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) { lower_bound = 0; } if (pos_diff) { - IntSet iset = this->int_set(this->Simplify(pos_diff.value())); + PrimExpr simplified_diff = this->Simplify(pos_diff.value()); + IntSet iset = this->int_set(simplified_diff); if (iset.HasLowerBound()) { ConstIntBound relaxed_lower_bound = this->const_int_bound(this->Simplify(iset.min())); if (relaxed_lower_bound->min_value >= lower_bound) return true; } + if (iset.HasUpperBound()) { + ConstIntBound relaxed_upper_bound = this->const_int_bound(this->Simplify(iset.max())); + if (relaxed_upper_bound->max_value < lower_bound) return false; + } } } @@ -238,14 +357,41 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) { if (ContainsVscaleCall(simplified)) { if (TargetHasVLA(curr_target)) { auto kVScaleValues = GetVScaleValues(curr_target); - return CanProveVscaleExpressionFromKnownValues(this, simplified, kVScaleValues); + if(CanProveVscaleExpressionFromKnownValues(this, simplified, kVScaleValues)) { + return true; + } } - LOG(WARNING) - << "The expression contains scalable values. An attempt to prove by substituting " - "with known values of vscale was not performed. This proof currently only supports " - "VLA targets, but the target was " - << curr_target; + // LOG(WARNING) + // << "The expression contains scalable values. An attempt to prove by substituting " + // "with known values of vscale was not performed. This proof currently only supports " + // "VLA targets, but the target was " + // << curr_target; } + if(z3_prover.CanProve(simplified)) { + // auto msg = z3_prover.GetSMTLIB2(simplified); + // std::stringstream ss; + // ss << msg; + // std::stringstream out; + // std::string tmp; + // while(std::getline(ss, tmp)) { + // out << " " << tmp << "\n"; + // } + // LOG(INFO) << "Proved by Z3: " << simplified << "\n" << out.str(); + return true; + } + // if(strength >= ProofStrength::kSymbolicBound && z3_prover.CanProve(simplified)) { + // // The following debug logging is very useful when diagnosing issues with the Z3 prover. + // auto msg = z3_prover.GetSMTLIB2(simplified); + // std::stringstream ss; + // ss << msg; + // std::stringstream out; + // std::string tmp; + // while(std::getline(ss, tmp)) { + // out << " " << tmp << "\n"; + // } + // LOG(INFO) << "Proved by Z3: " << simplified << "\n" << out.str(); + // return true; + // } return false; } @@ -270,100 +416,128 @@ PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) { return res; } +namespace { +using FnFactory = tvm::ffi::TypedFunction; +static FnFactory BuildAnalyzerFactory(std::shared_ptr self) { + using tvm::ffi::Function; + return FnFactory([self](std::string name) -> Function { + if (name == "const_int_bound") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + *ret = self->const_int_bound(args[0].cast()); + }); + } else if (name == "modular_set") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + *ret = self->modular_set(args[0].cast()); + }); + } else if (name == "clone") { + return Function([self](tvm::ffi::PackedArgs, tvm::ffi::Any* ret) { + auto cloned_unique = self->Clone(); + auto cloned = std::shared_ptr(cloned_unique.release()); + *ret = BuildAnalyzerFactory(cloned); + }); + } else if (name == "const_int_bound_update") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + self->const_int_bound.Update(args[0].cast(), args[1].cast(), + args[2].cast()); + }); + } else if (name == "const_int_bound_is_bound") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + *ret = self->const_int_bound.IsBound(args[0].cast()); + }); + } else if (name == "Simplify") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + if (args.size() == 1) { + *ret = self->Simplify(args[0].cast()); + } else if (args.size() == 2) { + *ret = self->Simplify(args[0].cast(), args[1].cast()); + } else { + LOG(FATAL) << "Invalid size of argument (" << args.size() << ")"; + } + }); + } else if (name == "rewrite_simplify") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + *ret = self->rewrite_simplify(args[0].cast()); + }); + } else if (name == "get_rewrite_simplify_stats") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + *ret = self->rewrite_simplify.GetStatsCounters(); + }); + } else if (name == "reset_rewrite_simplify_stats") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + self->rewrite_simplify.ResetStatsCounters(); + }); + } else if (name == "canonical_simplify") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + *ret = self->canonical_simplify(args[0].cast()); + }); + } else if (name == "int_set") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + *ret = self->int_set(args[0].cast(), args[1].cast>()); + }); + } else if (name == "bind") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + if (auto opt_range = args[1].try_cast()) { + self->Bind(args[0].cast(), opt_range.value()); + } else { + self->Bind(args[0].cast(), args[1].cast()); + } + }); + } else if (name == "can_prove") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + int strength = args[1].cast(); + *ret = self->CanProve(args[0].cast(), static_cast(strength)); + }); + } else if (name == "enter_constraint_context") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + auto ctx = std::shared_ptr>( + new With(self.get(), args[0].cast())); + auto fexit = [ctx](tvm::ffi::PackedArgs, tvm::ffi::Any*) mutable { ctx.reset(); }; + *ret = tvm::ffi::Function::FromPacked(fexit); + }); + } else if (name == "can_prove_equal") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + *ret = self->CanProveEqual(args[0].cast(), args[1].cast()); + }); + } else if (name == "get_enabled_extensions") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + *ret = static_cast(self->rewrite_simplify.GetEnabledExtensions()); + }); + } else if (name == "set_enabled_extensions") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + int64_t flags = args[0].cast(); + self->rewrite_simplify.SetEnabledExtensions( + static_cast(flags)); + }); + } else if (name == "get_smtlib2") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + auto expr = args[0].cast>(); + *ret = self->z3_prover.GetSMTLIB2(expr); + }); + } else if (name == "get_z3_stats") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + *ret = self->z3_prover.GetStats(); + }); + } else if (name == "set_z3_timeout_ms") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + unsigned timeout_ms = args[0].cast(); + self->z3_prover.SetTimeoutMs(timeout_ms); + }); + } else if (name == "set_z3_rlimit") { + return Function([self](tvm::ffi::PackedArgs args, tvm::ffi::Any* ret) { + unsigned max_step = args[0].cast(); + self->z3_prover.SetRLimit(max_step); + }); + } + return Function(); + }); +} +} // namespace + TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def_packed("arith.CreateAnalyzer", [](ffi::PackedArgs args, ffi::Any* ret) { - using ffi::Function; - using ffi::TypedFunction; + refl::GlobalDef().def_packed("arith.CreateAnalyzer", [](ffi::PackedArgs, ffi::Any* ret) { auto self = std::make_shared(); - auto f = [self](std::string name) -> ffi::Function { - if (name == "const_int_bound") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = self->const_int_bound(args[0].cast()); - }); - } else if (name == "modular_set") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = self->modular_set(args[0].cast()); - }); - } else if (name == "const_int_bound_update") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - self->const_int_bound.Update(args[0].cast(), args[1].cast(), - args[2].cast()); - }); - } else if (name == "const_int_bound_is_bound") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = self->const_int_bound.IsBound(args[0].cast()); - }); - } else if (name == "Simplify") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - if (args.size() == 1) { - *ret = self->Simplify(args[0].cast()); - } else if (args.size() == 2) { - *ret = self->Simplify(args[0].cast(), args[1].cast()); - } else { - LOG(FATAL) << "Invalid size of argument (" << args.size() << ")"; - } - }); - } else if (name == "rewrite_simplify") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = self->rewrite_simplify(args[0].cast()); - }); - } else if (name == "get_rewrite_simplify_stats") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = self->rewrite_simplify.GetStatsCounters(); - }); - } else if (name == "reset_rewrite_simplify_stats") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - self->rewrite_simplify.ResetStatsCounters(); - }); - } else if (name == "canonical_simplify") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = self->canonical_simplify(args[0].cast()); - }); - } else if (name == "int_set") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = self->int_set(args[0].cast(), args[1].cast>()); - }); - } else if (name == "bind") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - if (auto opt_range = args[1].try_cast()) { - self->Bind(args[0].cast(), opt_range.value()); - } else { - self->Bind(args[0].cast(), args[1].cast()); - } - }); - } else if (name == "can_prove") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - int strength = args[1].cast(); - *ret = self->CanProve(args[0].cast(), static_cast(strength)); - }); - } else if (name == "enter_constraint_context") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - // can't use make_shared due to noexcept(false) decl in destructor, - // see https://stackoverflow.com/a/43907314 - auto ctx = std::shared_ptr>( - new With(self.get(), args[0].cast())); - auto fexit = [ctx](ffi::PackedArgs, ffi::Any*) mutable { ctx.reset(); }; - *ret = ffi::Function::FromPacked(fexit); - }); - } else if (name == "can_prove_equal") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = self->CanProveEqual(args[0].cast(), args[1].cast()); - }); - } else if (name == "get_enabled_extensions") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = static_cast(self->rewrite_simplify.GetEnabledExtensions()); - }); - } else if (name == "set_enabled_extensions") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - int64_t flags = args[0].cast(); - self->rewrite_simplify.SetEnabledExtensions( - static_cast(flags)); - }); - } - return ffi::Function(); - }; - *ret = ffi::TypedFunction(f); + *ret = BuildAnalyzerFactory(self); }); } diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index f321d761198c..66f8af178a17 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -1446,3 +1446,16 @@ CanonicalSimplifier::~CanonicalSimplifier() { delete impl_; } } // namespace arith } // namespace tvm + +// After class implementations have been defined above +namespace tvm { +namespace arith { + +// Deep copy internal state from another analyzer +void CanonicalSimplifier::CopyFrom(const CanonicalSimplifier& other) { + // Impl derives from RewriteSimplifier::Impl, reuse its copying logic + this->impl_->CopyFromImpl(*other.impl_); +} + +} // namespace arith +} // namespace tvm diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 6dd029e136ea..23765730ce48 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -103,6 +103,11 @@ class ConstIntBoundAnalyzer::Impl : public ExprFunctor { public: explicit Impl(Analyzer* parent) : parent_(parent) {} + void CopyFrom(const Impl& other) { + this->var_map_ = other.var_map_; + this->additional_info_ = other.additional_info_; + this->bound_ = nullptr; + } /*! \brief additional bound info about expr in bound */ struct BoundInfo { /*! \brief The expr */ @@ -292,7 +297,6 @@ class ConstIntBoundAnalyzer::Impl // // Example: expr = (bx * 2048 + tx * 16) % 7168 // where bx in [0, 3584), tx in [0, 128) - // ModularSet(expr) = 16*k (coeff=16, base=0) // GCD(16, 7168) = 16 // Result can only be {0, 16, 32, ..., 7152} // Without this optimization: bound = [0, 7167] @@ -431,6 +435,10 @@ class ConstIntBoundAnalyzer::Impl return VisitLeftShift(op); } else if (op->op.same_as(tir::builtin::bitwise_and())) { return VisitBitwiseAnd(op); + } else if (op->op.same_as(tir::builtin::bitwise_or())) { + return VisitBitwiseOr(op); + } else if (op->op.same_as(tir::builtin::bitwise_xor())) { + return VisitBitwiseXor(op); } else if (op->op.same_as(tir::builtin::vscale()) && TargetHasVLA(curr_target)) { auto kVScaleValues = GetVScaleValues(curr_target); unsigned int max_val = *std::max_element(kVScaleValues.begin(), kVScaleValues.end()); @@ -497,6 +505,66 @@ class ConstIntBoundAnalyzer::Impl } } + Entry VisitBitwiseOr(const CallNode* op) { + Entry a = VisitExpr(op->args[0]); + Entry b = VisitExpr(op->args[1]); + // For non-negative operands, OR result is also non-negative and + // bounded by (1<= 0 && b.min_value >= 0) { + auto bit_width = [](int64_t v) { + if (v <= 0) return 0; + int bw = 0; + while (v) { + ++bw; + v >>= 1; + } + return bw; + }; + int bw_a = bit_width(a.max_value); + int bw_b = bit_width(b.max_value); + int k = std::max(bw_a, bw_b); + if (k >= 63) { + return Everything(op->dtype); + } + int64_t ub = (static_cast(1) << k) - 1; + return MakeBound(0, ub); + } + return Everything(op->dtype); + } + + Entry VisitBitwiseXor(const CallNode* op) { + Entry a = VisitExpr(op->args[0]); + Entry b = VisitExpr(op->args[1]); + // For non-negative operands (common for index math), + // the result is within [0, (1 << k) - 1], where k is the maximum + // number of bits required to represent either operand's upper bound. + // This is a conservative but safe bound and is sufficient for layout + // index computations. + if (a.min_value >= 0 && b.min_value >= 0) { + // Compute bit width of the larger upper bound; cap at 63 to avoid UB. + auto bit_width = [](int64_t v) { + if (v <= 0) return 0; + int bw = 0; + while (v) { + ++bw; + v >>= 1; + } + return bw; + }; + int bw_a = bit_width(a.max_value); + int bw_b = bit_width(b.max_value); + int k = std::max(bw_a, bw_b); + if (k >= 63) { + // Too wide; fall back to dtype limits. + return Everything(op->dtype); + } + int64_t ub = (static_cast(1) << k) - 1; + return MakeBound(0, ub); + } + // If signs are unknown, avoid incorrect assumptions. + return Everything(op->dtype); + } + std::function EnterConstraint(const PrimExpr& constraint) { std::vector info = DetectBoundInfo(constraint); if (info.size() == 0) return nullptr; @@ -869,5 +937,10 @@ ConstIntBoundAnalyzer::ConstIntBoundAnalyzer(Analyzer* parent) : impl_(new Impl( ConstIntBoundAnalyzer::~ConstIntBoundAnalyzer() { delete impl_; } +// Deep copy internal state from another analyzer +void ConstIntBoundAnalyzer::CopyFrom(const ConstIntBoundAnalyzer& other) { + this->impl_->CopyFrom(*other.impl_); +} + } // namespace arith } // namespace tvm diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 1433ceb70fc0..2e3c3cbdbe28 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -31,6 +31,7 @@ #include #include +#include #include #include "constraint_extract.h" @@ -426,6 +427,11 @@ class IntervalSetEvaluator : public ExprFunctor { IntervalSet VisitExpr_(const VarNode* op) final { Var var = ffi::GetRef(op); + // Detect cyclic dependency: if we're already visiting this var, return conservative estimate + if (visiting_vars_.count(op)) { + return IntervalSet::SinglePoint(var); + } + ffi::Array values; if (dom_constraints_) { for (const auto& constraint : *dom_constraints_) { @@ -456,9 +462,13 @@ class IntervalSetEvaluator : public ExprFunctor { if (res->min_value.same_as(var) && res->max_value.same_as(var)) { return res; } + // Mark this var as being visited to detect cycles + visiting_vars_.insert(op); // recursively evaluate mapped result // in case the domain contains variables to be relaxed. - return Eval(res); + IntervalSet result = Eval(res); + visiting_vars_.erase(op); + return result; } IntervalSet VisitExpr_(const AddNode* op) final { return VisitBinaryExpr_(op); } @@ -609,12 +619,19 @@ class IntervalSetEvaluator : public ExprFunctor { const ffi::Map& dom_map_; const std::vector>* dom_constraints_; bool eval_vec_{false}; + // track variables being visited to detect cyclic dependencies + std::unordered_set visiting_vars_; }; class IntSetAnalyzer::Impl { public: explicit Impl(Analyzer* analyzer) : analyzer_(analyzer) {} + void CopyFrom(const Impl& other) { + this->dom_map_ = other.dom_map_; + this->dom_constraints_ = other.dom_constraints_; + } + IntSet Eval(const PrimExpr& expr, const ffi::Map& dom_map) const { return IntervalSetEvaluator(analyzer_, dom_map).Eval(expr); } @@ -745,6 +762,11 @@ std::function IntSetAnalyzer::Impl::EnterConstraint(const PrimExpr& cons return frecover; } +// Deep copy internal state from another analyzer +void IntSetAnalyzer::CopyFrom(const IntSetAnalyzer& other) { + this->impl_->CopyFrom(*other.impl_); +} + // Quickly adapt to IntSet interface // TODO(tqchen): revisit IntSet interface as well. Range IntSet::CoverRange(Range max_range) const { diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index 59b0b0546dab..ab811fd7548b 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -25,6 +25,7 @@ #include #include #include +#include "tvm/arith/analyzer.h" namespace tvm { namespace arith { @@ -64,6 +65,7 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const ForNode* op) { Range dom = Range::FromMinExtent(op->min, op->extent); analyzer_->Bind(op->loop_var, dom); iter_vars_.Set(op->loop_var, dom); + With ctx(analyzer_, op->extent > 0); return StmtExprMutator::VisitStmt_(op); } @@ -140,9 +142,15 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) { iter_vars_.Set(iv->var, dom); Stmt stmt = StmtExprMutator::VisitStmt_(op); return stmt; - } else { + } + else if(op->attr_key == tir::attr::tilelang_assume) { + auto condition = Downcast(op->node); + With constraint(analyzer_, condition); return StmtExprMutator::VisitStmt_(op); } + else { + return StmtExprMutator::VisitStmt_(op); + } } Stmt IRMutatorWithAnalyzer::VisitStmt_(const AssertStmtNode* op) { diff --git a/src/arith/ir_visitor_with_analyzer.cc b/src/arith/ir_visitor_with_analyzer.cc index dba4567f88ec..031f0b17f296 100644 --- a/src/arith/ir_visitor_with_analyzer.cc +++ b/src/arith/ir_visitor_with_analyzer.cc @@ -69,8 +69,16 @@ void IRVisitorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) { IterVar iv = Downcast(op->node); ICHECK_NE(iv->thread_tag.length(), 0U); analyzer_.Bind(iv->var, Range::FromMinExtent(IntImm(op->value->dtype, 0), op->value)); + StmtExprVisitor::VisitStmt_(op); + } + else if(op->attr_key == tir::attr::tilelang_assume) { + auto condition = Downcast(op->node); + With constraint(&analyzer_, condition); + StmtExprVisitor::VisitStmt_(op); + } + else { + StmtExprVisitor::VisitStmt_(op); } - StmtExprVisitor::VisitStmt_(op); } void IRVisitorWithAnalyzer::VisitStmt_(const AssertStmtNode* op) { diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc index e69b8ad20e85..47d8acb14dc7 100644 --- a/src/arith/modular_set.cc +++ b/src/arith/modular_set.cc @@ -104,6 +104,8 @@ class ModularSetAnalyzer::Impl : public ExprFunctorimpl_->CopyFrom(*other.impl_); } + } // namespace arith } // namespace tvm diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 65b6e408e2cb..011d91177554 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -498,13 +498,13 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { return ret; } -std::function RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& constraint) { +std::function RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& constraint, bool is_assume) { size_t old_literal_size = literal_constraints_.size(); // we will compare the already simplified result with the constraint, // so simplify the constraint as well PrimExpr new_constraint = operator()(constraint); for (const PrimExpr& subconstraint : ExtractConstraints(new_constraint, false)) { - if (SideEffect(subconstraint) <= CallEffectKind::kPure) { + if (is_assume || SideEffect(subconstraint) <= CallEffectKind::kPure) { literal_constraints_.push_back(subconstraint); PrimExpr negation; if (subconstraint.dtype().is_bool()) { @@ -814,6 +814,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { return make_const(op->dtype, truncdiv(c1val, c2val)); } + // x % c1 // c2 => 0 if 0 < c1 < c2 && x >= 0 + TVM_TRY_REWRITE_IF(truncdiv(truncmod(x, c1), c2), ZeroWithTypeLike(x), + c1.Eval()->value > 0 && c2.Eval()->value > c1.Eval()->value && + CanProveGreaterEqual(x.Eval(), 0)); + // while it is always true for trunc div // restrict to common case(positive div) TVM_TRY_REWRITE_IF(truncdiv(truncdiv(x, c1), c2), truncdiv(x, c1 * c2), @@ -1159,7 +1164,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { // Pattern var to match any expression PVar x, y, z, b1; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp PVar lanes; @@ -1214,8 +1219,14 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { c2.Eval()->value % c1.Eval()->value == 0 && CanProveEqual(floordiv(y.Eval(), c1.Eval()), 0)); + TVM_TRY_REWRITE_IF(floormod(x * c1 + y * c2 + z, c3), floormod(x * floordiv(c1, c2) + y, floordiv(c3, c2)) * c2 + z, + c2.Eval()->value > 0 && c3.Eval()->value > 0 && + c3.Eval()->value % c2.Eval()->value == 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveEqual(floordiv(z.Eval(), c2.Eval()), 0)); + TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x * floormod(c1, c2) + y, c2), - c2.Eval()->value > 0); + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); // (x + 5) % 2 -> (x + 1) %2, (x + 3) % 3 => x TVM_TRY_REWRITE_IF( @@ -2404,8 +2415,8 @@ void RewriteSimplifier::Update(const Var& var, const PrimExpr& info, bool allow_ impl_->Update(var, info, allow_override); } -std::function RewriteSimplifier::EnterConstraint(const PrimExpr& constraint) { - return impl_->EnterConstraint(constraint); +std::function RewriteSimplifier::EnterConstraint(const PrimExpr& constraint, bool is_assume) { + return impl_->EnterConstraint(constraint, is_assume); } void RewriteSimplifier::SetEnabledExtensions(Extension flags) { @@ -2427,6 +2438,22 @@ RewriteSimplifier::RewriteSimplifier(Analyzer* parent) : impl_(new Impl(parent)) RewriteSimplifier::~RewriteSimplifier() { delete impl_; } +// Impl state copy +void RewriteSimplifier::Impl::CopyFromImpl(const RewriteSimplifier::Impl& other) { + this->var_map_ = other.var_map_; + this->literal_constraints_ = other.literal_constraints_; + this->enabled_extensions_ = other.enabled_extensions_; + this->maximum_rewrite_steps_ = other.maximum_rewrite_steps_; + this->stats_ = other.stats_; + this->recur_depth_ = 0; + this->recursively_visiting_boolean_ = false; +} + +// Deep copy internal state from another analyzer +void RewriteSimplifier::CopyFrom(const RewriteSimplifier& other) { + this->impl_->CopyFromImpl(*other.impl_); +} + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* ptr = node.as(); diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index e541970a2717..d27d750e0615 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -116,7 +116,10 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { PrimExpr VisitExpr_(const CastNode* op) override; PrimExpr VisitExpr_(const LetNode* op) override; - std::function EnterConstraint(const PrimExpr& constraint); + std::function EnterConstraint(const PrimExpr& constraint, bool is_assume=false); + + // Copy internal state from another Impl instance (used by Analyzer cloning) + void CopyFromImpl(const Impl& other); /*! \brief Enable an optional extension or extensions * diff --git a/src/arith/transitive_comparison_analyzer.cc b/src/arith/transitive_comparison_analyzer.cc index b4cd7b260ebb..ec0173ca996e 100644 --- a/src/arith/transitive_comparison_analyzer.cc +++ b/src/arith/transitive_comparison_analyzer.cc @@ -82,6 +82,9 @@ class TransitiveComparisonAnalyzer::Impl { */ std::function EnterConstraint(const PrimExpr& expr); + // Copy internal state from another Impl (for Analyzer cloning) + void CopyFrom(const Impl& other); + private: /* \brief Internal representation of a PrimExpr * @@ -600,6 +603,11 @@ std::function TransitiveComparisonAnalyzer::Impl::EnterConstraint(const return frecover; } +// Deep copy internal state from another analyzer +void TransitiveComparisonAnalyzer::CopyFrom(const TransitiveComparisonAnalyzer& other) { + this->impl_->CopyFrom(*other.impl_); +} + CompareResult TransitiveComparisonAnalyzer::Impl::TryCompare(const PrimExpr& lhs_expr, const PrimExpr& rhs_expr, bool propagate_inequalities) const { @@ -872,5 +880,13 @@ CompareResult TransitiveComparisonAnalyzer::Impl::MergeComparisons( return result; } +// Implementation of the CopyFrom helper +void TransitiveComparisonAnalyzer::Impl::CopyFrom(const Impl& other) { + prev_bindings_ = other.prev_bindings_; + knowns_ = other.knowns_; + scoped_knowns_ = other.scoped_knowns_; + expr_to_key = other.expr_to_key; +} + } // namespace arith } // namespace tvm diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index 3fee6b55f2e5..20ffbc1df450 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -178,6 +178,15 @@ class CUDAWrappedFunc { sptr_ = sptr; func_name_ = func_name; std::fill(fcache_.begin(), fcache_.end(), nullptr); + // Track whether this kernel uses dynamic shared memory and the last size set per device. + std::fill(dyn_smem_initialized_.begin(), dyn_smem_initialized_.end(), false); + use_dyn_shared_memory_ = false; + for (const auto& tag : launch_param_tags) { + if (tag == launch_param::kUseDynamicSharedMemoryTag) { + use_dyn_shared_memory_ = true; + break; + } + } launch_param_config_.Init(num_void_args, launch_param_tags); } // invoke the function with void arguments @@ -188,21 +197,56 @@ class CUDAWrappedFunc { if (fcache_[device_id] == nullptr) { fcache_[device_id] = m_->GetFunc(device_id, func_name_); - if (wl.dyn_shmem_size >= (48 << 10)) { - // Assumption: dyn_shmem_size doesn't change across different invocations of - // fcache_[device_id] - CUresult result = cuFuncSetAttribute( + } + + // If the kernel uses dynamic shared memory, we should ensure the attribute + // reflects the actual size needed for this launch. Some workloads vary the + // dynamic shared memory between invocations, in which case we cannot set it + // just once. Cache the last value per device to avoid redundant calls. + bool need_dyn_attr = use_dyn_shared_memory_ || (wl.dyn_shmem_size > 0); + if (need_dyn_attr) { + if (!dyn_smem_initialized_[device_id] || dyn_smem_last_[device_id] != wl.dyn_shmem_size) { + CUresult attr_set = cuFuncSetAttribute( fcache_[device_id], CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, wl.dyn_shmem_size); - if (result != CUDA_SUCCESS) { + if (attr_set != CUDA_SUCCESS) { LOG(FATAL) << "Failed to set the allowed dynamic shared memory size to " << wl.dyn_shmem_size; } + dyn_smem_last_[device_id] = wl.dyn_shmem_size; + dyn_smem_initialized_[device_id] = true; } } CUstream strm = static_cast(TVMFFIEnvGetStream(kDLCUDA, device_id)); - CUresult result = cuLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), - wl.grid_dim(2), wl.block_dim(0), wl.block_dim(1), - wl.block_dim(2), wl.dyn_shmem_size, strm, void_args, nullptr); + CUresult result; + + if (launch_param_config_.use_programtic_dependent_launch()) { + CUlaunchConfig config{}; + CUlaunchAttribute attribute[1]{}; + attribute[0].id = CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION; + attribute[0].value.programmaticStreamSerializationAllowed = 1; + + config.attrs = attribute; + config.numAttrs = 1; + config.hStream = strm; + config.gridDimX = wl.grid_dim(0); + config.gridDimY = wl.grid_dim(1); + config.gridDimZ = wl.grid_dim(2); + config.blockDimX = wl.block_dim(0); + config.blockDimY = wl.block_dim(1); + config.blockDimZ = wl.block_dim(2); + config.sharedMemBytes = wl.dyn_shmem_size; + + result = cuLaunchKernelEx(&config, fcache_[device_id], void_args, nullptr); + } else if (launch_param_config_.use_cooperative_launch()) { + result = cuLaunchCooperativeKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), + wl.grid_dim(2), wl.block_dim(0), wl.block_dim(1), + wl.block_dim(2), wl.dyn_shmem_size, strm, void_args); + } else { + result = cuLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2), + wl.block_dim(0), wl.block_dim(1), wl.block_dim(2), wl.dyn_shmem_size, + strm, void_args, nullptr); + } + if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { const char* msg; cuGetErrorName(result, &msg); @@ -210,7 +254,8 @@ class CUDAWrappedFunc { os << "CUDALaunch Error: " << msg << "\n" << " grid=(" << wl.grid_dim(0) << "," << wl.grid_dim(1) << "," << wl.grid_dim(2) << "), " << " block=(" << wl.block_dim(0) << "," << wl.block_dim(1) << "," << wl.block_dim(2) - << ")\n"; + << ")" + << " dyn_smem_bytes=" << wl.dyn_shmem_size; std::string cuda = m_->InspectSource(""); if (cuda.length() != 0) { os << "// func_name=" << func_name_ << "\n" @@ -234,6 +279,13 @@ class CUDAWrappedFunc { mutable std::array fcache_; // launch parameters configuration LaunchParamConfig launch_param_config_; + // Whether this kernel uses dynamic shared memory + bool use_dyn_shared_memory_{false}; + // Cached last dynamic shared memory size per device and whether it's initialized + mutable std::array dyn_smem_last_; + mutable std::array dyn_smem_initialized_; + // have pdl setting + bool has_programmatic_dependent_launch_; }; class CUDAPrepGlobalBarrier { diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index 85b83289f4d3..aceb97b58374 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -48,6 +48,10 @@ namespace launch_param { /*! \brief A tag to specify whether or not dynamic shared memory is used */ constexpr const char* kUseDynamicSharedMemoryTag = "tir.use_dyn_shared_memory"; +/*! \brief A tag to specify whether or not use programatic dependent launch */ +constexpr const char* kUseProgramaticDependentLaunch = "tir.use_programtic_dependent_launch"; +/*! \brief A tag to specify whether or not use cooperative launch */ +constexpr const char* kUseCooperativeLaunch = "tir.use_cooperative_launch"; } // namespace launch_param diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h index f10489826a5a..d16b3d8008fb 100644 --- a/src/runtime/metal/metal_common.h +++ b/src/runtime/metal/metal_common.h @@ -141,6 +141,19 @@ class Stream { std::string error_description_; }; +class CBStream final : public Stream { +public: + explicit CBStream(id commandBuffer): Stream(nullptr) { + buffer_ = commandBuffer; + } + id GetCommandBuffer() { + return buffer_; + } +private: + id buffer_; +}; + + /*! * \brief Process global Metal workspace. */ diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index 9c0aa96257d4..767701a73c2e 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -33,6 +33,7 @@ #include "../pack_args.h" #include "../thread_storage_scope.h" #include "metal_common.h" +#include "tvm/runtime/device_api.h" namespace tvm { namespace runtime { @@ -211,8 +212,9 @@ void operator()(ffi::PackedArgs args, ffi::Any* rv, const ArgUnion64* pack_args) auto maxTotalThreadsPerThreadgroup = scache_[device_id].maxTotalThreadsPerThreadgroup; CHECK_LE(blockSize, maxTotalThreadsPerThreadgroup); // attach error message directly in this functio - id cb = stream->GetCommandBuffer(/*label=*/"TVMKernel:" + func_name_, - /*attach_error_callback=*/false); + // id cb = stream->GetCommandBuffer(/*label=*/"TVMKernel:" + func_name_, + // /*attach_error_callback=*/false); + id cb = static_cast(stream)->GetCommandBuffer(); id encoder = [cb computeCommandEncoder]; [encoder setComputePipelineState:scache_[device_id]]; for (size_t i = 0; i < num_buffer_args_; ++i) { @@ -239,7 +241,8 @@ void operator()(ffi::PackedArgs args, ffi::Any* rv, const ArgUnion64* pack_args) stream->SetError(os.str()); } }]; - [cb commit]; + // When we reuse torch's command buffer, torch will sync + // [cb commit]; }; } @@ -324,9 +327,19 @@ void operator()(ffi::PackedArgs args, ffi::Any* rv, const ArgUnion64* pack_args) return MetalModuleCreate(smap, fmap, fmt, ""); } +void SetMetalStream(TVMStreamHandle stream) { + metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal(); + auto s = new metal::CBStream(static_cast>(stream)); + if (t->stream.size() <= t->device.device_id) { + t->stream.resize(t->device.device_id); + } + t->stream[t->device.device_id] = static_cast(s); +} + TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("ffi.Module.load_from_bytes.metal", MetalModuleLoadFromBytes); + refl::GlobalDef().def("ffi.Module.load_from_bytes.metal", MetalModuleLoadFromBytes) + .def("metal.SetStream", SetMetalStream); } } // namespace runtime } // namespace tvm diff --git a/src/runtime/pack_args.h b/src/runtime/pack_args.h index 8929f90b0f09..e1b1fec0a39a 100644 --- a/src/runtime/pack_args.h +++ b/src/runtime/pack_args.h @@ -39,6 +39,10 @@ namespace tvm { namespace runtime { + +/*! \brief TileLang Grid constant */ +constexpr unsigned int kDLGridConstant = 30U; + /*! * \brief argument union type of 32bit. */ @@ -134,7 +138,8 @@ enum ArgConvertCode { FLOAT64_TO_FLOAT32, FLOAT64_TO_FLOAT64, HANDLE_TO_HANDLE, - HANDLE_TO_TENSORMAP + HANDLE_TO_TENSORMAP, + HANDLE_TO_REFERENCE, }; inline ArgConvertCode GetArgConvertCode(DLDataType t) { @@ -149,6 +154,8 @@ inline ArgConvertCode GetArgConvertCode(DLDataType t) { if (t.bits == 32U) return FLOAT64_TO_FLOAT32; } else if (t.code == kDLOpaqueHandle) { return HANDLE_TO_HANDLE; + } else if (t.code == kDLGridConstant) { + return HANDLE_TO_REFERENCE; } LOG(FATAL) << "Cannot handle " << t << " as device function argument"; } @@ -191,6 +198,9 @@ inline ffi::Function PackFuncVoidAddr_(F f, const std::vector& c addr[i] = raw_args[i].v_ptr; break; } + case HANDLE_TO_REFERENCE: { + addr[i] = raw_args[i].v_obj; + } } } f(args, ret, addr); @@ -231,6 +241,7 @@ inline ffi::Function PackFuncNonBufferArg_(F f, int base, break; } case HANDLE_TO_HANDLE: + case HANDLE_TO_REFERENCE: case HANDLE_TO_TENSORMAP: { LOG(FATAL) << "not reached"; break; @@ -293,6 +304,7 @@ inline ffi::Function PackFuncPackedArgAligned_(F f, const std::vector arg_index_map_; /*! \brief Whether or not use dynamic shared memory. */ bool use_dyn_shared_memory_{false}; + /*! \brief Whether or not use programmatic dependent launch. */ + bool use_programmatic_dependent_launch_{false}; + /*! \brief Whether or not use cooperative launch. */ + bool use_cooperative_launch_{false}; }; } // namespace runtime diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 00f9c28475b4..a3a96d4a6e6f 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -22,6 +22,7 @@ #include #include "./utils.h" +#include "tvm/ffi/string.h" namespace tvm { namespace script { @@ -132,7 +133,7 @@ Buffer MatchBuffer(ObjectRef param, ffi::Array shape, DataType dtype, Buffer buffer = BufferDecl(shape, dtype, "", data, strides, elem_offset, storage_scope, align, offset_factor, buffer_type_str, axis_separators); if (const auto* var = param.as()) { - PrimFuncFrame frame = FindPrimFuncFrame("T.match_buffer"); + PrimFuncFrame frame = FindPrimFuncFrameRelaxed("T.match_buffer"); Var v = ffi::GetRef(var); for (auto const& arg : frame->args) { if (arg.same_as(v)) { @@ -222,9 +223,9 @@ void Writes(ffi::Array buffer_slices) { } /*! \brief Recursively merge two annotations, the new attrs will override the old ones */ -ffi::Map MergeAnnotations(const ffi::Map& new_attrs, - const ffi::Map& old_attrs) { - ffi::Map result = old_attrs; +ffi::Map MergeAnnotations(const ffi::Map& new_attrs, + const ffi::Map& old_attrs) { + ffi::Map result = old_attrs; for (const auto& [key, value] : new_attrs) { auto old_value = old_attrs.Get(key); // Case 1: the key is not in the old annotations, set the key to the new value @@ -235,15 +236,15 @@ ffi::Map MergeAnnotations(const ffi::Map& ne // Case 2: the key is in the old annotations // Case 2.1: both are dicts - auto old_dict = old_value->try_cast>(); - auto new_dict = value.try_cast>(); + auto old_dict = old_value->try_cast>(); + auto new_dict = value.try_cast>(); if (old_dict && new_dict) { // Recursively merge the two dicts auto merged_dict = MergeAnnotations(*old_dict, *new_dict); result.Set(key, merged_dict); continue; } - // Case 2.2: the values are not both dicts, check if the keys are the same + // Case 2.3: the values are not both dicts, check if the keys are the same if (!ffi::AnyEqual()(old_value.value(), value)) { LOG(FATAL) << "ValueError: Try to merge two annotations with different values for key `" << key << "`, previous one is " << old_value.value() << ", new one is " << value; @@ -252,14 +253,14 @@ ffi::Map MergeAnnotations(const ffi::Map& ne return result; } -void BlockAttrs(ffi::Map attrs) { +void BlockAttrs(ffi::Map attrs) { BlockFrame frame = FindBlockFrame("T.block_attr"); // Case 1: the block has no annotations, set the new annotations if (!frame->annotations.defined()) { frame->annotations = attrs; } else { // Case 2: the block has annotations, merge the new annotations with the old ones - frame->annotations = MergeAnnotations(attrs, frame->annotations.value()); + frame->annotations = Downcast>(MergeAnnotations(Downcast>(attrs), Downcast>(frame->annotations.value()))); } } @@ -270,10 +271,14 @@ Buffer AllocBuffer(ffi::Array shape, DataType dtype, ffi::Optional frame = builder->FindFrame()) { + if (ffi::Optional frame = builder->GetLastFrame()) { + frame.value()->alloc_buffers.push_back(buffer); + } else if (ffi::Optional frame = builder->FindFrame()) { frame.value()->alloc_buffers.push_back(buffer); } else if (ffi::Optional frame = builder->GetLastFrame()) { frame.value()->root_alloc_buffers.push_back(buffer); + } else if (ffi::Optional frame = builder->FindFrame()) { + frame.value()->root_alloc_buffers.push_back(buffer); } else { LOG(FATAL) << "ValueError: Block frame or PrimFunc frame not find. Please ensure " "'T.alloc_buffer' is called under T.block() or T.prim_func()"; @@ -677,8 +682,9 @@ TVM_STATIC_IR_FUNCTOR(Namer, vtable) int n = buffer->strides.size(); for (int i = 0; i < n; ++i) { PrimExpr e = buffer->strides[i]; - if (auto v = e.as()) { - Namer::Name(v.value(), name + "_s" + std::to_string(i)); + if (const auto* v = e.as()) { + ffi::String new_name = !v->name_hint.empty() ? v->name_hint : (name + "_s" + std::to_string(i)); + Namer::Name(ffi::GetRef(v), ffi::String(new_name)); } } }); @@ -687,7 +693,7 @@ TVM_STATIC_IR_FUNCTOR(Namer, vtable) .set_dispatch([](const ObjectRef& node, ffi::String name) -> void { using namespace tvm::tir; SizeVarNode* var = const_cast(node.as()); - var->name_hint = name; + var->name_hint = ffi::String(name); }); TVM_STATIC_IR_FUNCTOR(Namer, vtable) @@ -782,7 +788,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def(Prefix TVM_TMP_STR(64), DType##64) #define TVM_FFI_REFL_DEF_GLOBAL_LANES(Prefix, Func) \ - def(Prefix TVM_TMP_STR(x4), Func##x4) \ + def(Prefix TVM_TMP_STR(x2), Func##x2) \ + .def(Prefix TVM_TMP_STR(x4), Func##x4) \ .def(Prefix TVM_TMP_STR(x8), Func##x8) \ .def(Prefix TVM_TMP_STR(x16), Func##x16) \ .def(Prefix TVM_TMP_STR(x32), Func##x32) \ diff --git a/src/script/ir_builder/tir/utils.h b/src/script/ir_builder/tir/utils.h index d7c272ae5138..655dea5fbda3 100644 --- a/src/script/ir_builder/tir/utils.h +++ b/src/script/ir_builder/tir/utils.h @@ -75,6 +75,24 @@ inline PrimFuncFrame FindPrimFuncFrame(const ffi::String& method) { throw; } +/*! + * \brief Find a PrimFuncFrame anywhere in the current builder stack (not necessarily the top). + * This relaxed variant enables certain APIs (e.g., T.match_buffer on a PrimFunc param) + * to be invoked after non-top-level frames (let/if/for) have been introduced, while + * still being inside a PrimFunc scope. + * \param method The method name to be printed when throwing exception. + * \return The PrimFuncFrame found in the builder stack. + */ +inline PrimFuncFrame FindPrimFuncFrameRelaxed(const ffi::String& method) { + if (ffi::Optional frame = IRBuilder::Current()->FindFrame()) { + return frame.value(); + } else { + LOG(FATAL) << "ValueError: " << method << " must be called under a T.prim_func(), " + << "but it occurred outside of any T.prim_func() frame"; + } + throw; +} + /*! * \brief Check whether the top frame in IRBuilder frame stack is BlockFrame. * \param method The method name to be printed when throwing exception. diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index 3103e6f5b9c3..de9a8ce78a40 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -34,6 +34,9 @@ using tir::FLowerIntrinsic; TVM_REGISTER_OP("tir.exp").set_attr("default.FLowerIntrinsic", DispatchPureExtern); +TVM_REGISTER_OP("tir.exp2") + .set_attr("default.FLowerIntrinsic", DispatchPureExtern); + TVM_REGISTER_OP("tir.erf").set_attr("default.FLowerIntrinsic", DispatchPureExtern); diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 52ad78166981..31c2763ef629 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -360,7 +360,22 @@ std::string CodeGenC::GetStructRef(DataType t, const PrimExpr& buffer, const Pri os << ")"; return os.str(); } else { - TVM_FFI_THROW(RuntimeError) << "Unsupported type index: " << kind; + ICHECK_LT(kind, builtin::kTVMValueKindBound_); + std::ostringstream os; + os << "(((TVMFFIAny*)"; + this->PrintExpr(buffer, os); + os << ")[" << index << "]."; + if (t.is_handle()) { + os << "v_ptr"; + } else if (t.is_float()) { + os << "v_float64"; + } else if (t.is_int()) { + os << "v_int64"; + } else { + LOG(FATAL) << "Do not know how to handle type" << t; + } + os << ")"; + return os.str(); } } @@ -918,13 +933,19 @@ void CodeGenC::VisitStmt_(const BufferStoreNode* op) { } void CodeGenC::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*) - auto it = let_binding_.find(op->var); - if (it != let_binding_.end()) { - ICHECK(deep_equal_(it->second->value, op->value)) - << "Let cannot bind the same var to two different values"; - } else { - let_binding_[op->var] = op; - } + // auto it = let_binding_.find(op->var); + // if (it != let_binding_.end()) { + // std::cerr << "CHECK: " << op->var << "(" << op->var.get() << "): " << op->var << " = " << op->value << " : " << std::hex << (unsigned long long)(it->second) << "\n"; + // std::cerr << " var=" << op->var.get() << "\n"; + // std::cerr << " val=" << op->value.get() << "\n"; + // ICHECK(deep_equal_(it->second->value, op->value)) + // << "Let cannot bind the same var to two different values: " << op->var << " " << op->value; + // } else { + // std::cerr << "BIND: " << op->var << "(" << op->var.get() << "): " << op->var << " = " << op->value << " : " << std::hex << (unsigned long long)(op) << "\n"; + // std::cerr << " var=" << op->var.get() << "\n"; + // std::cerr << " val=" << op->value.get() << "\n"; + // let_binding_[op->var] = op; + // } std::string value = PrintExpr(op->value); if (print_ssa_form_) { ICHECK(!var_idmap_.count(op->var.get())); @@ -1215,6 +1236,21 @@ void CodeGenC::VisitStmt_(const EvaluateNode* op) { // cast int to enum cast = "(DLDeviceType)"; } + // Special-case: Assigning a string literal to the Any union's v_ptr + // triggers const correctness issues when compiling as C++. + // If the destination is the Any union value (kTVMFFIAnyUnionValue), + // the store dtype is a handle (thus maps to v_ptr), and the source value + // is a StringImm, cast the string literal to (void*) to avoid + // discarding const qualifier errors under C++. + if (kind == builtin::kTVMFFIAnyUnionValue && store_dtype.is_handle()) { + if (const auto* str_imm = call->args[3].as()) { + (void)str_imm; // silence unused warning + // prepend cast if not already added + if (cast.empty()) { + cast = "(void*)"; + } + } + } this->PrintIndent(); this->stream << ref << " = " << cast << value << ";\n"; return; diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 12a8d66bba9b..15bee36e31d9 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -32,6 +32,9 @@ #include #include +// For escaping strings embedded into generated C sources +#include "../../support/str_escape.h" + namespace tvm { namespace codegen { @@ -50,6 +53,8 @@ void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_d decl_stream << "#include \"tvm/runtime/c_backend_api.h\"\n"; decl_stream << "#include \"tvm/ffi/c_api.h\"\n"; decl_stream << "#include \n"; + // snprintf for richer assert messages with actual values + decl_stream << "#include \n"; decl_stream << "#include \n"; CodeGenCHost::InitGlobalContext(); CodeGenC::Init(output_ssa); @@ -323,9 +328,33 @@ void CodeGenCHost::VisitStmt_(const AssertStmtNode* op) { // NOLINT(*) PrintIndent(); stream << "if (!(" << cond << ")) {\n"; int assert_if_scope = this->BeginScope(); - PrintIndent(); - stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", \"" - << op->message.as()->value << "\", NULL);\n"; + { + // Prepare the base error message + const auto* msg_node = op->message.as(); + ICHECK(msg_node != nullptr) << "Assert message expected to be StringImm"; + const std::string& raw_msg = msg_node->value; + const std::string esc_msg = + tvm::support::StrEscape(raw_msg.c_str(), raw_msg.length(), /*use_octal_escape=*/true, + /*escape_whitespace_special_chars=*/true); + + // If the assertion is an equality check, append the actual LHS/RHS values + if (const auto* eq = op->condition.as()) { + std::string lhs = PrintExpr(eq->a); + std::string rhs = PrintExpr(eq->b); + PrintIndent(); + stream << "char __tvm_assert_msg_buf[512];\n"; + PrintIndent(); + stream << "snprintf(__tvm_assert_msg_buf, 512, \"%s; got: %lld, expected: %lld\", \"" + << esc_msg << "\", (long long)(" << lhs << "), (long long)(" << rhs + << "));\n"; + PrintIndent(); + stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", __tvm_assert_msg_buf);\n"; + } else { + PrintIndent(); + stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", \"" << esc_msg + << "\");\n"; + } + } PrintIndent(); stream << "return -1;\n"; this->EndScope(assert_if_scope); @@ -359,7 +388,8 @@ inline void CodeGenCHost::PrintTernaryCondExpr(const T* op, const char* compare, ffi::Module BuildCHost(IRModule mod, Target target) { bool output_ssa = false; - bool emit_asserts = false; + // Enable emission of runtime asserts in generated C host code + bool emit_asserts = true; bool emit_fwd_func_decl = true; std::unordered_set devices; diff --git a/src/target/source/intrin_rule_cuda.cc b/src/target/source/intrin_rule_cuda.cc index 56b575cc6c38..ee38ed63dc76 100644 --- a/src/target/source/intrin_rule_cuda.cc +++ b/src/target/source/intrin_rule_cuda.cc @@ -170,37 +170,37 @@ TVM_REGISTER_OP("tir.nearbyint") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); TVM_REGISTER_OP("tir.exp").set_attr("cuda.FLowerIntrinsic", - DispatchPureExtern); + DispatchPureExtern); TVM_REGISTER_OP("tir.exp2") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); TVM_REGISTER_OP("tir.exp10") - .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); TVM_REGISTER_OP("tir.erf").set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); TVM_REGISTER_OP("tir.log").set_attr("cuda.FLowerIntrinsic", - DispatchPureExtern); + DispatchPureExtern); TVM_REGISTER_OP("tir.log2") - .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); TVM_REGISTER_OP("tir.log10") - .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); TVM_REGISTER_OP("tir.tan").set_attr("cuda.FLowerIntrinsic", - DispatchPureExtern); + DispatchPureExtern); TVM_REGISTER_OP("tir.cos").set_attr("cuda.FLowerIntrinsic", - DispatchPureExtern); + DispatchPureExtern); TVM_REGISTER_OP("tir.cosh") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); TVM_REGISTER_OP("tir.sin").set_attr("cuda.FLowerIntrinsic", - DispatchPureExtern); + DispatchPureExtern); TVM_REGISTER_OP("tir.sinh") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index d44173a2ae3c..99a5684af521 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -356,6 +356,19 @@ TVM_REGISTER_TARGET_KIND("rocm", kDLROCM) .set_default_keys({"rocm", "gpu"}) .set_target_parser(UpdateROCmAttrs); +TVM_REGISTER_TARGET_KIND("hip", kDLROCM) + .add_attr_option("mcpu") + .add_attr_option("mtriple") + .add_attr_option>("mattr") + // TODO(masahi): Support querying from a target device + // On RDNA cards, thread_warp_size should be 32 + .add_attr_option("max_num_threads", 256) + .add_attr_option("max_threads_per_block", 256) + .add_attr_option("max_shared_memory_per_block", 65536) + .add_attr_option("thread_warp_size", 64) + .set_default_keys({"hip", "gpu"}) + .set_target_parser(UpdateROCmAttrs); + TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL) .add_attr_option("max_threads_per_block", 256) .add_attr_option("max_shared_memory_per_block", 16384) diff --git a/src/target/z3/z3_prover_off.cc b/src/target/z3/z3_prover_off.cc new file mode 100644 index 000000000000..1650e9261382 --- /dev/null +++ b/src/target/z3/z3_prover_off.cc @@ -0,0 +1,33 @@ +#include +#include +#include + +#include "tvm/ffi/string.h" +#include "tvm/ir/expr.h" +#include "tvm/tir/analysis.h" +#include "tvm/arith/analyzer.h" + +namespace tvm::arith { + +using namespace tir; +using namespace ffi; + +class Z3Prover::Impl {}; + +TVM_DLL bool Z3Prover::CanProve(const PrimExpr & expr) { return false; } +TVM_DLL void Z3Prover::Bind(const Var& var, const Range& new_range, bool allow_override) {} +TVM_DLL void Z3Prover::Bind(const Var& var, const PrimExpr& expr, bool allow_override) {} +std::function Z3Prover::EnterConstraint(const PrimExpr& constraint, bool is_assume) { return [](){}; } +ffi::String Z3Prover::GetSMTLIB2(const ffi::Optional expr) { + return "; Z3 Prover is disabled."; +} +void Z3Prover::SetTimeoutMs(unsigned timeout_ms) {} +void Z3Prover::SetMaxStep(unsigned max_step) {} +void Z3Prover::CopyFrom(const Z3Prover & other) {} +ffi::String Z3Prover::GetStats() { + return "; Z3 Prover is disabled."; +} +Z3Prover::Z3Prover(Analyzer*): impl_(nullptr) {} +TVM_DLL Z3Prover::~Z3Prover() {} + +} // namespace tvm::arith \ No newline at end of file diff --git a/src/target/z3/z3_prover_on.cc b/src/target/z3/z3_prover_on.cc new file mode 100644 index 000000000000..76de2125e053 --- /dev/null +++ b/src/target/z3/z3_prover_on.cc @@ -0,0 +1,517 @@ +#include +#include +#include +#include "z3++.h" + +#include +#include +#include + +#include "tvm/ffi/cast.h" +#include "tvm/ffi/object.h" +#include "tvm/ffi/string.h" +#include "tvm/ir/expr.h" +#include "tvm/node/structural_equal.h" +#include "tvm/node/structural_hash.h" +#include "tvm/runtime/data_type.h" +#include "tvm/tir/analysis.h" +#include "tvm/tir/expr_functor.h" +#include "tvm/arith/analyzer.h" +#include "tvm/tir/op_attr_types.h" + +namespace tvm::arith { + +using namespace tir; +using namespace ffi; + +namespace { + +struct Namespace { + std::unordered_set used_names; + /// @brief Get a new name that is not used before + /// This function is used to generate z3 variable names + /// + /// Z3 may deduplicate variables with the same name, which + /// causes issues when different TVM variables are mapped to + /// the same z3 variable. + /// + /// This function generates unique names by appending + /// suffixes to the original expression string representation. + /// + /// such as : "x", "x$1", "x$2", ... + std::string GetNewName(const PrimExpr & expr) { + std::stringstream ss; + ss << expr; + auto name = ss.str(); + if(used_names.count(name) == 0) { + used_names.insert(name); + return name; + } + int idx = 1; + std::string check_name = name + "$" + std::to_string(idx); + while(used_names.count(check_name)) { + idx ++; + check_name = name + "$" + std::to_string(idx); + } + used_names.insert(check_name); + return check_name; + } +}; + +} // namespace + +class Z3Prover::Impl : ExprFunctor { +public: + using Base = ExprFunctor; + using Self = Z3Prover::Impl; + + Analyzer* analyzer; + /// @brief Z3 context, a shared ptr, because tilelang want to copy the Analyzer + // We use a thread_local static Z3 context so all analyzers within the same thread + // can share a common context, because Z3 initialization is slow on some CPUs + // (e.g., AMD EPYC 7502 32-Core). Using thread_local ensures thread safety. + inline static thread_local std::shared_ptr ctx { new z3::context() }; + + /// @brief Z3 solver instance + z3::solver solver {*ctx}; + + /// @brief Memorize pure expressions + std::unordered_map memo_; + + /// @brief Assume overrides + std::vector assume_overrides_; + bool is_assume = false; + + /// @brief Namespace for variable naming + Namespace ns; + + /// @brief Timeout in milliseconds + unsigned timeout_ms {UINT_MAX}; + + /// @brief Max steps + unsigned rlimit {UINT_MAX}; + + /// @brief Create a z3 solver with custom options + static z3::solver CreateSolver(z3::context & ctx) { + z3::solver solver(ctx); + // here we disable model generation to speed up the solving process + solver.set("model", false); + // ensure determinstic behavior + solver.set("random_seed", (unsigned)42); + return solver; + } + + Impl(Analyzer * parent): analyzer(parent) { + scope_stack_.push_back({}); + solver = CreateSolver(*ctx); + // default timeout 5ms + // Z3's implementation of timeout, when setting timeout T ms, it will stop at T - 1 ms + // SetTimeoutMs(5); + // use rlimit, not timeout to ensure determinstic behavior + SetRLimit(1e4); + } + + /// @brief Create a Free z3 expression from PrimExprNode + z3::expr Create(const PrimExprNode *op) { + auto ref = ffi::GetRef(op); + auto dtype = op->dtype; + std::string name = ns.GetNewName(ref); + /// TVM max_val can't handle uint64 max correctly, so we special case it here + if(dtype.is_bool()) { + return ctx->bool_const(name.c_str()); + } + else { + z3::expr e = ctx->int_const(name.c_str()); + if(dtype.is_uint() && dtype.bits() == 64) { + solver.add(ctx->int_val(0) <= e && e <= ctx->int_val((uint64_t)UINT64_MAX)); + } else { + auto min_val = Downcast(min_value(dtype))->value; + auto max_val = Downcast(max_value(dtype))->value; + solver.add(ctx->int_val(min_val) <= e && e <= ctx->int_val(max_val)); + } + return e; + } + } + + struct Scope { + enum Kind { + BindValue, + BindRange, + Constraint, + } kind; + Var var; + PrimExpr value; + PrimExpr min; + PrimExpr extent; + PrimExpr constraint; + }; + + /// @brief scope_stack memorizes existing constraint and bindings + /// to generate SMTLIB2 representation with comments + std::vector> scope_stack_; + + /// @brief Enter a constraint scope + std::function EnterConstraint(const PrimExpr& constraint, bool is_assume=false) { + scope_stack_.push_back({}); + scope_stack_.back().push_back(Scope{Scope::Constraint, Var(), PrimExpr(), PrimExpr(), PrimExpr(), constraint}); + solver.push(); + // is_assume affects the memoization behavior + this->is_assume = is_assume; + auto e = VisitBool(constraint); + this->is_assume = false; + solver.add(e); + auto overrides = std::move(assume_overrides_); + assume_overrides_.clear(); + return [this, overrides]() { + solver.pop(); + for (const auto& expr : assume_overrides_) { + memo_.erase(expr); + } + scope_stack_.pop_back(); + }; + } + + /// @brief Check trivil bad cases, return true if the expr is a bad case + /// Z3 prover may take a long time to initialize (at least 200us), + /// This optimization can speedup 30% of the test cases in our unit tests + bool CheckTrivilBadCases(const PrimExpr & expr) { + if(IsFreeNode(expr)) { + return true; + } + auto checkTrivilCmp = [this](const PrimExpr & lhs, const PrimExpr & rhs) { + if(IsFreeNode(lhs) && rhs->IsInstance()) { + return true; + } + if(IsFreeNode(rhs) && lhs->IsInstance()) { + return true; + } + if(IsFreeNode(lhs) && IsFreeNode(rhs)) { + return true; + } + // cast('xxx', free_var) == constant + if(auto cast = lhs.as()) { + if(IsFreeNode(cast->value) && rhs->IsInstance()) { + return true; + } + } + // constant == cast('xxx', free_var) + if(auto cast = rhs.as()) { + if(IsFreeNode(cast->value) && lhs->IsInstance()) { + return true; + } + } + return false; + }; + if(auto eq = expr.as()) { + auto lhs = eq->a; + auto rhs = eq->b; + return checkTrivilCmp(lhs, rhs); + } else if(auto ne = expr.as()) { + auto lhs = ne->a; + auto rhs = ne->b; + return checkTrivilCmp(lhs, rhs); + } + return false; + } + + /// @brief Check if the expression can be proved + bool CanProve(const PrimExpr &expr) { + if (CheckTrivilBadCases(expr)) return false; + if (!IsValidDType(expr->dtype)) return false; + z3::expr_vector constr(*ctx); + constr.push_back(!VisitBool(expr)); + auto result = solver.check(constr); + constr.pop_back(); + return result == z3::unsat; + } + + /// @brief Binded + /// @brief Bind a variable to a value or a range + void Bind(const Var & var, const PrimExpr & value, bool allow_override = false) { + if (!IsValidDType(var->dtype)) return; + scope_stack_.back().push_back(Scope{ + Scope::BindValue, + var, + value + }); + // we add the binding whenever the value is pure, + // because non-pure parts are handling by creating free variables in VisitExpr + memo_.emplace(var, VisitInt(value)); + } + + /// @brief Bind a variable to a range + void Bind(const Var & var, const Range & range, bool allow_override = false) { + if (!IsValidDType(var->dtype)) return; + scope_stack_.back().push_back(Scope{ + Scope::BindRange, + var, + PrimExpr(), + range->min, + range->extent + }); + // 1. Create a placeholder for the var, and save it in the memo + // if the var is overrided later, we can just update the memo, and the old placeholder will be ignored + auto var_expr = Create(var.as()); + memo_.emplace(var, var_expr); + // 2. Add constraint on the placeholder + // when min_expr >= max_expr, the range is empty, which is under undefined behavior + // instead of adding an unsat constraint, we just skip the range constraint to leave it a free var + if(tir::is_const_int(range->min) && tir::is_const_int(range->min + range->extent)) { + int64_t min_value = *tir::as_const_int(range->min); + int64_t max_value = *tir::as_const_int(range->min + range->extent); + if(min_value < max_value) { + solver.add(ctx->int_val(min_value) <= var_expr); + solver.add(var_expr < ctx->int_val(max_value)); + } + } else { + auto min_expr = VisitInt(range->min); + auto max_expr = VisitInt(analyzer->Simplify(range->min + range->extent)); + solver.add(min_expr >= max_expr || (min_expr <= var_expr && var_expr < max_expr)); + } + } + + void CopyFrom(const Self & other_) { + // 1. create a new solver + // because this->solver depends on this->ctx + // we need to deconstruct the old solver, and create a new one depending on other_.ctx + solver = CreateSolver(*other_.ctx); + // 2. copy the context + // the context is a shared_ptr, we can just copy the pointer + ctx = other_.ctx; + // 3. copy other objects + ns = other_.ns; + for(auto & item: other_.memo_) { + memo_.emplace(item.first, item.second); + } + for(auto a: other_.solver.assertions()) { + solver.add(a); + } + // 4. copy timeout options + // but other solver options are not copied + SetTimeoutMs(other_.timeout_ms); + SetRLimit(other_.rlimit); + // 5. copy the scope stack, which containing comments for SMTLIB2 generation + scope_stack_ = other_.scope_stack_; + } + + /// @brief Set timeout in milliseconds + void SetTimeoutMs(unsigned timeout_ms) { + this->timeout_ms = timeout_ms; + solver.set("timeout", timeout_ms); + } + + /// @brief Set max steps + void SetRLimit(unsigned rlimit) { + this->rlimit = rlimit; + solver.set("rlimit", rlimit); + } + + /// @brief Get the SMTLIB2 representation of the current solver state + ffi::String GetSMTLIB2() { + std::stringstream ss; + ss << "(set-option :timeout " << timeout_ms << ")\n"; + AddScopeDebugMsg(ss); + ss << solver.to_smt2(); + return ss.str(); + } + + void AddScopeDebugMsg(std::ostream & ss) { + for(const auto &scope: scope_stack_) { + ss << "; Entering Scope\n"; + for(const auto & s: scope) { + switch(s.kind) { + case Scope::Constraint: + ss << "; constraint: " << s.constraint << "\n"; + break; + case Scope::BindValue: + ss << "; bind value: " << s.var << " = " << s.value << "\n"; + break; + case Scope::BindRange: + ss << "; bind range: " << s.var << " in [" << s.min << ", " << s.min + s.extent << ")\n"; + break; + } + } + } + } + + /// @brief Get the SMTLIB2 representation of the current solver state with additional expr trying to prove + ffi::String GetSMTLIB2(const PrimExpr & expr) { + std::stringstream ss; + ss << "(set-option :timeout " << timeout_ms << ")\n"; + AddScopeDebugMsg(ss); + ss << "; Trying to prove: " << expr << "\n"; + solver.push(); + solver.add(!VisitBool(expr)); + ss << solver.to_smt2(); + solver.pop(); + return ss.str(); + } + + /// @brief Get the statistics of the solver + ffi::String GetStats() { + std::stringstream ss; + ss << solver.statistics(); + return ss.str(); + } + +private: + + using Z3BinOp = z3::expr(*)(const z3::expr &, const z3::expr &); + + /// @brief Visit expression with memoization + z3::expr VisitExpr(const PrimExpr & e) override { + if(memo_.count(e)) { + return memo_.at(e); + } + auto res = Base::VisitExpr(e); + // if the expression is an assume, we need to memorize it whenever it is pure or not + bool pure = SideEffect(e) <= CallEffectKind::kPure; + if(is_assume || pure) { + memo_.emplace(e, res); + // if we memorized it during an assume, we need to record it for later cleanup + if(is_assume && !pure) { + assume_overrides_.emplace_back(e); + } + } + return res; + } + + /// @brief Check if the expression is a free node having no constraints + bool IsFreeNode(const PrimExpr & e) { + if(memo_.count(e)) { + return false; + } + return e->IsInstance() + || e->IsInstance() + || e->IsInstance() + || e->IsInstance() + || (e->IsInstance() && !IsValidDType(Downcast(e)->value->dtype)); + } + + /// @brief Check if the dtype is valid for z3 integer operations + static bool IsValidDType(const DataType & dtype) { + return (dtype.is_int() || dtype.is_uint() || dtype.is_bool()) && dtype.lanes() == 1; + } + + /// @brief Visit the expression and convert it into z3 integer expression + z3::expr VisitInt(const PrimExpr &expr) { + auto e = VisitExpr(expr); + if (e.is_bool()) { + return z3::ite(e, ctx->int_val(1), ctx->int_val(0)); + } else { + return e; + } + } + + /// @brief Visit the expression and convert it into z3 boolean expression + z3::expr VisitBool(const PrimExpr &e) { + auto expr = VisitExpr(e); + if (expr.is_bool()) { + return expr; + } else { + return expr != ctx->int_val(0); + } + } + + /// @brief Helper function to visit binary arithmetic operations + z3::expr VisitArith(Z3BinOp signed_op, const PrimExprNode *op, const PrimExpr &a, const PrimExpr &b) { + if (IsValidDType(a->dtype) && IsValidDType(b->dtype)) { + return signed_op(VisitInt(a), VisitInt(b)); + } else { + return Create(op); + } + } + + z3::expr VisitExpr_(const LetNode *op) override { + if (IsValidDType(op->var->dtype)) { + memo_.emplace(op->var, VisitInt(op->value)); + } + return VisitExpr(op->body); + } + z3::expr VisitExpr_(const CastNode * op) override { + // if the inner dtype is valid, we just visit it + if (IsValidDType(op->value->dtype) && IsValidDType(op->dtype)) { + return VisitInt(op->value); + } else { + // otherwise, we create a new free z3 variable + return Create(op); + } + } + z3::expr VisitExpr_(const CallNode *op) override { return Create(op); } + z3::expr VisitExpr_(const VarNode *op) override { return Create(op); } + z3::expr VisitExpr_(const BufferLoadNode *op) override { return Create(op); } + z3::expr VisitExpr_(const ProducerLoadNode *op) override { return Create(op); } + z3::expr VisitExpr_(const ReduceNode *op) override { return Create(op); } + z3::expr VisitExpr_(const MinNode *op) override { + auto a = VisitInt(op->a); + auto b = VisitInt(op->b); + return z3::ite(a < b, a, b); + } + z3::expr VisitExpr_(const MaxNode *op) override { + auto a = VisitInt(op->a); + auto b = VisitInt(op->b); + return z3::ite(a > b, a, b); + } + static z3::expr floordiv(const z3::expr & a, const z3::expr & b) { return z3::ite(b > 0, a / b, -((-a) / b)); } + static z3::expr floormod(const z3::expr & a, const z3::expr & b) { return z3::ite(b > 0, a % b, -((-a) % b)); } + z3::expr VisitExpr_(const AddNode *op) override { return VisitArith(z3::operator +, op, op->a, op->b); } + z3::expr VisitExpr_(const SubNode *op) override { return VisitArith(z3::operator -, op, op->a, op->b); } + z3::expr VisitExpr_(const MulNode *op) override { return VisitArith(z3::operator *, op, op->a, op->b); } + z3::expr VisitExpr_(const DivNode *op) override { return VisitArith(z3::operator /, op, op->a, op->b); } + z3::expr VisitExpr_(const ModNode *op) override { return VisitArith(z3::operator %, op, op->a, op->b); } + z3::expr VisitExpr_(const FloorDivNode *op) override { return VisitArith(floordiv, op, op->a, op->b); } + z3::expr VisitExpr_(const FloorModNode *op) override { return VisitArith(floormod, op, op->a, op->b); } + z3::expr VisitExpr_(const EQNode *op) override { return VisitArith(z3::operator==, op, op->a, op->b); } + z3::expr VisitExpr_(const NENode *op) override { return VisitArith(z3::operator!=, op, op->a, op->b); } + z3::expr VisitExpr_(const LTNode *op) override { return VisitArith(z3::operator<, op, op->a, op->b); } + z3::expr VisitExpr_(const LENode *op) override { return VisitArith(z3::operator<=, op, op->a, op->b); } + z3::expr VisitExpr_(const GTNode *op) override { return VisitArith(z3::operator>, op, op->a, op->b); } + z3::expr VisitExpr_(const GENode *op) override { return VisitArith(z3::operator>=, op, op->a, op->b); } + z3::expr VisitExpr_(const AndNode *op) override { return VisitBool(op->a) && VisitBool(op->b); } + z3::expr VisitExpr_(const OrNode *op) override { return VisitBool(op->a) || VisitBool(op->b); } + z3::expr VisitExpr_(const NotNode *op) override { return !VisitBool(op->a); } + z3::expr VisitExpr_(const SelectNode *op) override { return z3::ite(VisitBool(op->condition), VisitInt(op->true_value), VisitInt(op->false_value)); } + z3::expr VisitExpr_(const IntImmNode *op) override { return ctx->int_val(op->value); } + z3::expr VisitExprDefault_(const Object* op) override { + LOG(FATAL) << "Z3Prover only support integers, but got " << op->GetTypeKey() << "."; + TVM_FFI_UNREACHABLE(); + } +}; + +TVM_DLL bool Z3Prover::CanProve(const PrimExpr & expr) { + return impl_->CanProve(expr); +} +TVM_DLL void Z3Prover::Bind(const Var& var, const Range& new_range, bool allow_override) { + return impl_->Bind(var, new_range, allow_override); +} +TVM_DLL void Z3Prover::Bind(const Var& var, const PrimExpr& expr, bool allow_override) { + return impl_->Bind(var, expr, allow_override); +} +std::function Z3Prover::EnterConstraint(const PrimExpr& constraint, bool is_assume) { + return impl_->EnterConstraint(constraint, is_assume); +} +ffi::String Z3Prover::GetSMTLIB2(const ffi::Optional expr) { + if(expr.has_value()) { + return impl_->GetSMTLIB2(expr.value()); + } else { + return impl_->GetSMTLIB2(); + } +} +void Z3Prover::SetTimeoutMs(unsigned timeout_ms) { + impl_->SetTimeoutMs(timeout_ms); +} +void Z3Prover::SetRLimit(unsigned max_step) { + impl_->SetRLimit(max_step); +} +void Z3Prover::CopyFrom(const Z3Prover & other) { + impl_->CopyFrom(*other.impl_); +} +ffi::String Z3Prover::GetStats() { + return impl_->GetStats(); +} +Z3Prover::Z3Prover(Analyzer* parent): impl_(new Impl{parent}) {} +TVM_DLL Z3Prover::~Z3Prover() { + delete impl_; +} + +} // namespace tvm::arith \ No newline at end of file diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index aca06ad595bc..2dad012a163f 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -279,6 +279,7 @@ void BlockReadWriteDetector::VisitStmt_(const BlockRealizeNode* op) { } Update(&writes_buffers_, &write_regions_, write->buffer, relaxed_region); } + StmtVisitor::VisitStmt_(op); } std::vector BlockReadWriteDetector::ConvertMatchedRegion( diff --git a/src/tir/analysis/buffer_access_lca_detector.cc b/src/tir/analysis/buffer_access_lca_detector.cc index 67e8bda6f670..f8665362fa5e 100644 --- a/src/tir/analysis/buffer_access_lca_detector.cc +++ b/src/tir/analysis/buffer_access_lca_detector.cc @@ -70,6 +70,29 @@ class LCADetector : public StmtExprVisitor { return buffer_lca; } + static ffi::Map> DetectVar(const PrimFunc& func) { + LCADetector detector; + for (const auto& kv : func->buffer_map) { + const Buffer& buffer = kv.second; + detector.buffer_var_map_.emplace(buffer->data.get(), buffer.get()); + } + + ScopeInfo root(nullptr, nullptr, 0); + detector.ancestor_scopes_.push_back(&root); + + detector(func->body); + + // Prepare the return + ffi::Map> var_lca; + for (const auto& kv : detector.buffer_var_lca_) { + const Var& var = ffi::GetRef(kv.first); + const ffi::Optional stmt = + kv.second ? ffi::GetRef>(kv.second->stmt) : std::nullopt; + var_lca.Set(var, stmt); + } + return var_lca; + } + private: /*! * \brief The AST node information for querying LCA. @@ -271,6 +294,7 @@ class LCADetector : public StmtExprVisitor { void VisitExpr_(const VarNode* op) final { VisitBufferVar(op); } void VisitBufferVar(const VarNode* op) { + UpdateVarLCA(op, ancestor_scopes_.back()); auto it = buffer_var_map_.find(op); if (it != buffer_var_map_.end()) { UpdateBufferLCA(it->second, ancestor_scopes_.back()); @@ -279,6 +303,8 @@ class LCADetector : public StmtExprVisitor { void UpdateBufferLCA(const BufferNode* buffer, const ScopeInfo* scope) { buffer_var_map_.emplace(buffer->data.get(), buffer); + // Also record LCA for the underlying data var to capture BufferLoad/Store cases. + UpdateVarLCA(buffer->data.get(), scope); if (match_buffers_.find(buffer) == match_buffers_.end()) { // Ingore buffer created by block match_buffer const ScopeInfo*& lca = buffer_lca_[buffer]; @@ -286,6 +312,11 @@ class LCADetector : public StmtExprVisitor { } } + void UpdateVarLCA(const VarNode* var, const ScopeInfo* scope) { + const ScopeInfo*& lca = buffer_var_lca_[var]; + lca = LowestCommonAncestor(lca, scope); + } + void UpdateWithBlockidx() { for (const auto& it : buffer_lca_) { const runtime::StorageScope& scope = @@ -333,6 +364,8 @@ class LCADetector : public StmtExprVisitor { std::unordered_map buffer_lca_ = {}; /*! \brief The map from Buffer data to the Buffer. */ std::unordered_map buffer_var_map_ = {}; + /*! \brief The map from Buffer data var to its LCA ForNode/BlockNode. */ + std::unordered_map buffer_var_lca_ = {}; /*! \brief The match buffers inside blocks. */ std::unordered_set match_buffers_ = {}; /*! \brief The ForNodes/BlockNodes which contain immediate `blockIdx` launch. */ @@ -347,9 +380,14 @@ ffi::Map> DetectBufferAccessLCA(const PrimFunc& func return LCADetector::Detect(func); } +ffi::Map> DetectBufferVarAccessLCA(const PrimFunc& func) { + return LCADetector::DetectVar(func); +} + TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tir.analysis.detect_buffer_access_lca", DetectBufferAccessLCA); + refl::GlobalDef().def("tir.analysis.detect_buffer_var_access_lca", DetectBufferVarAccessLCA); } } // namespace tir } // namespace tvm diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 5eee4ffd8bd5..e6ffd2f09b57 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -581,7 +581,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { } // Call -Call::Call(DataType dtype, RelaxExpr op, ffi::Array args, Span span) { +Call::Call(DataType dtype, RelaxExpr op, ffi::Array args, + ffi::Map annotations, Span span) { for (size_t i = 0; i < args.size(); ++i) { ICHECK(args[i].defined()) << "arg " << i << " is not defined()"; } @@ -590,6 +591,7 @@ Call::Call(DataType dtype, RelaxExpr op, ffi::Array args, Span span) { node->dtype = dtype; node->op = std::move(op); node->args = std::move(args); + node->annotations = std::move(annotations); node->span = std::move(span); data_ = std::move(node); } @@ -600,6 +602,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { "tir.Call", [](ffi::Optional dtype, RelaxExpr op, ffi::Array> args, + ffi::Optional> annotations, Span span) { ffi::Array prim_expr_args; for (const auto& it : args) { @@ -626,7 +629,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { prim_expr_args.push_back(Downcast(it)); } } - return Call(dtype.value_or(DataType::Void()), op, prim_expr_args, span); + return Call(dtype.value_or(DataType::Void()), op, prim_expr_args, + annotations.value_or(ffi::Map()), span); }); } @@ -789,10 +793,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { // BufferLoad void BufferLoadNode::LegalizeDType() { - for (int i = 0; i < static_cast(indices.size()) - 1; i++) { - ICHECK(indices[i].dtype().is_scalar()) - << "Only the last index of a buffer access may be a vector type."; - } + // for (int i = 0; i < static_cast(indices.size()) - 1; i++) { + // ICHECK(indices[i].dtype().is_scalar()) + // << "Only the last index of a buffer access may be a vector type."; + // } if (indices.empty()) { this->dtype = buffer->dtype; diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index cdd1d8ad56d8..84e701210247 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -98,7 +98,8 @@ std::pair IndexMapInverseImpl(const IndexMap& self, /*check_level=*/check_level, analyzer, /*simplify_trivial_iterators=*/false); CHECK(padded_iter_map->errors.empty()) << "Could not parse mapping as sum of iterators. " - << "Error: " << padded_iter_map->errors[0]; + << "\nIndex map: " << self->initial_indices << " -> " << self->final_indices + << "\nError: " << padded_iter_map->errors[0]; // Determine expressions for the input variables, in terms of the // output variables. diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index b7e28e84e748..d57196dc8d62 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -245,12 +245,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { // Allocate Allocate::Allocate(Var buffer_var, DataType dtype, ffi::Array extents, PrimExpr condition, Stmt body, ffi::Map annotations, Span span) { - CHECK(IsPointerType(buffer_var->type_annotation, dtype) || - (dtype.is_bool() && IsPointerType(buffer_var->type_annotation, DataType::Int(8)))) - << "The allocated data type (" << dtype - << ") does not match the type annotation of the buffer " << buffer_var << " (" - << buffer_var->type_annotation - << "). The data type should be an element of the pointer type."; for (size_t i = 0; i < extents.size(); ++i) { ICHECK(extents[i].defined()); @@ -717,7 +711,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { PrimExpr TypeAnnotation(DataType dtype, Span span) { static auto op = Op::Get("tir.type_annotation"); - return tir::Call(dtype, op, {}, span); + return tir::Call(dtype, op, {}, {}, span); } TVM_TIR_REGISTER_OP("type_annotation") diff --git a/src/tir/ir/tir_visitor_with_path.cc b/src/tir/ir/tir_visitor_with_path.cc index 638340e0bd2f..b76234ecb856 100644 --- a/src/tir/ir/tir_visitor_with_path.cc +++ b/src/tir/ir/tir_visitor_with_path.cc @@ -203,8 +203,11 @@ void TIRVisitorWithPath::VisitStmt_(const AttrStmtNode* op, AccessPath path) { context.push_back(std::move(var)); } - } else if (auto expr = op->node.as()) { - Visit(expr.value(), path->Attr("node")); + } else if (op->node != nullptr) { + auto expr = op->node.as(); + if (expr) { + Visit(expr.value(), path->Attr("node")); + } } Visit(op->body, path->Attr("body")); diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 51c0b64ed295..3ad2f0d62a45 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -114,14 +114,14 @@ Type GetTypeFromRuntimeDataType(const DataType& dtype) { PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high, Span span) { return tir::Call( t, tir::builtin::large_uint_imm(), - {make_const(DataType::UInt(32), low, span), make_const(DataType::UInt(32), high, span)}, + {make_const(DataType::UInt(32), low, span), make_const(DataType::UInt(32), high, span)}, {}, span); } // Q-multiplication PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s, Span span) { return tir::Call(DataType::Int(32, x.dtype().lanes()), tir::builtin::q_multiply_shift(), - {x, y, q, s}, span); + {x, y, q, s}, {}, span); } void BroadcastToMatchLanes(PrimExpr& op_a, PrimExpr& op_b) { // NOLINT(*) @@ -249,19 +249,19 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) PrimExpr ret(PrimExpr value, Span span) { CHECK(value.defined()); - return tir::Call(value.dtype(), tir::builtin::ret(), {value}, span); + return tir::Call(value.dtype(), tir::builtin::ret(), {value}, {}, span); } PrimExpr thread_return(Span span) { - return tir::Call(DataType::Void(), tir::builtin::thread_return(), {}, span); + return tir::Call(DataType::Void(), tir::builtin::thread_return(), {}, {}, span); } PrimExpr continue_loop(Span span) { - return tir::Call(DataType::Void(), tir::builtin::continue_loop(), {}, span); + return tir::Call(DataType::Void(), tir::builtin::continue_loop(), {}, {}, span); } PrimExpr break_loop(Span span) { - return tir::Call(DataType::Void(), tir::builtin::break_loop(), {}, span); + return tir::Call(DataType::Void(), tir::builtin::break_loop(), {}, {}, span); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -497,7 +497,7 @@ PrimExpr reinterpret(const DataType& t, PrimExpr value, Span span) { value.dtype().bytes() * value.dtype().lanes() == t.bytes() * t.lanes())) << "Reinterpret requires size match " << t << " vs " << value.dtype(); } - return tir::Call(t, tir::builtin::reinterpret(), {value}, span); + return tir::Call(t, tir::builtin::reinterpret(), {value}, {}, span); } // operator+ @@ -639,13 +639,13 @@ PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, } return tir::Call(true_value.dtype(), tir::builtin::if_then_else(), - {cond, true_value, false_value}, span); + {cond, true_value, false_value}, {}, span); } // likely PrimExpr likely(PrimExpr cond, Span span) { if (is_const_int(cond)) return cond; - return tir::Call(cond.dtype(), tir::builtin::likely(), {cond}, span); + return tir::Call(cond.dtype(), tir::builtin::likely(), {cond}, {}, span); } // operator> @@ -771,7 +771,7 @@ PrimExpr right_shift(PrimExpr a, PrimExpr b, Span span) { } }); - return tir::Call(a.dtype(), tir::builtin::shift_right(), {a, b}, span); + return tir::Call(a.dtype(), tir::builtin::shift_right(), {a, b}, {}, span); } // shift left @@ -790,7 +790,7 @@ PrimExpr left_shift(PrimExpr a, PrimExpr b, Span span) { if (pb->value == 0) return a; } }); - return tir::Call(a.dtype(), tir::builtin::shift_left(), {a, b}, span); + return tir::Call(a.dtype(), tir::builtin::shift_left(), {a, b}, {}, span); } // bitwise and @@ -802,7 +802,7 @@ PrimExpr bitwise_and(PrimExpr a, PrimExpr b, Span span) { const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, (pa->value & pb->value), span); }); - return tir::Call(a.dtype(), tir::builtin::bitwise_and(), {a, b}, span); + return tir::Call(a.dtype(), tir::builtin::bitwise_and(), {a, b}, {}, span); } // bitwise_or @@ -814,7 +814,7 @@ PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span) { const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, (pa->value | pb->value), span); }); - return tir::Call(a.dtype(), tir::builtin::bitwise_or(), {a, b}, span); + return tir::Call(a.dtype(), tir::builtin::bitwise_or(), {a, b}, {}, span); } // bitwise_xor @@ -826,7 +826,7 @@ PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span) { const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, (pa->value ^ pb->value), span); }); - return tir::Call(a.dtype(), tir::builtin::bitwise_xor(), {a, b}, span); + return tir::Call(a.dtype(), tir::builtin::bitwise_xor(), {a, b}, {}, span); } // bitwise_not @@ -834,7 +834,7 @@ PrimExpr operator~(PrimExpr a) { return bitwise_neg(a); } PrimExpr bitwise_neg(PrimExpr a, Span span) { type_check_int_or_bool_args(a, "~ operator (bitwise NOT)"); - return tir::Call(a.dtype(), tir::builtin::bitwise_not(), {a}, span); + return tir::Call(a.dtype(), tir::builtin::bitwise_not(), {a}, {}, span); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -874,7 +874,7 @@ PrimExpr pow(PrimExpr x, PrimExpr y, Span span) { } static auto op = Op::Get("tir.pow"); - return tir::Call(x.dtype(), op, {x, y}, span); + return tir::Call(x.dtype(), op, {x, y}, {}, span); } TVM_TIR_REGISTER_PURE_BINARY_OP("pow").set_attr("TVectorizable", true); @@ -895,7 +895,7 @@ PrimExpr abs(PrimExpr x, Span span) { return FloatImm(x.dtype(), std::fabs(fx->value), fx->span); } static auto op = Op::Get("tir.fabs"); - return tir::Call(x.dtype(), op, {x}, span); + return tir::Call(x.dtype(), op, {x}, {}, span); } else if (x.dtype().is_uint()) { return x; } else { @@ -920,9 +920,9 @@ PrimExpr isnan(PrimExpr x, Span span) { } static auto op = Op::Get("tir.isnan"); if (x.dtype().bits() == 16) { - return tir::Call(t, op, {cast(DataType::Float(32, t.lanes()), std::move(x), span)}, span); + return tir::Call(t, op, {cast(DataType::Float(32, t.lanes()), std::move(x), span)}, {}, span); } else { - return tir::Call(t, op, {x}, span); + return tir::Call(t, op, {x}, {}, span); } } else { LOG(FATAL) << "Data type " << x.dtype() << " not supported for isnan op. Skipping isnan op..."; @@ -1000,7 +1000,7 @@ PrimExpr fmod(PrimExpr x, PrimExpr y, Span span) { BinaryOpMatchTypes(x, y, span); ICHECK(x.dtype().is_float()) << "fmod only applies to float"; static auto op = Op::Get("tir.fmod"); - return tir::Call(x.dtype(), op, {x, y}, span); + return tir::Call(x.dtype(), op, {x, y}, {}, span); } TVM_TIR_REGISTER_PURE_UNARY_OP("fmod"); @@ -1014,7 +1014,7 @@ PrimExpr floor(PrimExpr x, Span span) { const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::floor(fx->value), fx->span); static auto op = Op::Get("tir.floor"); - return tir::Call(x.dtype(), op, {x}, span); + return tir::Call(x.dtype(), op, {x}, {}, span); } TVM_TIR_REGISTER_PURE_UNARY_OP("floor").set_attr("TVectorizable", true); @@ -1028,7 +1028,7 @@ PrimExpr ceil(PrimExpr x, Span span) { const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::ceil(fx->value), fx->span); static auto op = Op::Get("tir.ceil"); - return tir::Call(x.dtype(), op, {x}, span); + return tir::Call(x.dtype(), op, {x}, {}, span); } TVM_TIR_REGISTER_PURE_UNARY_OP("ceil").set_attr("TVectorizable", true); @@ -1042,7 +1042,7 @@ PrimExpr round(PrimExpr x, Span span) { const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value), fx->span); static auto op = Op::Get("tir.round"); - return tir::Call(x.dtype(), op, {x}, span); + return tir::Call(x.dtype(), op, {x}, {}, span); } TVM_TIR_REGISTER_PURE_UNARY_OP("round").set_attr("TVectorizable", true); @@ -1056,7 +1056,7 @@ PrimExpr nearbyint(PrimExpr x, Span span) { const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value), fx->span); static auto op = Op::Get("tir.nearbyint"); - return tir::Call(x.dtype(), op, {x}, span); + return tir::Call(x.dtype(), op, {x}, {}, span); } TVM_TIR_REGISTER_PURE_UNARY_OP("nearbyint"); @@ -1073,7 +1073,7 @@ PrimExpr trunc(PrimExpr x, Span span) { fx->span); } static auto op = Op::Get("tir.trunc"); - return tir::Call(x.dtype(), op, {x}, span); + return tir::Call(x.dtype(), op, {x}, {}, span); } TVM_TIR_REGISTER_PURE_UNARY_OP("trunc").set_attr("TVectorizable", true); @@ -1191,9 +1191,19 @@ TVM_FFI_STATIC_INIT_BLOCK() { bool lhs_is_int = args[0].type_index() == ffi::TypeIndex::kTVMFFIInt; \ bool rhs_is_int = args[1].type_index() == ffi::TypeIndex::kTVMFFIInt; \ if (lhs_is_int) { \ - *ret = (Func(args[0].cast(), args[1].cast(), args[2].cast())); \ + auto arg1 = args[1].cast(); \ + if(arg1.dtype().is_uint()) { \ + *ret = Func(make_const(arg1.dtype(), args[0].cast()), arg1, args[2].cast()); \ + } else { \ + *ret = Func(make_const(arg1.dtype(), args[0].cast()), arg1, args[2].cast()); \ + } \ } else if (rhs_is_int) { \ - *ret = (Func(args[0].cast(), args[1].cast(), args[2].cast())); \ + auto arg0 = args[0].cast(); \ + if(arg0.dtype().is_uint()) { \ + *ret = Func(arg0, make_const(arg0.dtype(), args[1].cast()), args[2].cast()); \ + } else { \ + *ret = Func(arg0, make_const(arg0.dtype(), args[1].cast()), args[2].cast()); \ + } \ } else { \ *ret = (Func(args[0].cast(), args[1].cast(), args[2].cast())); \ } \ diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 00f421e733e2..eae4c64a15a7 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -745,10 +745,10 @@ ffi::Array ConcreteScheduleNode::CacheIndex(const BlockRV& block_rv, } BlockRV ConcreteScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index, - BufferIndexType buffer_index_type) { + BufferIndexType buffer_index_type, bool skip_simplify) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); - result = tir::ReIndex(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type); + result = tir::ReIndex(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type, skip_simplify); TVM_TIR_SCHEDULE_END("reindex", this->error_render_level_); this->state_->DebugVerify(); return CreateRV(result); diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 7ee54961415b..64d27fc10a1d 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -134,7 +134,7 @@ class ConcreteScheduleNode : public ScheduleNode { ffi::Array CacheIndex(const BlockRV& block_rv, const ffi::String& storage_scope, int cse_thresh) override; BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, - BufferIndexType buffer_index_type) override; + BufferIndexType buffer_index_type, bool skip_simplify) override; /******** Schedule: Data movement ********/ BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index, const ffi::String& storage_scope) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 1af0033791f4..b031266211ed 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -428,7 +428,7 @@ TVM_DLL ffi::Array CacheIndex(ScheduleState self, const StmtSRef& bloc * \return The reindex stage block. */ TVM_DLL StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_index, - BufferIndexType buffer_index_type); + BufferIndexType buffer_index_type, bool skip_simplify = false); /******** Schedule: Data movement ********/ diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index a2479a0d28ff..9a883c11359b 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -2241,7 +2241,7 @@ ffi::Array CacheInplace(ScheduleState self, const StmtSRef& block_sref } StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_index, - BufferIndexType buffer_index_type) { + BufferIndexType buffer_index_type, bool skip_simplify) { const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); Block block = ffi::GetRef(block_ptr); Buffer buffer = GetNthAccessBuffer(self, block, buffer_index, buffer_index_type); @@ -2252,11 +2252,14 @@ StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_inde // Load/Store and the buffer is not accessed opaquely ffi::Array original_indices = ReIndexCollector::Collect(self->mod, buffer, block); // Simplify the indices if possible - for (const IterVar& iter : block->iter_vars) { - analyzer.Bind(iter->var, iter->dom); + if (!skip_simplify){ + // skip simplification in case to preserve unit loops. + for (const IterVar& iter : block->iter_vars) { + analyzer.Bind(iter->var, iter->dom); + } + original_indices.MutateByApply( + [&analyzer](const PrimExpr& expr) { return SimplifyNonTrivialExpr(expr, &analyzer); }); } - original_indices.MutateByApply( - [&analyzer](const PrimExpr& expr) { return SimplifyNonTrivialExpr(expr, &analyzer); }); // Collect block iters appearing in the original_indices std::unordered_set covered; @@ -2418,23 +2421,26 @@ struct ReIndexTraits : public UnpackedInstTraits { private: static constexpr size_t kNumInputs = 1; - static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumAttrs = 3; static constexpr size_t kNumDecisions = 0; static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer buffer_index, - Integer buffer_index_type) { + Integer buffer_index_type, Bool skip_simplify) { return sch->ReIndex(block, buffer_index.IntValue(), - static_cast(buffer_index_type->value)); + static_cast(buffer_index_type->value), + skip_simplify.operator bool()); } static ffi::String UnpackedAsPython(ffi::Array outputs, ffi::String block, - Integer buffer_index, Integer buffer_index_type) { + Integer buffer_index, Integer buffer_index_type, + Bool skip_simplify) { PythonAPICall py("reindex"); py.Input("block", block); std::ostringstream os; os << "(\"" << BufferIndexType2Str(static_cast(buffer_index_type->value)) << "\", " << buffer_index << ")"; py.Input("buffer", ffi::String(os.str())); + py.Input("skip_simplify", skip_simplify.operator bool()); py.SingleOutput(outputs); return py.Str(); } diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 35b221561978..d15b43afb965 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -208,9 +208,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def_method("tir.schedule.ScheduleCacheInplace", &ScheduleNode::CacheInplace) .def_method("tir.schedule.ScheduleCacheIndex", &ScheduleNode::CacheIndex) .def("tir.schedule.ScheduleReIndex", - [](Schedule self, const BlockRV& block_rv, int buffer_index, int buffer_index_type) { + [](Schedule self, const BlockRV& block_rv, int buffer_index, int buffer_index_type, bool skip_simplify) { return self->ReIndex(block_rv, buffer_index, - static_cast(buffer_index_type)); + static_cast(buffer_index_type), skip_simplify); }); } /******** (FFI) Data movement ********/ diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 72606f243d69..ad9e65a643cd 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -448,13 +448,13 @@ ffi::Array TracedScheduleNode::CacheIndex(const BlockRV& block_rv, } BlockRV TracedScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index, - BufferIndexType buffer_index_type) { - BlockRV result = ConcreteScheduleNode::ReIndex(block_rv, buffer_index, buffer_index_type); + BufferIndexType buffer_index_type, bool skip_simplify) { + BlockRV result = ConcreteScheduleNode::ReIndex(block_rv, buffer_index, buffer_index_type, skip_simplify); static const InstructionKind& kind = InstructionKind::Get("ReIndex"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, /*inputs=*/{block_rv}, - /*attrs=*/{Integer(buffer_index), Integer(buffer_index_type)}, + /*attrs=*/{Integer(buffer_index), Integer(buffer_index_type), Bool(skip_simplify)}, /*outputs=*/{result})); return result; } diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 8c7b16a47e8d..cfe9b83e7cc6 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -20,6 +20,7 @@ #define TVM_TIR_SCHEDULE_TRACED_SCHEDULE_H_ #include "./concrete_schedule.h" +#include namespace tvm { namespace tir { @@ -94,9 +95,9 @@ class TracedScheduleNode : public ConcreteScheduleNode { ffi::Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, const ffi::String& storage_scope) final; BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, - BufferIndexType buffer_index_type) final; + BufferIndexType buffer_index_type, bool skip_simplify) final; ffi::Array CacheIndex(const BlockRV& block_rv, const ffi::String& storage_scope, - int cse_thresh) final; + int cse_thresh) final; /******** Schedule: Data movement ********/ BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index, const ffi::String& storage_scope) final; diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index f4258fc479d6..950e3fb8c850 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -113,7 +113,7 @@ class PipelineOpaqueAccessRewriter { ffi::Array new_args = call->args; const Buffer& new_buffer = (*it).second; new_args.Set(4, RewriteWmmaFragmentIndex(buffer, new_buffer, call->args[4])); - return Call(call->dtype, call->op, new_args, call->span); + return Call(call->dtype, call->op, new_args, call->annotations, call->span); } } else if (call->op.same_as(mma_sync)) { ffi::Array new_args = call->args; @@ -127,7 +127,7 @@ class PipelineOpaqueAccessRewriter { new_args.Set(i * 2 + 1, new_index); } } - return Call(call->dtype, call->op, new_args, call->span); + return Call(call->dtype, call->op, new_args, call->annotations, call->span); } else if (call->op.same_as(access_ptr)) { return RewriteBufferAccess(call, {1}); } else if (call->op.same_as(ptx_mma)) { @@ -190,7 +190,7 @@ class PipelineOpaqueAccessRewriter { new_args.Set(i + 1, new_index); } } - return Call(call->dtype, call->op, new_args, call->span); + return Call(call->dtype, call->op, new_args, call->annotations, call->span); } const ffi::Map& buffer_data_to_buffer_; diff --git a/src/tir/transforms/inject_torch_mps_stream.cc b/src/tir/transforms/inject_torch_mps_stream.cc new file mode 100644 index 000000000000..22f968bd2f13 --- /dev/null +++ b/src/tir/transforms/inject_torch_mps_stream.cc @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Lower TVM related builtin intrinsics such as packed call. + * \file tir/transforms/inject_torch_mps_stream.cc + */ +#include +#include +#include +#include +#include +#include + +#include + +#include "ir_utils.h" + +namespace tvm { +namespace tir { + +// Calculate the statistics of packed function. +// These information are needed during codegen. +class InjectMPSStream : public StmtExprMutator { + public: + static PrimFunc Build(PrimFunc func) { + return func; + } +}; + +namespace transform { + +Pass InjectTorchMPSStream() { + auto pass_func = [](PrimFunc func, IRModule m, PassContext ctx) { + if (IsHostFunc(func).value_or(false)) { + func = InjectMPSStream::Build(func); + VLOG(2) << "InjectTorchMPSStream: " << func; + } + return func; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.InjectTorchMPSStream", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tir.transform.InjectTorchMPSStream", InjectTorchMPSStream); +} + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 8bcb2077c677..0e83b9113b98 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -183,8 +183,30 @@ class IRConvertSSA final : public StmtExprMutator { value = VisitExpr(ffi::GetRef(expr)); } else if (auto* stmt = value.as()) { value = VisitStmt(ffi::GetRef(stmt)); + } else if (auto opt_arr = value.try_cast>()) { + // Handle container types like Array[...] that may contain Vars/Buffers/Exprs/Stmts + auto arr = opt_arr.value(); + bool arr_changed = false; + std::vector rewritten; + rewritten.reserve(arr.size()); + for (const ObjectRef& elem : arr) { + ObjectRef new_elem = elem; + if (auto* e = elem.as()) { + new_elem = VisitExpr(ffi::GetRef(e)); + } else if (auto* s = elem.as()) { + new_elem = VisitStmt(ffi::GetRef(s)); + } else if (auto* v = elem.as()) { + new_elem = GetRemappedVar(ffi::GetRef(v)); + } else if (auto* b = elem.as()) { + new_elem = GetRemappedBuffer(ffi::GetRef(b)); + } + arr_changed = arr_changed || !new_elem.same_as(elem); + rewritten.push_back(new_elem); + } + if (arr_changed) { + value = ffi::Array(rewritten); + } } - made_change = made_change || !value.same_as(old_value); dict.Set(key, value); } @@ -195,9 +217,7 @@ class IRConvertSSA final : public StmtExprMutator { return func->attrs; } }(); - auto body = VisitStmt(func->body); - // If anything changed, update the returned function if (!params.same_as(func->params) || !buffer_map.same_as(func->buffer_map) || !attrs.same_as(func->attrs) || !body.same_as(func->body)) { @@ -213,6 +233,7 @@ class IRConvertSSA final : public StmtExprMutator { } PrimExpr VisitExpr_(const VarNode* op) final { return GetRemappedVar(ffi::GetRef(op)); } + PrimExpr VisitExpr_(const LetNode* op) final { const Var& v = op->var; if (defined_.count(v.get())) { diff --git a/src/tir/transforms/lower_device_kernel_launch.cc b/src/tir/transforms/lower_device_kernel_launch.cc index fcf85ce6b445..da187fd8c2f0 100644 --- a/src/tir/transforms/lower_device_kernel_launch.cc +++ b/src/tir/transforms/lower_device_kernel_launch.cc @@ -58,6 +58,11 @@ struct KernelInfo { // (e.g. a function that computes the average of `N` elements, and // which must be launched with `N` CUDA threads). ffi::Array launch_args; + + // The extent of each thread + ffi::Map thread_extent; + // The amount of dynamic shared memory used + ffi::Optional dyn_shmem_size{std::nullopt}; }; /*! @@ -85,6 +90,8 @@ class DeviceInfoCollector : public StmtVisitor { collector.info_.launch_args = collector.info_.launch_params.Map( [&](const auto& param) { return collector.GetArgument(param); }); + collector.info_.dyn_shmem_size = collector.dyn_shmem_size; + collector.info_.thread_extent = collector.thread_extent; return collector.info_; } @@ -233,6 +240,12 @@ class DeviceKernelMutator : public StmtExprMutator { func = WithAttr(func, tvm::attr::kGlobalSymbol, gvar->name_hint); } + const auto& info = device_info_map_.at(gvar.get()); + const auto& thread_extent = info.thread_extent; + func = WithAttr(std::move(func), "thread_extent", thread_extent); + if (info.dyn_shmem_size.defined()) { + func = WithAttr(std::move(func), "dyn_shared_memory_buf", info.dyn_shmem_size.value()); + } return func; } diff --git a/src/tir/transforms/merge_shared_memory_allocations.cc b/src/tir/transforms/merge_shared_memory_allocations.cc index 4a2b8698d8cf..132f200ba638 100644 --- a/src/tir/transforms/merge_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_shared_memory_allocations.cc @@ -168,9 +168,8 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { for (const auto& index : load->indices) { this->VisitExpr(index); } - } else { - StmtExprVisitor::VisitExpr_(op); } + StmtExprVisitor::VisitExpr_(op); } void VisitExpr_(const VarNode* buf) final { @@ -215,6 +214,10 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { VisitNewScope(op); } else if (op->attr_key == attr::virtual_thread) { VisitNewScope(op); + } else if (op->attr_key == "kWarpSpecializationScope") { + IfThenElse body = Downcast(op->body); + this->VisitStmt(body->then_case); + this->VisitStmt(body->else_case.value()); } else { StmtExprVisitor::VisitStmt_(op); } diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index 7f19a8992998..3ad05337b591 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -77,7 +77,7 @@ class DataTypeVisitor final : public StmtExprVisitor { explicit DataTypeVisitor(int target_bits) : bits_(target_bits), target_bits_(target_bits) {} void VisitExpr(const PrimExpr& e) { - if (e.dtype().is_int()) { + if (e.dtype().is_int() || e.dtype().is_uint()) { int bits = max_bits_; if (bound_.find(e) == bound_.end()) { analyzer_.const_int_bound(e, &bound_);