@@ -108,81 +108,86 @@ tile_load(tile_t& tile, payload_t& payload) {
108108
109109 using load_store_attr = load_store_attr_t <msg_type::block_2d, arch_tag>;
110110
111- // static constexpr uint32_t max_load_width_in_elem = trans
112- // ? load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype)
113- // : load_store_attr::max_load_width_in_bytes / sizeof(dtype);
111+ // static constexpr uint32_t max_load_width_in_elem = trans
112+ // ? load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype)
113+ // : load_store_attr::max_load_width_in_bytes / sizeof(dtype);
114114 // static constexpr uint32_t max_load_height_in_elem = trans
115115 // ? load_store_attr::max_trans_load_height_in_elem
116116 // : load_store_attr::max_load_height_in_elem;
117- static constexpr uint32_t max_trans_load_width_in_elem =
118- load_store_attr::max_trans_load_width_in_bytes / sizeof (dtype);
119- static constexpr uint32_t max_load_width_in_elem =
120- load_store_attr::max_load_width_in_bytes / sizeof (dtype);
117+ // static constexpr uint32_t max_trans_load_width_in_elem =
118+ // load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype);
119+ // static constexpr uint32_t max_load_width_in_elem =
120+ // load_store_attr::max_load_width_in_bytes / sizeof(dtype);
121121
122122 // static constexpr uint32_t max_trans_load_height_in_elem =
123123 // load_store_attr::max_trans_load_height_in_elem;
124- static constexpr uint32_t max_load_height_in_elem =
125- load_store_attr::max_load_height_in_elem;
124+
125+ // static constexpr uint32_t max_load_height_in_elem =
126+ // load_store_attr::max_load_height_in_elem;
126127
127128 static constexpr uint32_t elems_per_CL =
128129 load_store_attr::cache_line_size_in_bytes / sizeof (dtype);
129130
130131 static constexpr uint32_t elems_per_reg =
131132 register_bytes_t <arch_tag>::reg_in_bytes / sizeof (dtype);
132133
133- static constexpr uint32_t ld_blk_size_y_limit =
134- mem_transpose ? max_trans_load_width_in_elem : max_load_height_in_elem;
135- static constexpr uint32_t ld_blk_size_y = reg_transpose
136- ? block_size_y
137- : std::min (ld_blk_size_y_limit, block_size_y);
134+ static constexpr uint32_t max_load_width_in_elem = trans
135+ ? load_store_attr::max_trans_load_width_in_bytes / sizeof (dtype)
136+ : load_store_attr::max_load_width_in_bytes / sizeof (dtype);
137+
138+ static constexpr uint32_t max_load_blk_height_in_elem = trans
139+ ? load_store_attr::max_trans_load_height_in_elem
140+ : load_store_attr::max_load_height_in_elem;
141+
142+ static constexpr uint32_t ld_blk_width = std::min (
143+ (mem_transpose ? block_size_y : block_size_x), max_load_width_in_elem);
144+
145+ static constexpr uint32_t ld_blk_height = std::min (
146+ (mem_transpose ? block_size_x : block_size_y),
147+ max_load_blk_height_in_elem);
148+
149+ static constexpr uint32_t ld_blk_size_y =
150+ mem_transpose ? ld_blk_width : ld_blk_height;
151+
152+ static constexpr uint32_t ld_blk_size_y_limit = mem_transpose
153+ ? load_store_attr::max_trans_load_width_in_bytes / sizeof (dtype)
154+ : load_store_attr::max_load_height_in_elem;
138155
139156 // array len is used to make sure memory load is cache line aligned
140157 // disabled while register or memory transpose
141158 static constexpr uint8_t arr_len_candidate =
142- (reg_transpose ||
143- mem_transpose
159+ ((reg_transpose || mem_transpose)
144160 // block elements should be integer
145161 // times of register bytes
146- || ((block_size_y * block_size_x ) % elems_per_reg != 0 )
162+ || ((block_elems ) % elems_per_reg != 0 )
147163 // tail blocks also need to meet above condition
148- ||
149- (((tile_size_y % block_size_y) * block_size_x) % elems_per_reg != 0 )) ||
150- (block_size_y > ld_blk_size_y_limit)
164+ || (((tile_size_y % block_size_y) * block_size_x) % elems_per_reg != 0 ))
165+ // || (block_size_y > load_store_attr::max_load_height_in_elem)
151166 ? 1
152167 : (((tile_size_x % elems_per_CL) == 0 )
153168 ? (((elems_per_CL % block_size_x) == 0 )
154169 ? elems_per_CL / block_size_x
155170 : 1 )
156171 : ((tile_size_x < elems_per_CL) ? (tile_size_x / block_size_x)
157172 : 1 ));
158- static constexpr bool is_valid_arr_len_candidate = (arr_len_candidate == 1 ) ||
159- (arr_len_candidate == 2 ) || (arr_len_candidate == 4 );
160-
161- static constexpr uint8_t arr_len =
162- is_valid_arr_len_candidate ? arr_len_candidate : 1 ;
163-
164- static_assert (
165- reg_transpose || mem_transpose ||
166- (!mem_transpose &&
167- (block_size_x * arr_len) <= max_load_width_in_elem),
168- " When reg_transpose was disabled, check 2d block width "
169- " restriction" );
170- static_assert (
171- !reg_transpose ||
172- (!mem_transpose &&
173- (block_size_x * arr_len) <= max_trans_load_width_in_elem) ||
174- (mem_transpose && (block_size_y * arr_len) <= max_load_width_in_elem),
175- " When reg_transpose was enabled, check 2d block width "
176- " restriction" );
177- static_assert (
178- !reg_transpose ||
179- (!mem_transpose && (block_size_y <= max_load_height_in_elem)) ||
180- (mem_transpose && (block_size_x) <= max_load_height_in_elem),
181- " When reg_transpose was enabled, check 2d block height "
182- " restriction" );
183- static_assert (
184- tile_size_x % (block_size_x * arr_len) == 0 ,
185- " tile_size_x should be a multiple of (block_size_x * arr_len)" );
173+ // NBlocks must be {1,2,4} for bytes and words, {1,2} for dwords, 1 for
174+ // qwords.
175+ static constexpr bool arr_len =
176+ ((arr_len_candidate == 1 ) ||
177+ (arr_len_candidate == 2 && sizeof (dtype) <= 4 ) ||
178+ (arr_len_candidate == 4 && sizeof (dtype) <= 2 ))
179+ ? arr_len_candidate
180+ : 1 ;
181+
182+ if constexpr (!trans && !mem_transform) {
183+ static_assert (
184+ (ld_blk_width * arr_len) <= max_load_width_in_elem,
185+ " When Transposed and Transformed are both set to false, BlockWidth * NBlocks must not exceed 64 for bytes, 32 for words, 16 for dwords, and 8 for qwords" );
186+ } else if constexpr (mem_transform) {
187+ static_assert (
188+ (ld_blk_width * arr_len) <= max_load_width_in_elem,
189+ " When Transformed is true then, BlockWidth * NBlocks must not exceed 64 for bytes and 32 for words." );
190+ }
186191 static_assert (
187192 (reg_transpose &&
188193 ((block_size_x * sizeof (dtype)) % sizeof (load_dtype) == 0 )) ||
@@ -198,10 +203,7 @@ tile_load(tile_t& tile, payload_t& payload) {
198203 constexpr uint32_t load_block_elems = block_elems * arr_len;
199204 auto reg_blk = tile.reg .xetla_select <load_block_elems, 1 >(
200205 (i * num_block_x + j) * block_elems);
201- constexpr uint32_t ld_blk_height = (reg_transpose && trans)
202- ? detail::getNextPowerOf2<ld_blk_size_y>()
203- : ld_blk_size_y;
204- constexpr uint32_t tmp_size = ld_blk_height * block_size_x * arr_len;
206+ constexpr uint32_t tmp_size = ld_blk_width * ld_blk_height * arr_len;
205207 xetla_vector<dtype, tmp_size> reg_tmp;
206208#pragma unroll
207209 for (uint32_t ii = 0 ; ii < block_size_y / ld_blk_size_y; ++ii) {
@@ -213,10 +215,8 @@ tile_load(tile_t& tile, payload_t& payload) {
213215 mem_transpose ? offset_x : (offset_y + ii * ld_blk_size_y);
214216 reg_tmp.xetla_format <native_type_t <load_dtype>>() = xetla_load_global<
215217 native_type_t <load_dtype>,
216- (trans ? ld_blk_size_y : block_size_x) / scale_factor,
217- (trans ? block_size_x : ld_blk_size_y),
218- // block_size_x / scale_factor,
219- // ld_blk_size_y,
218+ ld_blk_width / scale_factor,
219+ ld_blk_height,
220220 arr_len,
221221 trans,
222222 mem_transform,
@@ -261,11 +261,6 @@ tile_load(tile_t& tile, payload_t& payload) {
261261 (mem_transpose ? remained_blk_size_y : block_size_x) / scale_factor;
262262 constexpr uint8_t block_height =
263263 mem_transpose ? block_size_x : remained_blk_size_y;
264- // constexpr uint32_t block_widthx_widthy_arrlen =
265- // (block_width - 1) | ((block_height - 1) << 8);
266- // gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen(
267- // tdesc.xetla_format<uint32_t>(), block_widthx_widthy_arrlen);
268-
269264 reg_blk.xetla_select <load_elems, 1 >(remained_start)
270265 .xetla_format <native_type_t <load_dtype>>() = xetla_load_global<
271266 native_type_t <load_dtype>,
@@ -283,15 +278,6 @@ tile_load(tile_t& tile, payload_t& payload) {
283278 payload.surface_pitch ,
284279 payload.offset_x + offset_x / scale_factor,
285280 payload.offset_y + offset_y + remained_start_y);
286-
287- // xetla_tload_global<
288- // load_dtype,
289- // (load_elems / scale_factor),
290- // L1,
291- // L2,
292- // trans,
293- // mem_transform,
294- // arch_tag>(tdesc);
295281 }
296282 }
297283 }
@@ -304,24 +290,16 @@ tile_load(tile_t& tile, payload_t& payload) {
304290 (!reg_transpose && (remained_size_y > ld_blk_size_y_limit))
305291 ? ld_blk_size_y_limit
306292 : remained_size_y;
307- // auto payload_row = payload_2d.xetla_select<num_block_x, 1, 16, 1>(
308- // num_block_y * num_block_x, 0);
309- // detail::reset_tile_desc_core<
310- // num_block_x,
311- // block_size_x,
312- // remained_ld_blk_size_y,
313- // scale_factor,
314- // arr_len,
315- // mem_transpose>(payload_row);
293+
316294#pragma unroll
317295 for (uint32_t j = 0 ; j < num_block_x; j += arr_len) {
318296 int32_t offset_x = j * block_size_x;
319297 // xetla_tdescriptor tdesc = payload_row.row(j);
320298 auto reg_blk = tile.reg .xetla_select <remained_block_elems * arr_len, 1 >(
321299 processed_elems + j * remained_block_elems);
322- constexpr uint32_t ld_blk_height = (reg_transpose && trans)
323- ? detail::getNextPowerOf2<remained_ld_blk_size_y>()
324- : remained_ld_blk_size_y;
300+ // constexpr uint32_t ld_blk_height = (reg_transpose && trans)
301+ // ? detail::getNextPowerOf2<remained_ld_blk_size_y>()
302+ // : remained_ld_blk_size_y;
325303 constexpr uint32_t tmp_size = ld_blk_height * block_size_x * arr_len;
326304 xetla_vector<dtype, tmp_size> reg_tmp;
327305#pragma unroll
@@ -490,7 +468,8 @@ tile_load(tile_t& tile, payload_t& payload) {
490468
491469// / @brief This function loads data from unaligned-2D memory surface.
492470// / Loads an array of rectangular regions (X,Y)..(X+W,Y+H) from memory into
493- // / registers. Each block will be loaded serially by its corresponding payload.
471+ // / registers. Each block will be loaded serially by its corresponding
472+ // / payload.
494473// / @tparam tile_t Is the tile_t struct contains registers.
495474// / These registers will be the destination of load operation.
496475// / @tparam payload_t Is the mem_payload_t struct describing the memory
@@ -614,7 +593,8 @@ tile_load(tile_t& tile, payload_t& payload) {
614593
615594// / @brief This function loads data from unaligned-2D memory surface.
616595// / Loads an array of rectangular regions (X,Y)..(X+W,Y+H) from memory into
617- // / registers. Each block will be loaded serially by its corresponding payload.
596+ // / registers. Each block will be loaded serially by its corresponding
597+ // / payload.
618598// / @tparam tile_t Is the tile_t struct contains registers.
619599// / These registers will be the destination of load operation.
620600// / @tparam payload_t Is the mem_payload_t struct describing the memory
@@ -679,7 +659,8 @@ tile_load(tile_t& tile, payload_t& payload) {
679659
680660// / @brief This function loads data from unaligned-2D memory surface.
681661// / Loads an array of rectangular regions (X,Y)..(X+W,Y+H) from memory into
682- // / registers. Each block will be loaded serially by its corresponding payload.
662+ // / registers. Each block will be loaded serially by its corresponding
663+ // / payload.
683664// / @tparam tile_t Is the tile_t struct contains registers.
684665// / These registers will be the destination of load operation.
685666// / @tparam payload_t Is the mem_payload_t struct describing the memory
@@ -819,8 +800,8 @@ tile_load(
819800}
820801
821802// / @brief Is the data load func from local shared memory to register file,
822- // / which supports the memory surface is 1d or 2d scenario. And we always assume
823- // / data in SLM is row major.
803+ // / which supports the memory surface is 1d or 2d scenario. And we always
804+ // / assume data in SLM is row major.
824805// / @tparam tile_t Is the tile_t struct contains registers
825806// / These registers will be the destination of load operation.
826807// / @tparam payload_t Is the mem_payload_t struct describing the memory
@@ -902,8 +883,8 @@ tile_load(tile_t& tile, payload_t& payload) {
902883}
903884
904885// / @brief Is the data load func from shared local memory to register file,
905- // / which supports the memory surface is 1d scenario. And the src memory layout
906- // / is always row major.
886+ // / which supports the memory surface is 1d scenario. And the src memory
887+ // / layout is always row major.
907888// / @tparam tile_t Is the tile_t struct contains registers.
908889// / These registers will be the destination of load operation.
909890// / @tparam payload_t Is the mem_payload_t struct describing the memory
0 commit comments