@@ -3097,3 +3097,75 @@ def _kernel_matmul_fp8_row_non_persistent(
3097
3097
tl .store (C , acc , mask = mask )
3098
3098
else :
3099
3099
tl .atomic_add (C , acc , mask = mask )
3100
+
3101
+
3102
+ @triton .jit
3103
+ def _kernel_dequantize_fp8_block (
3104
+ xq_ptr ,
3105
+ x_scale_ptr ,
3106
+ x_dequant_ptr ,
3107
+ M ,
3108
+ K ,
3109
+ BLOCK_M : tl .constexpr ,
3110
+ BLOCK_K : tl .constexpr ,
3111
+ ):
3112
+ """
3113
+ Kernel to dequantize FP8 tensor to BF16 tensor.
3114
+ Args:
3115
+ xq_ptr (tl.constexpr): Pointer to FP8 tensor.
3116
+ x_scale_ptr (tl.constexpr): Pointer to FP8 scale tensor.
3117
+ x_dequant_ptr (tl.constexpr): Pointer to BF16 tensor.
3118
+ M (tl.constexpr): M dimension of input tensor.
3119
+ K (tl.constexpr): K dimension of input tensor.
3120
+ BLOCK_M (tl.constexpr): Block size for the M dimension.
3121
+ BLOCK_K (tl.constexpr): Block size for the K dimension.
3122
+ """
3123
+ pid_m = tl .program_id (axis = 0 )
3124
+ pid_k = tl .program_id (axis = 1 )
3125
+ k = tl .cdiv (K , BLOCK_K )
3126
+ offs_m = pid_m * BLOCK_M + tl .arange (0 , BLOCK_M )
3127
+ offs_k = pid_k * BLOCK_K + tl .arange (0 , BLOCK_K )
3128
+ offs = offs_m [:, None ] * K + offs_k [None , :]
3129
+ mask = (offs_m [:, None ] < M ) & (offs_k [None , :] < K )
3130
+ xq = tl .load (xq_ptr + offs , mask = mask ).to (tl .bfloat16 )
3131
+ x_scale = tl .load (x_scale_ptr + pid_m * k + pid_k )
3132
+ x_dequant = xq * x_scale
3133
+ tl .store (x_dequant_ptr + offs , x_dequant , mask = mask )
3134
+
3135
+
3136
+ def dequantize_fp8_block (
3137
+ xq : torch .Tensor ,
3138
+ x_scale : torch .Tensor ,
3139
+ block_m : int = 256 ,
3140
+ block_k : int = 256 ,
3141
+ ) -> torch .Tensor :
3142
+ """
3143
+ Dequantize FP8 tensor to BF16 tensor.
3144
+
3145
+ Args:
3146
+ xq (torch.Tensor): FP8 tensor to be dequantized.
3147
+ x_scale (torch.Tensor): FP8 scale tensor.
3148
+ block_m (int): Block size for the M dimension.
3149
+ block_k (int): Block size for the K dimension.
3150
+
3151
+ Returns:
3152
+ torch.Tensor: Dequantized BF16 tensor.
3153
+ """
3154
+
3155
+ assert (
3156
+ xq .is_contiguous () and x_scale .is_contiguous ()
3157
+ ), "Input tensors must be contiguous"
3158
+ assert xq .dim () == 2 and x_scale .dim () == 2 , "Input tensors must have 2 dimensions"
3159
+ M , K = xq .size ()
3160
+ x_dequant = torch .empty_like (xq , dtype = torch .bfloat16 )
3161
+
3162
+ def grid (meta ):
3163
+ return (
3164
+ triton .cdiv (M , meta ["BLOCK_M" ]),
3165
+ triton .cdiv (K , meta ["BLOCK_K" ]),
3166
+ )
3167
+
3168
+ _kernel_dequantize_fp8_block [grid ](
3169
+ xq , x_scale , x_dequant , M , K , BLOCK_M = block_m , BLOCK_K = block_k # pyre-ignore[6]
3170
+ )
3171
+ return x_dequant
0 commit comments