Skip to content

Commit b01ab9e

Browse files
Hzfengsyjunrushaovinx13MasterJH5574jinhongyii
authored
[TensorIR][M2a] CacheRead/Write (apache#8863)
Co-authored-by: Junru Shao <[email protected]> Co-authored-by: Wuwei Lin <[email protected]> Co-authored-by: Ruihang Lai <[email protected]> Co-authored-by: Hongyi Jin <[email protected]> Co-authored-by: Siyuan Feng <[email protected]> Co-authored-by: Bohan Hou <[email protected]>
1 parent 7b91e62 commit b01ab9e

18 files changed

+1840
-23
lines changed

Diff for: include/tvm/tir/schedule/schedule.h

+22
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,28 @@ class ScheduleNode : public runtime::Object {
282282
*/
283283
virtual void Unroll(const LoopRV& loop_rv) = 0;
284284
/******** Schedule: Insert cache stages ********/
285+
/*!
286+
* \brief Create a block that reads a buffer region into a read cache. It requires:
287+
* 1) There is at most one block who writes the buffer in the scope.
288+
* 2) The scope block have stage-pipeline property.
289+
* \param block_rv The consumer block of the target buffer.
290+
* \param read_buffer_index The index of the buffer in block's read region.
291+
* \param storage_scope The target storage scope.
292+
* \return The cache stage block.
293+
*/
294+
virtual BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index,
295+
const String& storage_scope) = 0;
296+
/*!
297+
* \brief Create a block that writes a buffer region into a write cache. It requires:
298+
* 1) There is only one block who writes the target buffer.
299+
* 2) The scope block have stage-pipeline property.
300+
* \param block_rv The producer of the buffer
301+
* \param write_buffer_index The index of the buffer in block's write region
302+
* \param storage_scope The target storage scope
303+
* \return The cache stage block.
304+
*/
305+
virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
306+
const String& storage_scope) = 0;
285307
/******** Schedule: Compute location ********/
286308
/*!
287309
* \brief Inline a block into its consumer(s). It requires:

Diff for: include/tvm/tir/schedule/state.h

+5
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,11 @@ class ScheduleStateNode : public Object {
128128
*/
129129
TVM_DLL void Replace(const tir::StmtSRef& src_sref, const Stmt& tgt_stmt,
130130
const Map<Block, Block>& block_sref_reuse);
131+
/*!
132+
* \brief Recalculate the `affine_binding` flag of the scope block info.
133+
* \param scope_sref The sref to the interested scope block.
134+
*/
135+
TVM_DLL void UpdateAffineFlag(const StmtSRef& scope_sref);
131136
/*!
132137
* \brief Trigger the verification according to the `debug_mask` bitmask.
133138
* 1) If the bitmask `kVerifySRefTree` is on, verify the correctness of the sref tree.

Diff for: python/tvm/tir/schedule/schedule.py

+135
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,141 @@ def after_unroll(a: ty.handle, b: ty.handle) -> None:
790790

791791
########## Schedule: Insert cache stages ##########
792792

793+
def cache_read(self, block: BlockRV, read_buffer_index: int, storage_scope: str) -> BlockRV:
794+
"""Create a block that reads a buffer region into a read cache. It requires:
795+
796+
1) There is at most one block who write the buffer in the scope.
797+
798+
2) The scope block have stage-pipeline property.
799+
800+
Parameters
801+
----------
802+
block : BlockRV
803+
The consumer block of the target buffer.
804+
805+
read_buffer_index: int
806+
The index of the buffer in block's read region.
807+
808+
storage_scope: str
809+
The target storage scope.
810+
811+
Returns
812+
-------
813+
cached_block : BlockRV
814+
The block of the cache stage
815+
816+
Examples
817+
--------
818+
Before cache_read, in TensorIR, the IR is:
819+
820+
.. code-block:: python
821+
822+
@tvm.script.tir
823+
def before_cache_read(a: ty.handle, b: ty.handle) -> None:
824+
A = tir.match_buffer(a, (128, 128))
825+
B = tir.match_buffer(b, (128, 128))
826+
for i, j in tir.grid(128, 128):
827+
with tir.block([128, 128], "B") as [vi, vj]:
828+
B[vi, vj] = A[vi, vj] * 2.0
829+
830+
Create the schedule and cache_read:
831+
832+
.. code-block:: python
833+
834+
sch = tir.Schedule(before_cache_read)
835+
block_b = sch.get_block("B")
836+
sch.cache_read(block_b, 0, "local")
837+
print(tvm.script.asscript(sch.mod["main"]))
838+
839+
After applying cache_read, the IR becomes:
840+
841+
.. code-block:: python
842+
843+
@tvm.script.tir
844+
def after_cache_read(a: ty.handle, b: ty.handle) -> None:
845+
A = tir.match_buffer(a, (128, 128))
846+
B = tir.match_buffer(b, (128, 128))
847+
A_local = tir.alloc_buffer((128, 128), scope="local")
848+
for i, j in tir.grid(128, 128):
849+
with tir.block([128, 128], "A_local") as [vi, vj]:
850+
A_local[vi, vj] = A[vi, vj]
851+
for i, j in tir.grid(128, 128):
852+
with tir.block([128, 128], "B") as [vi, vj]:
853+
B[vi, vj] = A_local[vi, vj] * 2.0
854+
855+
"""
856+
return _ffi_api.ScheduleCacheRead( # type: ignore # pylint: disable=no-member
857+
self, block, read_buffer_index, storage_scope
858+
)
859+
860+
def cache_write(self, block: BlockRV, write_buffer_index: int, storage_scope: str) -> BlockRV:
861+
"""Create a block that reads a buffer region into a write cache. It requires:
862+
863+
1) There is only one block who write the buffer in the scope.
864+
865+
2) The scope block have stage-pipeline property.
866+
867+
Parameters
868+
----------
869+
block : BlockRV
870+
The producer block of the target buffer.
871+
872+
write_buffer_index: int
873+
The index of the buffer in block's write region.
874+
875+
storage_scope: str
876+
The target storage scope.
877+
878+
879+
Returns
880+
-------
881+
cached_block : BlockRV
882+
The block of the cache stage
883+
884+
Examples
885+
--------
886+
Before cache_write, in TensorIR, the IR is:
887+
888+
.. code-block:: python
889+
890+
@tvm.script.tir
891+
def before_cache_write(a: ty.handle, b: ty.handle) -> None:
892+
A = tir.match_buffer(a, (128, 128))
893+
B = tir.match_buffer(b, (128, 128))
894+
for i, j in tir.grid(128, 128):
895+
with tir.block([128, 128], "B") as [vi, vj]:
896+
B[vi, vj] = A[vi, vj] * 2.0
897+
898+
Create the schedule and cache_write:
899+
900+
.. code-block:: python
901+
902+
sch = tir.Schedule(before_cache_write)
903+
block_b = sch.get_block("B")
904+
sch.cache_write(block_b, 0, "local")
905+
print(tvm.script.asscript(sch.mod["main"]))
906+
907+
After applying cache_write, the IR becomes:
908+
909+
.. code-block:: python
910+
911+
@tvm.script.tir
912+
def after_cache_write(a: ty.handle, b: ty.handle) -> None:
913+
A = tir.match_buffer(a, (128, 128))
914+
B = tir.match_buffer(b, (128, 128))
915+
B_local = tir.alloc_buffer((128, 128), scope="local")
916+
for i, j in tir.grid(128, 128):
917+
with tir.block([128, 128], "A_local") as [vi, vj]:
918+
B_local[vi, vj] = A[vi, vj] * 2.0
919+
for i, j in tir.grid(128, 128):
920+
with tir.block([128, 128], "B") as [vi, vj]:
921+
B[vi, vj] = B_local[vi, vj]
922+
923+
"""
924+
return _ffi_api.ScheduleCacheWrite( # type: ignore # pylint: disable=no-member
925+
self, block, write_buffer_index, storage_scope
926+
)
927+
793928
########## Schedule: Compute location ##########
794929

795930
def compute_inline(self, block: BlockRV) -> None:

Diff for: src/tir/schedule/analysis.h

+14-7
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@ void VerifyCachedFlags(const ScheduleState& self);
5656
const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block,
5757
GlobalVar* result_g_var);
5858

59+
/*!
60+
* \brief Get the root node of the sref tree, which is the root block of the PrimFunc.
61+
* \param sref The given sref.
62+
* \return The root node of the sref tree which contains the given node.
63+
*/
64+
StmtSRef GetSRefTreeRoot(const StmtSRef& sref);
65+
5966
/******** Scope ********/
6067
/*!
6168
* \brief Checks if scope the specified sref is in is a stage-pipeline and return it
@@ -228,15 +235,15 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr
228235
/******** Block-buffer relation ********/
229236

