@@ -98,125 +98,85 @@ tile_store(tile_t& tile, payload_t& payload) {
9898
9999 static constexpr uint32_t num_block_x = tile_desc::num_block_x;
100100 static constexpr uint32_t num_block_y = tile_desc::num_block_y;
101- // static constexpr uint32_t num_block = tile_desc::num_block;
102101
103- using load_store_attr = typename arch_attr_t <
104- payload_t ::arch_tag>::template load_store_attr<msg_type::block_2d>;
105-
106- static constexpr uint32_t max_block_width =
107- load_store_attr::max_store_width_in_bytes / sizeof (dtype);
108- static constexpr uint32_t max_block_height =
109- load_store_attr::max_store_height_in_elem;
110102 static_assert (
111- (max_block_width % block_size_x) == 0 ,
112- " max_block_width should be a multiply of block_size_x." );
113- static constexpr uint32_t elems_per_CL =
114- load_store_attr::cache_line_size_in_bytes / sizeof (dtype);
115- static constexpr uint32_t st_block_size_y =
116- std::min (block_size_y, max_block_height);
103+ (payload_t ::max_store_width_in_elem % block_size_x) == 0 ,
104+ " max_store_width_in_elem should be a multiply of block_size_x." );
105+
106+ static constexpr uint32_t st_blk_size_y =
107+ std::min (block_size_y, payload_t ::max_store_height_in_elem);
117108
118109 // to make sure full CL store
119- static constexpr uint32_t st_block_size_x =
120- ((tile_size_x % elems_per_CL) == 0 )
121- ? elems_per_CL
122- : (((elems_per_CL % tile_size_x) == 0 ) ? tile_size_x : block_size_x);
110+ static constexpr uint32_t st_blk_size_x =
111+ ((tile_size_x % payload_t ::elems_per_CL) == 0 )
112+ ? payload_t ::elems_per_CL
113+ : (((payload_t ::elems_per_CL % tile_size_x) == 0 ) ? tile_size_x
114+ : block_size_x);
123115
124- static constexpr uint8_t arr_len_candidate = st_block_size_x / block_size_x;
116+ static constexpr uint8_t arr_len_candidate = st_blk_size_x / block_size_x;
125117 static constexpr bool is_valid_arr_len_candidate = (arr_len_candidate == 1 ) ||
126118 (arr_len_candidate == 2 ) || (arr_len_candidate == 4 );
127119
128120 static constexpr uint8_t arr_len =
129121 is_valid_arr_len_candidate ? arr_len_candidate : 1 ;
130122
131- // auto payload_2d = payload.payloads.xetla_format<uint32_t, num_block, 16>();
123+ constexpr uint32_t store_block_elems = block_elems * arr_len;
124+ constexpr uint32_t store_elems = st_blk_size_y * st_blk_size_x;
132125#pragma unroll
133126 for (uint32_t i = 0 ; i < num_block_y; ++i) {
134127 int32_t offset_y = i * block_size_y;
135- constexpr uint32_t store_block_elems = block_elems * arr_len;
136- // auto payload_row =
137- // payload_2d.xetla_select<num_block_x, 1, 16, 1>(i * num_block_x, 0);
138- // detail::reset_tile_desc_core<
139- // num_block_x,
140- // block_size_x * arr_len,
141- // st_block_size_y,
142- // 1,
143- // 1,
144- // false>(payload_row);
145128#pragma unroll
146129 for (uint32_t j = 0 ; j < num_block_x; j += arr_len) {
147130 int32_t offset_x = j * block_size_x;
148- // xetla_tdescriptor tdesc = payload_row.row(j);
149131 auto reg_blk = tile.reg .xetla_select <store_block_elems, 1 >(
150132 (i * num_block_x + j) * block_elems);
151133 xetla_vector<dtype, store_block_elems> combine_blk;
152134 auto combine_blk_2d = combine_blk.xetla_format <
153135 native_type_t <dtype>,
154136 block_size_y,
155137 block_size_x * arr_len>();
156- #pragma unroll
157- for (uint32_t combine_i = 0 ; combine_i < arr_len; ++combine_i) {
138+ /* combine_blk_2d
139+ ____________ ____________
140+ | || |
141+ | block || block |
142+ | || |
143+ |____________||____________|
144+ */
145+ #pragma unroll
146+ for (uint32_t block_id = 0 ; block_id < arr_len; ++block_id) {
158147 combine_blk_2d.xetla_select <block_size_y, 1 , block_size_x, 1 >(
159- 0 , combine_i * block_size_x) =
160- reg_blk.xetla_select <block_elems, 1 >(combine_i * block_elems);
148+ 0 , block_id * block_size_x) =
149+ reg_blk.xetla_select <block_elems, 1 >(block_id * block_elems);
161150 }
162151#pragma unroll
163- for (uint32_t ii = 0 ; ii < block_size_y / st_block_size_y; ++ii) {
164- constexpr uint32_t store_elems =
165- st_block_size_y * block_size_x * arr_len;
152+ for (uint32_t ii = 0 ; ii < block_size_y; ii += st_blk_size_y) {
166153 auto st_blk =
167- combine_blk.xetla_select <store_elems, 1 >(ii * store_elems);
168- // xetla_tstore_global<dtype, store_elems, L1, L2, payload_t::arch_tag>(
169- // tdesc, st_blk);
170- xetla_store_global<
171- dtype,
172- block_size_x * arr_len,
173- st_block_size_y,
174- L1,
175- L2>(
154+ combine_blk.xetla_select <store_elems, 1 >(ii * st_blk_size_x);
155+ xetla_store_global<dtype, st_blk_size_x, st_blk_size_y, L1, L2>(
176156 reinterpret_cast <dtype*>(payload.base_ptr ),
177157 payload.surface_width ,
178158 payload.surface_height ,
179159 payload.surface_pitch ,
180160 payload.offset_x + offset_x,
181- payload.offset_y + offset_y + ii * st_block_size_y,
182- // ::gpu::xetla::detail::xetla_get_tensor_offset_x(tdesc),
183- // ::gpu::xetla::detail::xetla_get_tensor_offset_y(tdesc),
161+ payload.offset_y + offset_y + ii,
184162 st_blk);
185- // xetla_update_tdesc_offsety(
186- // tdesc.xetla_format<uint32_t>(), st_block_size_y);
187163 }
188164 // exceed hardware limitation
189- if constexpr ((block_size_y % st_block_size_y) != 0 ) {
190- constexpr uint32_t blk_remained_start = block_size_y / st_block_size_y *
191- st_block_size_y * block_size_x * arr_len;
192- constexpr uint8_t blk_remained_y = block_size_y % st_block_size_y;
193- constexpr uint8_t blk_remained_elems =
194- blk_remained_y * block_size_x * arr_len;
165+ if constexpr ((block_size_y % st_blk_size_y) != 0 ) {
166+ constexpr uint32_t blk_remained_start =
167+ block_size_y / st_blk_size_y * st_blk_size_y * st_blk_size_x;
168+ constexpr uint8_t blk_remained_y = block_size_y % st_blk_size_y;
169+ constexpr uint8_t blk_remained_elems = blk_remained_y * st_blk_size_x;
195170 auto st_blk =
196171 combine_blk.xetla_select <blk_remained_elems, 1 >(blk_remained_start);
197- // constexpr uint32_t block_widthx_widthy_arrlen =
198- // (block_size_x * arr_len - 1) | ((blk_remained_y - 1) << 8);
199- // gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen(
200- // tdesc.xetla_format<uint32_t>(), block_widthx_widthy_arrlen);
201- // xetla_tstore_global<
202- // dtype,
203- // blk_remained_elems,
204- // L1,
205- // L2,
206- // payload_t::arch_tag>(tdesc, st_blk);
207- xetla_store_global<
208- dtype,
209- block_size_x * arr_len,
210- blk_remained_y,
211- L1,
212- L2>(
172+ xetla_store_global<dtype, st_blk_size_x, blk_remained_y, L1, L2>(
213173 reinterpret_cast <dtype*>(payload.base_ptr ),
214174 payload.surface_width ,
215175 payload.surface_height ,
216176 payload.surface_pitch ,
217177 payload.offset_x + offset_x,
218178 payload.offset_y + offset_y +
219- block_size_y / st_block_size_y * st_block_size_y ,
179+ block_size_y / st_blk_size_y * st_blk_size_y ,
220180 st_blk);
221181 }
222182 }
@@ -227,47 +187,34 @@ tile_store(tile_t& tile, payload_t& payload) {
227187 constexpr uint32_t processed_elems =
228188 num_block_y * num_block_x * block_elems;
229189 constexpr uint32_t remained_st_blk_size_y =
230- st_block_size_y > remained_size_y ? remained_size_y : st_block_size_y;
231- // auto payload_row = payload_2d.xetla_select<num_block_x, 1, 16, 1>(
232- // num_block_y * num_block_x, 0);
233- // detail::reset_tile_desc_core<
234- // num_block_x,
235- // block_size_x * arr_len,
236- // remained_st_blk_size_y,
237- // 1,
238- // 1,
239- // false>(payload_row);
190+ std::min (st_blk_size_y, remained_size_y);
240191#pragma unroll
241192 for (uint32_t j = 0 ; j < num_block_x; j += arr_len) {
242193 int offset_x = j * block_size_x;
243- // xetla_tdescriptor tdesc = payload_row.row(j);
244194 auto reg_blk = tile.reg .xetla_select <remained_block_elems * arr_len, 1 >(
245195 processed_elems + j * remained_block_elems);
246196 // Do combination
247197 xetla_vector<dtype, remained_block_elems * arr_len> combine_blk;
248198 auto combine_blk_2d = combine_blk.xetla_format <
249199 native_type_t <dtype>,
250200 remained_size_y,
251- block_size_x * arr_len >();
201+ st_blk_size_x >();
252202#pragma unroll
253- for (uint32_t combine_i = 0 ; combine_i < arr_len; ++combine_i ) {
203+ for (uint32_t block_id = 0 ; block_id < arr_len; ++block_id ) {
254204 combine_blk_2d.xetla_select <remained_size_y, 1 , block_size_x, 1 >(
255- 0 , combine_i * block_size_x) =
205+ 0 , block_id * block_size_x) =
256206 reg_blk.xetla_select <remained_block_elems, 1 >(
257- combine_i * remained_block_elems);
207+ block_id * remained_block_elems);
258208 }
259209#pragma unroll
260- for (uint32_t ii = 0 ; ii < remained_size_y / remained_st_blk_size_y;
261- ++ii) {
262- constexpr uint32_t store_elems =
263- remained_st_blk_size_y * block_size_x * arr_len;
210+ for (uint32_t ii = 0 ; ii < remained_size_y;
211+ ii += remained_st_blk_size_y) {
212+ constexpr uint32_t store_elems = remained_st_blk_size_y * st_blk_size_x;
264213 auto st_blk =
265- combine_blk.xetla_select <store_elems, 1 >(ii * store_elems);
266- // xetla_tstore_global<dtype, store_elems, L1, L2, payload_t::arch_tag>(
267- // tdesc, st_blk);
214+ combine_blk.xetla_select <store_elems, 1 >(ii * st_blk_size_x);
268215 xetla_store_global<
269216 dtype,
270- block_size_x * arr_len ,
217+ st_blk_size_x ,
271218 remained_st_blk_size_y,
272219 L1,
273220 L2>(
@@ -276,38 +223,19 @@ tile_store(tile_t& tile, payload_t& payload) {
276223 payload.surface_height ,
277224 payload.surface_pitch ,
278225 payload.offset_x + offset_x,
279- payload.offset_y + num_block_y * block_size_y +
280- ii * remained_st_blk_size_y,
226+ payload.offset_y + num_block_y * block_size_y + ii,
281227 st_blk);
282- // xetla_update_tdesc_offsety(
283- // tdesc.xetla_format<uint32_t>(), remained_st_blk_size_y);
284228 }
285229 constexpr uint32_t final_st_blk_size_y =
286230 remained_size_y % remained_st_blk_size_y;
287231 if constexpr (final_st_blk_size_y != 0 ) {
288232 constexpr uint32_t final_start = remained_size_y /
289- remained_st_blk_size_y * remained_st_blk_size_y * block_size_x *
290- arr_len;
233+ remained_st_blk_size_y * remained_st_blk_size_y * st_blk_size_x;
291234 constexpr uint32_t final_store_elems =
292- final_st_blk_size_y * block_size_x * arr_len ;
235+ final_st_blk_size_y * st_blk_size_x ;
293236 auto st_blk =
294237 combine_blk.xetla_select <final_store_elems, 1 >(final_start);
295- // constexpr uint32_t block_widthx_widthy_arrlen =
296- // (block_size_x * arr_len - 1) | ((final_st_blk_size_y - 1) << 8);
297- // gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen(
298- // tdesc.xetla_format<uint32_t>(), block_widthx_widthy_arrlen);
299- // xetla_tstore_global<
300- // dtype,
301- // final_store_elems,
302- // L1,
303- // L2,
304- // payload_t::arch_tag>(tdesc, st_blk);
305- xetla_store_global<
306- dtype,
307- block_size_x * arr_len,
308- final_st_blk_size_y,
309- L1,
310- L2>(
238+ xetla_store_global<dtype, st_blk_size_x, final_st_blk_size_y, L1, L2>(
311239 reinterpret_cast <dtype*>(payload.base_ptr ),
312240 payload.surface_width ,
313241 payload.surface_height ,
0 commit comments