@@ -458,7 +458,24 @@ __XETLA_API xetla_vector<T, N> xetla_load_global(
458458 size_t SurfacePitch,
459459 int X,
460460 int Y) {
461- if constexpr (BlockWidth * sizeof (T) < sizeof (uint32_t )) {
461+ if constexpr (std::is_same_v<T, bf16 >) {
462+ auto ret = xetla_load_global<
463+ fp16,
464+ BlockWidth,
465+ BlockHeight,
466+ NBlocks,
467+ Transposed,
468+ Transformed,
469+ L1H,
470+ L2H>(
471+ reinterpret_cast <const fp16*>(Ptr),
472+ SurfaceWidth,
473+ SurfaceHeight,
474+ SurfacePitch,
475+ X,
476+ Y);
477+ return ret.xetla_format <T>();
478+ } else if constexpr (BlockWidth * sizeof (T) < sizeof (uint32_t )) {
462479 constexpr auto scale_factor = sizeof (uint32_t ) / sizeof (T);
463480 xetla_vector<uint32_t , N> ret = __ESIMD_ENS::lsc_load_2d<
464481 uint32_t ,
@@ -754,13 +771,25 @@ __XETLA_API void xetla_store_global(
754771 int X,
755772 int Y,
756773 xetla_vector<T, N> Vals) {
757- __ESIMD_ENS::lsc_store_2d<
758- T,
759- BlockWidth,
760- BlockHeight,
761- gpu::xetla::detail::get_cache_hint (L1H),
762- gpu::xetla::detail::get_cache_hint (L2H)>(
763- Ptr, SurfaceWidth, SurfaceHeight, SurfacePitch, X, Y, Vals);
774+ if constexpr (std::is_same_v<T, bf16 >) {
775+ xetla_vector<fp16, N> Vals_fp16 = Vals.xetla_format <fp16>();
776+ xetla_store_global<fp16, BlockWidth, BlockHeight, L1H, L2H>(
777+ reinterpret_cast <fp16*>(Ptr),
778+ SurfaceWidth,
779+ SurfaceHeight,
780+ SurfacePitch,
781+ X,
782+ Y,
783+ Vals_fp16);
784+ } else {
785+ __ESIMD_ENS::lsc_store_2d<
786+ T,
787+ BlockWidth,
788+ BlockHeight,
789+ gpu::xetla::detail::get_cache_hint (L1H),
790+ gpu::xetla::detail::get_cache_hint (L2H)>(
791+ Ptr, SurfaceWidth, SurfaceHeight, SurfacePitch, X, Y, Vals);
792+ }
764793}
765794// / template <typename T, int N, int VS = 1, typename OffsetT,
766795// / typename PropertyListT = empty_properties_t>
0 commit comments