Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
136 commits
Select commit Hold shift + click to select a range
f77817b
TVM Patch for TileLang
Hzfengsy Jun 11, 2025
3427445
Update CMakeLists.txt to include Python include directory and clean u…
LeiWang1999 Jul 8, 2025
d230129
phaseout ck dependency
LeiWang1999 Jul 8, 2025
2139f47
phaseout flashinfer
LeiWang1999 Jul 8, 2025
9249de3
phase out vta
LeiWang1999 Jul 8, 2025
39d113b
support T.address_of(B[i, j])
LeiWang1999 Jul 9, 2025
3c72b8f
Fix CMakeLists.txt to remove unnecessary '-I' flag from Python build …
LeiWang1999 Jul 16, 2025
ce08d9c
Merge branch 'main' of https://github.com/apache/tvm into upstream-dev
LeiWang1999 Jul 24, 2025
9611cc7
c api fix
LeiWang1999 Jul 24, 2025
493f937
[FFI] Remove unused Grid constant and add HANDLE_TO_REFERENCE conversion
LeiWang1999 Jul 24, 2025
9a00cd6
preserve unit loop for reindex scheduling.
LeiWang1999 Jan 20, 2024
fc29e7b
Add skip_simplify option to reindex method for improved index handling
LeiWang1999 Jul 28, 2025
5cc56c9
fix
LeiWang1999 Jul 28, 2025
763f196
Update LetFrameNode to allow mutable value and register reflection ac…
LeiWang1999 Jul 28, 2025
ab733d1
Refactor argument extraction in ExprEvaluator to streamline handling …
LeiWang1999 Jul 22, 2025
ccc68f5
Enhance error reporting in IndexMapInverseImpl by including index map…
LeiWang1999 May 29, 2025
555cc71
Remove redundant type check in Allocate constructor for improved clar…
LeiWang1999 May 21, 2025
d39953f
Change annotations type in Allocate constructor from Map<String, Obje…
LeiWang1999 Jul 29, 2025
9574805
Update minimum Python version requirement from 3.9 to 3.8 for compati…
LeiWang1999 Jul 29, 2025
a08b7c3
Revert "Update minimum Python version requirement from 3.9 to 3.8 for…
LeiWang1999 Jul 29, 2025
cb0fd6d
Refactor stride naming in Namer to use name_hint when defined, improv…
LeiWang1999 Aug 11, 2025
e11521e
Refactor MergeAnnotations function to accept Map<Any, Any> instead of…
LeiWang1999 Aug 12, 2025
e5558ac
Merge branch 'tilelang_main' of https://github.com/TileLang/tvm into …
LeiWang1999 Aug 12, 2025
5a433cc
phaseout legacy components
LeiWang1999 Aug 12, 2025
a64a592
Add support for 'tir.exp2' operation and register 'hip' target kind w…
Alex4210987 Aug 12, 2025
1a07fda
Add tilelang assume attribute to support custom assumption (#9)
kurisu6912 Sep 5, 2025
ee6d522
Add tl.assume attr in tvm (#10)
kurisu6912 Sep 5, 2025
1fc7578
kurisu add assume attr patch 1 (#11)
kurisu6912 Sep 5, 2025
eddefbd
Refactor buffer allocation logic in IRBuilder to use GetLastFrame for…
LeiWang1999 Sep 9, 2025
87b845f
Refactor BlockReadWriteDetector analysis on BlockRealizeNode
chengyupku Sep 14, 2025
b56420b
if_then_else_support
Hzfengsy Sep 16, 2025
9d467c8
ml_dtypes fix
LeiWang1999 Sep 17, 2025
6051f6d
Remove redundant division simplification for FloatImm in RewriteSimpl…
LeiWang1999 Sep 18, 2025
adc0e48
Add simplification for division by FloatImm in RewriteSimplifier
LeiWang1999 Sep 18, 2025
24072d2
Update Python version requirement to 3.8 and enhance type hinting in …
LeiWang1999 Sep 19, 2025
872e32c
Merge commit '24072d2b1' into tilelang_main
LeiWang1999 Sep 19, 2025
0506337
Add modular set analysis for tighter bounds in ConstIntBoundAnalyzer …
LeiWang1999 Sep 21, 2025
0524f76
update
LeiWang1999 Sep 24, 2025
7a71ee3
Refactor ExprEvaluator to improve expression evaluation logic and add…
Hzfengsy Sep 24, 2025
16f16b6
Refactor CUDA intrinsic registrations to use CUDAMath for consistency…
LeiWang1999 Sep 25, 2025
883e96b
Merge commit '16f16b6a7' into tilelang_main
LeiWang1999 Sep 25, 2025
5bf17a3
Workaround limit api too high in tvm (#12)
oraluben Oct 12, 2025
86aef0a
Improve Analyzer symbolic bounds handling and reuse recorded ranges
LeiWang1999 Oct 19, 2025
e66a5cd
Revert "Improve Analyzer symbolic bounds handling and reuse recorded …
LeiWang1999 Oct 19, 2025
3e41c80
Patch for TileLang
Hzfengsy Oct 22, 2025
635e8c3
rebase
LeiWang1999 Oct 22, 2025
548cadf
merge rebase
LeiWang1999 Oct 22, 2025
d7f0118
rebase fix
LeiWang1999 Oct 22, 2025
a78f6f5
rebase fix
LeiWang1999 Oct 22, 2025
ea34153
Remove dlpack submodule
LeiWang1999 Oct 23, 2025
3345fde
rebasefix
LeiWang1999 Oct 23, 2025
3085bc4
Refactor GCD computation and update annotation merging to use ffi::Ma…
LeiWang1999 Oct 23, 2025
0f1ebab
bug fix
LeiWang1999 Oct 23, 2025
00f7d7b
fix type linting warning in tl.float32
kurisu6912 Oct 27, 2025
fa576ec
Merge branch 'tilelang_main' of https://github.com/TileLang/tvm into …
LeiWang1999 Oct 31, 2025
a57f651
Add support for finding PrimFuncFrame in buffer allocation
LeiWang1999 Nov 2, 2025
1815c3e
Merge branch 'tilelang_main' of https://github.com/TileLang/tvm into …
LeiWang1999 Nov 2, 2025
2e24fc1
Add VisitBitwiseXor method to ConstIntBoundAnalyzer for handling bitw…
LeiWang1999 Nov 12, 2025
1b54bb0
Add VisitBitwiseOr method to ConstIntBoundAnalyzer for handling bitwi…
LeiWang1999 Nov 12, 2025
f0bbd3b
remove 3rdparty
LeiWang1999 Nov 12, 2025
093b2cd
Remove dlpack subproject from 3rdparty directory
LeiWang1999 Nov 12, 2025
cdc2ace
Refactor CUDA function attribute setting and enhance error message ha…
LeiWang1999 Nov 14, 2025
f4105f8
Enhance find_include_path function to include system-installed tvm_ff…
LeiWang1999 Nov 16, 2025
49e650b
[DataType] Update to use explicit Bool Type Aligning with DLPack (#18…
tqchen Nov 15, 2025
f4affc7
Revert "[DataType] Update to use explicit Bool Type Aligning with DLP…
LeiWang1999 Nov 19, 2025
2adf5ea
Reapply "[DataType] Update to use explicit Bool Type Aligning with DL…
LeiWang1999 Nov 19, 2025
18a30cd
Relax constraint side effect check in EnterConstraint (#14)
LJC00118 Nov 20, 2025
70808bc
Implement dynamic shared memory handling in CUDA kernel launches. Tra…
LeiWang1999 Nov 20, 2025
7e4da6d
Merge branch 'tilelang_main' of https://github.com/TileLang/tvm into …
LeiWang1999 Nov 20, 2025
713e6ad
Revert "Reapply "[DataType] Update to use explicit Bool Type Aligning…
LeiWang1999 Nov 20, 2025
ead90f6
Add missing int32x2 and other dtypex2
kurisu6912 Nov 21, 2025
bc31e7a
remove unused let_binding_ in CodeGenC
kurisu6912 Nov 21, 2025
3eb4938
Support analyzer clone
LeiWang1999 Nov 21, 2025
cd2b2b6
Merge branch 'tilelang_main' of https://github.com/TileLang/tvm into …
LeiWang1999 Nov 21, 2025
3354ada
disable narrowing uint in NarrowDataType pass
kurisu6912 Nov 24, 2025
e3af400
disable strided buffer load in tvm
kurisu6912 Nov 25, 2025
fc7ed0b
Fix const correctness issues when assigning string literals to Any un…
LeiWang1999 Nov 28, 2025
e633295
integrate z3 with tvm
kurisu6912 Nov 28, 2025
075e08a
Merge commit 'fc7ed0b9c' into z3
kurisu6912 Nov 28, 2025
e8b0261
Remove debug print statements from PyStmtExprVisitor methods to clean…
LeiWang1999 Dec 2, 2025
50ec055
Merge commit 'e8b02611f' into HEAD
kurisu6912 Dec 2, 2025
f86ab53
fix many bugs in z3_prover
kurisu6912 Dec 2, 2025
36e4074
Add better debug print functionality
kurisu6912 Dec 2, 2025
1be49b8
Enhance Z3 prover and analyzer integration with improved constraints …
kurisu6912 Dec 3, 2025
7517ab6
Add methods to set Z3 max step and retrieve Z3 statistics in Analyzer
kurisu6912 Dec 3, 2025
3a32b76
introduce var_lca
LeiWang1999 Dec 5, 2025
0297c0b
Make z3 an optional dependency
kurisu6912 Dec 8, 2025
250827c
make z3 an optional feature
kurisu6912 Dec 8, 2025
0b352a1
build system debug
kurisu6912 Dec 8, 2025
3a8b894
build system debug
kurisu6912 Dec 8, 2025
e6f891c
build system debug
kurisu6912 Dec 8, 2025
7019e85
build system debug
kurisu6912 Dec 8, 2025
46c5427
build system debug
kurisu6912 Dec 8, 2025
e6a6694
build system debug
kurisu6912 Dec 8, 2025
a6088da
build system debug
kurisu6912 Dec 8, 2025
68f5e91
Merge commit '3a32b763e' into z3
kurisu6912 Dec 8, 2025
877f20c
add ,ossomg z3_header dependency
kurisu6912 Dec 8, 2025
a9c22ee
Add structured boolean reasoning in Analyzer::CanProve method
LeiWang1999 Dec 10, 2025
90581fe
Merge branch 'tilelang_main' of https://github.com/TileLang/tvm into …
LeiWang1999 Dec 10, 2025
afc0793
Merge branch 'main' of https://github.com/apache/tvm into tilelang_main
LeiWang1999 Dec 10, 2025
2b1ead1
Implement relaxed PrimFuncFrame retrieval in IRBuilder
LeiWang1999 Dec 12, 2025
e9f9392
use statically linked z3
kurisu6912 Dec 12, 2025
dd834fd
merge branch z3 into tilelang_main
kurisu6912 Dec 12, 2025
afb0370
update z3 build steps
kurisu6912 Dec 12, 2025
f6bcb0b
update build include directory
kurisu6912 Dec 12, 2025
cb9736f
update build system
kurisu6912 Dec 12, 2025
1059552
fix bug in build system
kurisu6912 Dec 12, 2025
3537ef7
fix bug in build system
kurisu6912 Dec 12, 2025
185bba7
minor fix
kurisu6912 Dec 12, 2025
790e793
Enhance IRConvertSSA to handle container types in VisitExpr
LeiWang1999 Dec 14, 2025
68aa846
Merge branch 'tilelang_main' of https://github.com/TileLang/tvm into …
LeiWang1999 Dec 14, 2025
20a5922
fix bool bug in z3
kurisu6912 Dec 15, 2025
d730446
remove z3
kurisu6912 Dec 15, 2025
050815c
simplify z3 integration
kurisu6912 Dec 15, 2025
7514242
delete z3 include in z3_prover_off.cc
kurisu6912 Dec 15, 2025
78b4caf
fix z3 for macos (#15)
oraluben Dec 15, 2025
1dde5c8
patch z3 when building tvm
kurisu6912 Dec 15, 2025
d9ccc03
fix typo
kurisu6912 Dec 15, 2025
c43fd9b
add comment to print z3 soname
kurisu6912 Dec 15, 2025
4d3ec92
Merge branch 'z3' into tilelang_main
kurisu6912 Dec 16, 2025
0a7a6ea
Analyzer: require loop extent > 0 when entering loop
kurisu6912 Dec 17, 2025
8f4da61
fix floordiv & floormod converting in z3 prover
kurisu6912 Dec 17, 2025
88778fa
fix when patchelf not found (#16)
oraluben Dec 17, 2025
6dc8b76
use static Z3 context
LeiWang1999 Dec 19, 2025
79ed747
Update Z3 context to be thread-local for improved thread safety
LeiWang1999 Dec 19, 2025
03ad7cc
Update library loading to use lazy loading
LeiWang1999 Dec 19, 2025
1eeadc6
Add cyclic dependency detection in IntervalSetEvaluator
LeiWang1999 Dec 22, 2025
315036d
Merge branch 'tilelang_main' of https://github.com/TileLang/tvm into …
LeiWang1999 Dec 22, 2025
d9d3e9d
Remove Z3 subproject as it is no longer needed in the repository.
LeiWang1999 Dec 22, 2025
62af333
Add a rewrite pattern
kurisu6912 Dec 25, 2025
9bb866e
[Cherry-pick][CUDA][FFI] Extend kernel launch config to support Progr…
silentCoder-dev Dec 26, 2025
ce96c60
[Z3] change z3 timeout to determinstic `rlimit`
kurisu6912 Dec 26, 2025
b487ec4
Merge branch 'tilelang_main' of https://github.com/TileLang/tvm into …
kurisu6912 Dec 26, 2025
8ae9be3
Add annotations to CallNode and Call classes
LeiWang1999 Dec 26, 2025
23bce01
Merge commit '8ae9be35a' into tilelang_main
LeiWang1999 Dec 26, 2025
e1d4a29
POC for metal w. tvm-ffi
oraluben Jan 5, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,15 @@ include(cmake/modules/Git.cmake)
include(cmake/modules/LibInfo.cmake)
include(cmake/modules/contrib/Mrvl.cmake)

tvm_option(USE_Z3 "Build with Z3 SMT solver support" OFF)

if (USE_Z3)
find_package(Z3 REQUIRED)
list(APPEND COMPILER_SRCS src/target/z3/z3_prover_on.cc)
else()
list(APPEND COMPILER_SRCS src/target/z3/z3_prover_off.cc)
endif()

set(LIBINFO_FILE ${CMAKE_CURRENT_LIST_DIR}/src/support/libinfo.cc)
add_lib_info(${LIBINFO_FILE})
list(REMOVE_ITEM COMPILER_SRCS ${LIBINFO_FILE})
Expand Down Expand Up @@ -541,6 +550,42 @@ else()
target_link_libraries(tvm_runtime PUBLIC tvm_ffi_shared)
endif()

if(USE_Z3)
target_include_directories(tvm_objs PRIVATE ${Z3_CXX_INCLUDE_DIRS})
target_include_directories(tvm_runtime_objs PRIVATE ${Z3_CXX_INCLUDE_DIRS})
target_include_directories(tvm_libinfo_objs PRIVATE ${Z3_CXX_INCLUDE_DIRS})
target_link_libraries(tvm PRIVATE z3::libz3)

if (APPLE)
# `libz3.dylib` from z3-solver on pypi have a "wrong" name `libz3.dylib`,
# so it won't be searched in rpath. We patch it to `@rpath/libz3.dylib` here.
# `POST_BUILD` command needs to be in same cmake file where the target's created.
add_custom_command(TARGET tvm POST_BUILD
COMMAND install_name_tool -change "libz3.dylib" "@rpath/libz3.dylib" $<TARGET_FILE:tvm>
COMMENT "Patching libz3 reference to use @rpath"
)
else()
find_program(PATCHELF_EXECUTABLE patchelf)
if (PATCHELF_EXECUTABLE)
execute_process(
COMMAND ${PATCHELF_EXECUTABLE} --print-soname ${Z3_LIBRARY}
OUTPUT_VARIABLE Z3_SONAME
OUTPUT_STRIP_TRAILING_WHITESPACE
RESULT_VARIABLE Z3_SONAME_RESULT
)
if(NOT Z3_SONAME_RESULT EQUAL "0")
message(FATAL_ERROR "Failed to get Z3 soname using patchelf")
endif()
message("-- Z3 SONAME: ${Z3_SONAME}")
add_custom_command(TARGET tvm POST_BUILD
COMMAND ${PATCHELF_EXECUTABLE} --replace-needed ${Z3_SONAME} libz3.so $<TARGET_FILE:tvm>
COMMENT "Patching libz3 reference to use soname ${Z3_SONAME}"
)
else()
message("patchelf not found, skip.")
endif()
endif()
endif()

target_include_directories(tvm_runtime PUBLIC "$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}>")
set_property(TARGET tvm_runtime APPEND PROPERTY LINK_OPTIONS "${TVM_VISIBILITY_FLAG}")
Expand Down Expand Up @@ -748,7 +793,7 @@ endif()

# Change relative paths in backtrace to absolute ones
if(TVM_IS_DEBUG_BUILD)
set(FILE_PREFIX_MAP_FLAG "-ffile-prefix-map=..=${CMAKE_CURRENT_SOURCE_DIR}")
# set(FILE_PREFIX_MAP_FLAG "-ffile-prefix-map=..=${CMAKE_CURRENT_SOURCE_DIR}")
target_compile_options(tvm PRIVATE "${FILE_PREFIX_MAP_FLAG}")
CHECK_CXX_COMPILER_FLAG("${FILE_PREFIX_MAP_FLAG}" FILE_PREFIX_MAP_SUPPORTED)
if(FILE_PREFIX_MAP_SUPPORTED)
Expand Down
99 changes: 96 additions & 3 deletions include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <memory>
#include <unordered_map>
#include <vector>
#include "tvm/ffi/object.h"

namespace tvm {
/*! \brief namespace of arithmetic analysis. */
Expand Down Expand Up @@ -175,6 +176,8 @@ class ConstIntBoundAnalyzer {
friend class ConstraintContext;
explicit ConstIntBoundAnalyzer(Analyzer* parent);
TVM_DLL ~ConstIntBoundAnalyzer();
// Deep-copy internal state from another instance (for Analyzer::Clone)
void CopyFrom(const ConstIntBoundAnalyzer& other);
/*!
* \brief Update the internal state to enter constraint.
* \param constraint A constraint expression.
Expand Down Expand Up @@ -254,6 +257,8 @@ class ModularSetAnalyzer {
friend class ConstraintContext;
explicit ModularSetAnalyzer(Analyzer* parent);
TVM_DLL ~ModularSetAnalyzer();
// Deep-copy internal state from another instance (for Analyzer::Clone)
void CopyFrom(const ModularSetAnalyzer& other);
/*!
* \brief Update the internal state to enter constraint.
* \param constraint A constraint expression.
Expand Down Expand Up @@ -294,7 +299,7 @@ class RewriteSimplifier {
*
* \return an exit function that must be called to cleanup the constraint can be nullptr.
*/
TVM_DLL std::function<void()> EnterConstraint(const PrimExpr& constraint);
TVM_DLL std::function<void()> EnterConstraint(const PrimExpr& constraint, bool is_assume=false);

/*! \brief Flags to enable more computationally-intensive simplifications
*
Expand Down Expand Up @@ -407,6 +412,8 @@ class RewriteSimplifier {
friend class CanonicalSimplifier;
explicit RewriteSimplifier(Analyzer* parent);
TVM_DLL ~RewriteSimplifier();
// Deep-copy internal state from another instance (for Analyzer::Clone)
void CopyFrom(const RewriteSimplifier& other);
class Impl;
/*! \brief Internal impl */
Impl* impl_;
Expand Down Expand Up @@ -438,6 +445,8 @@ class CanonicalSimplifier {
friend class ConstraintContext;
explicit CanonicalSimplifier(Analyzer* parent);
TVM_DLL ~CanonicalSimplifier();
// Deep-copy internal state from another instance (for Analyzer::Clone)
void CopyFrom(const CanonicalSimplifier& other);
class Impl;
/*! \brief Internal impl */
Impl* impl_;
Expand Down Expand Up @@ -523,6 +532,8 @@ class TransitiveComparisonAnalyzer {
friend class ConstraintContext;
TransitiveComparisonAnalyzer();
TVM_DLL ~TransitiveComparisonAnalyzer();
// Deep-copy internal state from another instance (for Analyzer::Clone)
void CopyFrom(const TransitiveComparisonAnalyzer& other);
class Impl;
/*! \brief Internal impl */
std::unique_ptr<Impl> impl_;
Expand Down Expand Up @@ -553,8 +564,8 @@ class ConstraintContext {
* \param analyzer The analyzer.
* \param constraint The constraint to be applied.
*/
ConstraintContext(Analyzer* analyzer, PrimExpr constraint)
: analyzer_(analyzer), constraint_(constraint) {}
ConstraintContext(Analyzer* analyzer, PrimExpr constraint, bool is_assume=false)
: analyzer_(analyzer), constraint_(constraint), is_assume_(is_assume) {}
// enter the scope.
void EnterWithScope();
// exit the scope.
Expand All @@ -565,6 +576,7 @@ class ConstraintContext {
PrimExpr constraint_;
/*! \brief function to be called in recovery */
std::vector<std::function<void()>> recovery_functions_;
bool is_assume_;
};

/*!
Expand Down Expand Up @@ -616,11 +628,85 @@ class IntSetAnalyzer {
friend class Analyzer;
explicit IntSetAnalyzer(Analyzer* parent);
TVM_DLL ~IntSetAnalyzer();
// Deep-copy internal state from another instance (for Analyzer::Clone)
void CopyFrom(const IntSetAnalyzer& other);
class Impl;
/*! \brief Internal impl */
Impl* impl_;
};

class Z3Prover {
public:
/*!
* \brief Update binding of var to a new expression.
*
* \param var The variable of interest.
* \param new_range The range of allowed values for this var.
* \param allow_override whether we allow override of existing information.
*/
TVM_DLL void Bind(const Var& var, const Range& new_range, bool allow_override = false);

/*!
* \brief Update binding of var to a new expression.
*
* \param var The variable of interest.
* \param expr The bound expression
* \param allow_override whether we allow override of existing information.
*/
TVM_DLL void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false);

/*!
* \brief Whether can we prove expr is always true.
*
* \param expr The expression.
* \return Whether we can prove it.
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
TVM_DLL bool CanProve(const PrimExpr & expr);

/*!
* \brief Update the internal state to enter constraint.
* \param constraint A constraint expression.
*
* \return an exit function that must be called to cleanup the constraint can be nullptr.
*/
std::function<void()> EnterConstraint(const PrimExpr& constraint, bool is_assume=false);

/*!
* \brief Get the SMTLIB2 representation of the current context
* \param expr The optional expression to check
* \return The SMTLIB2 string
*/
ffi::String GetSMTLIB2(const ffi::Optional<PrimExpr> expr);

/*!
* \brief Get statistics about Z3 prover
* \return The statistics string
*/
ffi::String GetStats();

/*!
* \brief Set timeout in milliseconds for Z3 prover
* \param timeout_ms The timeout in milliseconds
*/
void SetTimeoutMs(unsigned timeout_ms);

/*!
* \brief Set resource limitation for Z3 prover
* \param rlimit the resource limitation (like maxinum step or sth.)
*/
void SetRLimit(unsigned rlimit);

private:
friend class Analyzer;
explicit Z3Prover(Analyzer* parent);
TVM_DLL ~Z3Prover();
void CopyFrom(const Z3Prover & other);
class Impl;
Impl* impl_;
};

/*!
* \brief Analyzer that contains bunch of sub-analyzers.
*
Expand Down Expand Up @@ -650,8 +736,15 @@ class TVM_DLL Analyzer {
IntSetAnalyzer int_set;
/*! \brief sub-analyzer transitive comparisons */
TransitiveComparisonAnalyzer transitive_comparisons;
/*! \brief analyzer using z3 */
Z3Prover z3_prover;
/*! \brief constructor */
Analyzer();
/*!
* \brief Create a deep copy of this Analyzer, including all sub-analyzer states.
* \return A new Analyzer with copied internal state.
*/
std::unique_ptr<Analyzer> Clone() const;
/*!
* \brief Mark the value as non-negative value globally in analyzer.
*
Expand Down
1 change: 0 additions & 1 deletion include/tvm/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,5 @@ class TensorMapType : public Type {
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE_WITHOUT_DEFAULT_CONSTRUCTOR(TensorMapType, Type,
TensorMapTypeNode);
};

} // namespace tvm
#endif // TVM_IR_TYPE_H_
3 changes: 2 additions & 1 deletion include/tvm/script/ir_builder/tir/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -350,11 +350,12 @@ class LetFrameNode : public TIRFrameNode {
/*! \brief The value we bind var to */
PrimExpr value;


static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<LetFrameNode>()
.def_ro("var", &LetFrameNode::var)
.def_ro("value", &LetFrameNode::value);
.def_rw("value", &LetFrameNode::value);
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.LetFrame", LetFrameNode, TIRFrameNode);

Expand Down
2 changes: 2 additions & 0 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(UInt, DataType::UInt);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Int, DataType::Int);

#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(FuncName, FDType, Size) \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x2, FDType(Size, 2)); \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x4, FDType(Size, 4)); \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x8, FDType(Size, 8)); \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x16, FDType(Size, 16)); \
Expand All @@ -507,6 +508,7 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int);

#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES_FIXED_SIZE(DType, FDType) \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType, FDType(1)); \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x2, FDType(2)); \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x4, FDType(4)); \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x8, FDType(8)); \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##x16, FDType(16)); \
Expand Down
18 changes: 16 additions & 2 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -731,9 +731,21 @@ class CallNode : public PrimExprNode {
/*! \brief The arguments. */
ffi::Array<PrimExpr> args;

/*!
* \brief Additional annotations about the call.
*
* These annotations can be used to pass additional metadata
* to lowering passes. For tile operators, this can include
* coalesced_width, disable_tma, eviction_policy, etc.
*/
ffi::Map<ffi::String, ObjectRef> annotations;

static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<CallNode>().def_ro("op", &CallNode::op).def_ro("args", &CallNode::args);
refl::ObjectDef<CallNode>()
.def_ro("op", &CallNode::op)
.def_ro("args", &CallNode::args)
.def_ro("annotations", &CallNode::annotations);
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Call", CallNode, PrimExprNode);
};
Expand All @@ -744,7 +756,9 @@ class CallNode : public PrimExprNode {
*/
class Call : public PrimExpr {
public:
TVM_DLL Call(DataType dtype, RelaxExpr op, ffi::Array<PrimExpr> args, Span span = Span());
TVM_DLL Call(DataType dtype, RelaxExpr op, ffi::Array<PrimExpr> args,
ffi::Map<ffi::String, ObjectRef> annotations = {},
Span span = Span());
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Call, PrimExpr, CallNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode);
};
Expand Down
24 changes: 12 additions & 12 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -722,17 +722,17 @@ TVM_DLL PrimExpr fast_erf_float_expr(PrimExpr arg, int bits);

// Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \
static const Op& op = Op::Get("tir." #OpName); \
if (x.dtype().is_bfloat16()) { \
DataType bf16_dtype = x.dtype(); \
DataType fp32_dtype(kDLFloat, 32, bf16_dtype.lanes()); \
PrimExpr x_fp32 = tir::Cast(fp32_dtype, {x}, span); \
PrimExpr result_fp32 = tir::Call(fp32_dtype, op, {x_fp32}, span); \
return tir::Cast(bf16_dtype, {result_fp32}, span); \
} else { \
return tir::Call(x.dtype(), op, {x}, span); \
} \
inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \
static const Op& op = Op::Get("tir." #OpName); \
if (x.dtype().is_bfloat16()) { \
DataType bf16_dtype = x.dtype(); \
DataType fp32_dtype(kDLFloat, 32, bf16_dtype.lanes()); \
PrimExpr x_fp32 = tir::Cast(fp32_dtype, {x}, span); \
PrimExpr result_fp32 = tir::Call(fp32_dtype, op, {x_fp32}, {}, span); \
return tir::Cast(bf16_dtype, {result_fp32}, span); \
} else { \
return tir::Call(x.dtype(), op, {x}, {}, span); \
} \
}

TVM_DECLARE_INTRIN_UNARY(exp);
Expand Down Expand Up @@ -764,7 +764,7 @@ TVM_DECLARE_INTRIN_UNARY(clz);
#define TVM_DECLARE_INTRIN_BINARY(OpName) \
inline PrimExpr OpName(PrimExpr x, PrimExpr y, Span span = Span()) { \
static const Op& op = Op::Get("tir." #OpName); \
return tir::Call(x.dtype(), op, {x, y}, span); \
return tir::Call(x.dtype(), op, {x, y}, {}, span); \
}

TVM_DECLARE_INTRIN_BINARY(atan2);
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ class ScheduleNode : public runtime::Object {
* \return The reindex stage block.
*/
virtual BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type) = 0;
BufferIndexType buffer_index_type, bool skip_simplify = false) = 0;
/******** Schedule: Data movement ********/
virtual BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index,
const ffi::String& storage_scope) = 0;
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1319,6 +1319,8 @@ constexpr const char* explicit_read_region = "explicit_read_region";
*/
constexpr const char* explicit_write_region = "explicit_write_region";

constexpr const char* tilelang_assume = "tl.assume";

/*! \brief ,ark a ForNode represent an irregular loop of non-structural control flow edges. */
constexpr const char* irregular_loop_mark = "irregular_loop_mark";

Expand Down
7 changes: 5 additions & 2 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -1285,12 +1285,15 @@ inline Tensor take(const Tensor& a, ffi::Variant<Tensor, PrimExpr> indices, int
for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
real_indices.push_back(out_index[j]);
}
auto idx = truncmod(truncmod(get_index(indices_position), axis_dim) + axis_dim, axis_dim);
PrimExpr idx = get_index(indices_position);
real_indices.push_back(idx);
for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
real_indices.push_back(out_index[j]);
}
return a(real_indices);
PrimExpr in_bounds = idx >= 0 && idx < axis_dim;
return tvm::if_then_else(
in_bounds, a(real_indices),
tvm::tir::make_const(a->dtype, std::numeric_limits<float>::quiet_NaN()));
},
name, tag);
}
Expand Down
Loading