230237
/*!
231-
* \brief Get the BlockRealize of the single child block of the block or loop specified by
232-
* `parent_sref` on SRef tree, or throw an exception if there is 0 or multiple child blocks
233-
* \param self The schedule state
234-
* \param block The queried block
235-
* \param n The index of the queried buffer
236-
* \return The buffer of the n-th write region of the block.
238+
* \brief Get the n-th read or write buffer of the given block.
239+
* \param self The schedule state.
240+
* \param block The queried block.
241+
* \param n The index of the queried buffer.
242+
* \param is_write A boolean flag to indicate querying write buffer or read buffer.
243+
* \return The buffer of the n-th read/write region of the block.
237244
* \throw ScheduleError If the buffer index is out of bound.
238245
*/
239-
Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n);
246+
Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, bool is_write);
240247

241248
/******** Commutative Reducer ********/
242249

Diff for: src/tir/schedule/analysis/analysis.cc

+36-14
Original file line numberDiff line numberDiff line change
@@ -588,25 +588,37 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr
588588

589589
/******** Block-buffer relation ********/
590590

591-
Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n) {
592-
class WriteBufferIndexOutOfRangeError : public ScheduleError {
591+
Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, bool is_write) {
592+
class BufferIndexOutOfRangeError : public ScheduleError {
593593
public:
594-
explicit WriteBufferIndexOutOfRangeError(IRModule mod, Block block, int buffer_index)
595-
: mod_(std::move(mod)), block_(std::move(block)), buffer_index_(buffer_index) {}
594+
explicit BufferIndexOutOfRangeError(IRModule mod, Block block, int buffer_index, bool is_write)
595+
: mod_(std::move(mod)),
596+
block_(std::move(block)),
597+
buffer_index_(buffer_index),
598+
is_write_(is_write) {}
596599

597600
String FastErrorString() const final {
598-
return "ScheduleError: The input `buffer_index` is out of range. It is required to be in "
599-
"range [0, num_write_regions) where `num_write_regions` is the number of buffer "
600-
"regions written by the block.";
601+
if (is_write_) {
602+
return "ScheduleError: The input `buffer_index` is out of range. It is required to be in "
603+
"range "
604+
"[0, num_write_regions) where `num_write_regions` is the number of buffer regions "
605+
"written by the block.";
606+
} else {
607+
return "ScheduleError: The input `buffer_index` is out of range. It is required to be in "
608+
"range "
609+
"[0, num_read_regions) where `num_read_regions` is the number of buffer regions "
610+
"read by the block.";
611+
}
601612
}
602613

603614
String DetailRenderTemplate() const final {
604615
std::ostringstream os;
605-
size_t num_writes = block_->writes.size();
606-
os << "The block {0} has " << num_writes
607-
<< " write regions, so `buffer_index` is required to be in [0, " << num_writes
616+
size_t num = is_write_ ? block_->writes.size() : block_->reads.size();
617+
std::string access_type = is_write_ ? "write" : "read";
618+
os << "The block {0} has " << num << " " << access_type
619+
<< " regions, so `buffer_index` is required to be in [0, " << num
608620
<< "). However, the input `buffer_index` is " << buffer_index_
609-
<< ", which is out of the expected range";
621+
<< ", which is out of the expected range.";
610622
return os.str();
611623
}
612624

@@ -617,12 +629,15 @@ Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n) {
617629
IRModule mod_;
618630
Block block_;
619631
int buffer_index_;
632+
bool is_write_;
620633
};
621634

622-
if (n < 0 || static_cast<size_t>(n) >= block->writes.size()) {
623-
throw WriteBufferIndexOutOfRangeError(self->mod, block, n);
635+
const Array<BufferRegion>& access_region = is_write ? block->writes : block->reads;
636+
637+
if (n < 0 || static_cast<int>(access_region.size()) <= n) {
638+
throw BufferIndexOutOfRangeError(self->mod, block, n, is_write);
624639
}
625-
return block->writes[n]->buffer;
640+
return access_region[n]->buffer;
626641
}
627642

628643
/******** Pattern Matcher ********/
@@ -941,5 +956,12 @@ bool FromIdentityCombiner(const PrimExpr& identity, const BufferStore& combiner,
941956
return false;
942957
}
943958

959+
/******** SRef Tree Related ********/
960+
StmtSRef GetSRefTreeRoot(const StmtSRef& sref) {
961+
const StmtSRefNode* p = sref.get();
962+
for (; p->parent != nullptr; p = p->parent) {
963+
}
964+
return GetRef<StmtSRef>(p);
965+
}
944966
} // namespace tir
945967
} // namespace tvm

Diff for: src/tir/schedule/concrete_schedule.cc

+21
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,27 @@ void ConcreteScheduleNode::Unroll(const LoopRV& loop_rv) {
416416
}
417417

418418
/******** Schedule: Insert cache stages ********/
419+
420+
BlockRV ConcreteScheduleNode::CacheRead(const BlockRV& block_rv, int read_buffer_index,
421+
const String& storage_scope) {
422+
StmtSRef result{nullptr};
423+
TVM_TIR_SCHEDULE_BEGIN();
424+
result = tir::CacheRead(state_, this->GetSRef(block_rv), read_buffer_index, storage_scope);
425+
TVM_TIR_SCHEDULE_END("cache-read", this->error_render_level_);
426+
this->state_->DebugVerify();
427+
return CreateRV<BlockRV>(result);
428+
}
429+
430+
BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer_index,
431+
const String& storage_scope) {
432+
StmtSRef result{nullptr};
433+
TVM_TIR_SCHEDULE_BEGIN();
434+
result = tir::CacheWrite(state_, this->GetSRef(block_rv), write_buffer_index, storage_scope);
435+
TVM_TIR_SCHEDULE_END("cache-write", this->error_render_level_);
436+
this->state_->DebugVerify();
437+
return CreateRV<BlockRV>(result);
438+
}
439+
419440
/******** Schedule: Compute location ********/
420441

421442
void ConcreteScheduleNode::ComputeInline(const BlockRV& block_rv) {

Diff for: src/tir/schedule/concrete_schedule.h

+4
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ class ConcreteScheduleNode : public ScheduleNode {
103103
void Bind(const LoopRV& loop_rv, const String& thread_axis) override;
104104
void Unroll(const LoopRV& loop_rv) override;
105105
/******** Schedule: Insert cache stages ********/
106+
BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index,
107+
const String& storage_scope) override;
108+
BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
109+
const String& storage_scope) override;
106110
/******** Schedule: Compute location ********/
107111
void ComputeInline(const BlockRV& block) override;
108112
void ReverseComputeInline(const BlockRV& block) override;

Diff for: src/tir/schedule/primitive.h

+24
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,30 @@ TVM_DLL void Bind(ScheduleState self, const StmtSRef& loop_sref, const IterVar&
135135
*/
136136
TVM_DLL void Unroll(ScheduleState self, const StmtSRef& loop_sref);
137137
/******** Schedule: Insert cache stages ********/
138+
/*!
139+
* \brief Create a block that reads a buffer region into a read cache. It requires:
140+
* 1) There is at most one block who writes the buffer in the scope.
141+
* 2) The scope block have stage-pipeline property.
142+
* \param self The state of the schedule
143+
* \param block_sref The consumer block of the target buffer.
144+
* \param read_buffer_index The index of the buffer in block's read region.
145+
* \param storage_scope The target storage scope.
146+
* \return The cache stage block.
147+
*/
148+
TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index,
149+
const String& storage_scope);
150+
/*!
151+
* \brief Create a block that writes a buffer region into a write cache. It requires:
152+
* 1) There is only one block that writes the target buffer.
153+
* 2) The scope block have stage-pipeline property.
154+
* \param self The state of the schedule
155+
* \param block_sref The producer of the buffer
156+
* \param write_buffer_index The index of the buffer in block's write region
157+
* \param storage_scope The target storage scope
158+
* \return The cache stage block.
159+
*/
160+
TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index,
161+
const String& storage_scope);
138162
/******** Schedule: Compute location ********/
139163
/*!
140164
* \brief Inline a block into its consumer(s). It requires:

Diff for: src/tir/schedule/primitive/block_annotate.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
* specific language governing permissions and limitations
1717
* under the License.
1818
*/
19-
#include "../transform.h"
2019
#include "../utils.h"
2120

2221
namespace tvm {
@@ -237,7 +236,8 @@ class StorageAlignInvalidAnnotationError : public ScheduleError {
237236
void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_index, int axis,
238237
int factor, int offset) {
239238
const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref);
240-
Buffer buffer = GetNthWriteBuffer(self, GetRef<Block>(block_ptr), buffer_index);
239+
Buffer buffer =
240+
GetNthAccessBuffer(self, GetRef<Block>(block_ptr), buffer_index, /*is_write=*/true);
241241
StorageAlignInvalidFactorError::Check(self->mod, factor);
242242
axis = StorageAlignAxisOutOfRangeError::CheckAndUpdate(self->mod, buffer, axis);
243243
NonAllocatedBufferError::CheckBufferAllocated(self->mod, block_sref, buffer);

0 commit comments

Comments
 (0)