Skip to content

Commit e085e0d

Browse files
authored
NF4 & FP4 blockwise quant (#1802)
1 parent 0b2ee91 commit e085e0d

16 files changed

+1301
-0
lines changed

csrc/lc/common.h

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// Copyright (c) Facebook, Inc. and its affiliates.
2+
//
3+
// This source code is licensed under the MIT license found in the
4+
// LICENSE file in the root directory of this source tree.
5+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
#ifndef COMMON_H
19+
#define COMMON_H
20+
21+
typedef enum LC_DataType_t
22+
{
23+
General8bit = 0,
24+
FP4 = 1,
25+
NF4 = 2,
26+
} LC_DataType_t;
27+
28+
template <typename T, int DATA_TYPE> void quantize_blockwise(const float * code, const T *A, float *absmax, unsigned char *out, int blocksize, int n);
29+
template<typename T, int DATA_TYPE> void dequantize_blockwise(const float *code, const unsigned char *A, float *absmax, T *out, int block_size, int n);
30+
31+
32+
#define CUDA_CHECK_RETURN(value) { \
33+
cudaError_t _m_cudaStat = value; \
34+
if (_m_cudaStat != cudaSuccess) { \
35+
fprintf(stderr, "Error %s at line %d in file %s\n", \
36+
cudaGetErrorString(_m_cudaStat), __LINE__, __FILE__); \
37+
exit(1); \
38+
} }
39+
40+
#endif

csrc/lc/dequantize_blockwise.cu

+270
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
// Copyright (c) Facebook, Inc. and its affiliates.
2+
//
3+
// This source code is licensed under the MIT license found in the
4+
// LICENSE file in the root directory of this source tree.
5+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
19+
#include "paddle/extension.h"
20+
#include<stdlib.h>
21+
#include<string.h>
22+
#include<sys/types.h>
23+
#include<sys/stat.h>
24+
#include<unistd.h>
25+
#include<fcntl.h>
26+
#include<sys/mman.h>
27+
#include<stdio.h>
28+
#include<algorithm>
29+
#include<cub/device/device_scan.cuh>
30+
#include <cub/block/block_radix_sort.cuh>
31+
#include <cub/warp/warp_reduce.cuh>
32+
#include <cub/block/block_load.cuh>
33+
#include <cub/block/block_discontinuity.cuh>
34+
#include <cub/block/block_store.cuh>
35+
#include <cub/block/block_reduce.cuh>
36+
#include <cub/cub.cuh>
37+
#include <math_constants.h>
38+
#include <thrust/host_vector.h>
39+
#include <thrust/device_vector.h>
40+
#include <mma.h>
41+
#include "common.h"
42+
43+
__device__ float dDequantizeFP4Tree(unsigned char val, float absmax)
44+
{
45+
float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f;
46+
if((val & 0b0100) == 4) // 0
47+
if((val & 0b0010) == 2) //01
48+
if((val & 0b0001) == 1) // 111
49+
return 0.25000000f*absmax*sign; // 1111
50+
else
51+
return 0.16666667f*absmax*sign; // 1110
52+
else
53+
if((val & 0b0001) == 1) // 110
54+
return 0.50000000f*absmax*sign; // 1101
55+
else
56+
return 0.33333333f*absmax*sign; // 1100
57+
else
58+
if((val & 0b0010) == 2) //10
59+
if((val & 0b0001) == 1) // 101
60+
return 1.00000000f*absmax*sign; // 1011
61+
else
62+
return 0.66666667f*absmax*sign; // 1010
63+
else
64+
if((val & 0b0001) == 1) // 100
65+
return 5.208333333e-03f*absmax*sign; // 1001
66+
else
67+
return 0.00000000f*absmax*sign; // 1000
68+
}
69+
70+
__device__ float dDequantizeNF4(unsigned char val)
71+
{
72+
73+
// the values for this tree was generated by test_normal_map_tree
74+
// in the file tests/test_functional.py
75+
if((val & 0b1000) == 8)
76+
if((val & 0b0100) == 4) // 1
77+
if((val & 0b0010) == 2) // 11
78+
if((val & 0b0001) == 1) // 111
79+
return 1.0f;
80+
else
81+
return 0.7229568362236023f;
82+
else
83+
if((val & 0b0001) == 1) // 110
84+
return 0.5626170039176941f;
85+
else
86+
return 0.44070982933044434f;
87+
else
88+
if((val & 0b0010) == 2) //10
89+
if((val & 0b0001) == 1) // 101
90+
return 0.33791524171829224f;
91+
else
92+
return 0.24611230194568634f;
93+
else
94+
if((val & 0b0001) == 1) // 100
95+
return 0.16093020141124725f;
96+
else
97+
return 0.07958029955625534f;
98+
99+
else
100+
if((val & 0b0100) == 4) // 0
101+
if((val & 0b0010) == 2) //01
102+
if((val & 0b0001) == 1) // 011
103+
return 0.0f;
104+
else
105+
return -0.09105003625154495f;
106+
else
107+
if((val & 0b0001) == 1) // 010
108+
return -0.18477343022823334f;
109+
else
110+
return -0.28444138169288635f;
111+
else
112+
if((val & 0b0010) == 2) //00
113+
if((val & 0b0001) == 1) // 001
114+
return -0.39491748809814453f;
115+
else
116+
return -0.5250730514526367f;
117+
else
118+
if((val & 0b0001) == 1) // 000
119+
return -0.6961928009986877f;
120+
else
121+
return -1.0f;
122+
123+
}
124+
125+
126+
template<typename T, int TILE_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE>
127+
__global__ void kDequantizeBlockwise(const float *code, const unsigned char * A, const float * absmax, T *out, int blocksize, int n)
128+
{
129+
130+
const int n_load = (gridDim.x * TILE_SIZE);
131+
int valid_items_load = 0;
132+
int valid_items_store = 0;
133+
const int base_idx = (blockIdx.x * TILE_SIZE);
134+
135+
T vals[NUM_PER_TH*((DATA_TYPE > 0) ? 2 : 1)];
136+
unsigned char qvals[NUM_PER_TH];
137+
float local_abs_max = -FLT_MAX;
138+
139+
typedef cub::BlockLoad<unsigned char, THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
140+
typedef cub::BlockStore<T, THREADS, NUM_PER_TH*((DATA_TYPE > 0) ? 2 : 1), cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
141+
142+
__shared__ typename LoadChar::TempStorage loadchar;
143+
__shared__ typename StoreT::TempStorage storet;
144+
145+
for (unsigned int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE)
146+
{
147+
if(DATA_TYPE > 0)
148+
{
149+
valid_items_load = (n+1)/2 - i > TILE_SIZE ? TILE_SIZE : (n+1)/2 - i;
150+
valid_items_store = n - i*2 > TILE_SIZE*2 ? TILE_SIZE*2 : n - i*2;
151+
}
152+
else
153+
{
154+
valid_items_load = n - i > TILE_SIZE ? TILE_SIZE : n - i;
155+
valid_items_store = n - i > TILE_SIZE ? TILE_SIZE : n - i;
156+
}
157+
local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH)/(blocksize)]);
158+
159+
__syncthreads();
160+
LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128);
161+
162+
switch(DATA_TYPE)
163+
{
164+
case General8bit:
165+
// load code through read-only cache via __ldg
166+
#pragma unroll NUM_PER_TH
167+
for(int j = 0; j < NUM_PER_TH; j++)
168+
vals[j] = __ldg(&code[qvals[j]])*local_abs_max;
169+
break;
170+
case FP4:
171+
#pragma unroll NUM_PER_TH
172+
for(int j = 0; j < NUM_PER_TH; j++)
173+
{
174+
vals[j*2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max);
175+
vals[j*2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max);
176+
}
177+
break;
178+
case NF4:
179+
#pragma unroll NUM_PER_TH
180+
for(int j = 0; j < NUM_PER_TH; j++)
181+
{
182+
vals[j*2] = dDequantizeNF4(qvals[j] >> 4)* local_abs_max;
183+
vals[j*2 + 1] = dDequantizeNF4(qvals[j] & 0x0F)* local_abs_max;
184+
}
185+
break;
186+
}
187+
188+
__syncthreads();
189+
StoreT(storet).Store(&(out[(DATA_TYPE > 0) ? i*2 : i]), vals, valid_items_store);
190+
}
191+
}
192+
193+
//template __global__ void kDequantizeBlockwise<half, 512, 64, 8, FP4>(const float *code, const unsigned char * A, const float * absmax, half *out, int blocksize, int n);
194+
//template __global__ void kDequantizeBlockwise<half, 512, 64, 8, General8bit>(const float *code, const unsigned char * A, const float * absmax, half *out, int blocksize, int n);
195+
//template __global__ void kDequantizeBlockwise<half, 512, 64, 8, NF4>(const float *code, const unsigned char * A, const float * absmax, half *out, int blocksize, int n);
196+
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, FP4>(const float *code, const unsigned char * A, const float * absmax, float *out, int blocksize, int n);
197+
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, General8bit>(const float *code, const unsigned char * A, const float * absmax, float *out, int blocksize, int n);
198+
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, NF4>(const float *code, const unsigned char * A, const float * absmax, float *out, int blocksize, int n);
199+
//template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, FP4>(const float *code, const unsigned char * A, const float * absmax, __nv_bfloat16 *out, int blocksize, int n);
200+
//template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, General8bit>(const float *code, const unsigned char * A, const float * absmax, __nv_bfloat16 *out, int blocksize, int n);
201+
//template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, NF4>(const float *code, const unsigned char * A, const float * absmax, __nv_bfloat16 *out, int blocksize, int n);
202+
203+
204+
205+
template<typename T, int DATA_TYPE> void dequantize_blockwise(const float *code, const unsigned char *A, const float *absmax, T *out, int blocksize, int n)
206+
{
207+
int num_blocks = n/blocksize;
208+
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
209+
int tile_size = (DATA_TYPE > 0) ? 1024 : 512;
210+
211+
if(DATA_TYPE > 0)
212+
kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE><<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize/2, n);
213+
else
214+
kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE><<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize, n);
215+
216+
CUDA_CHECK_RETURN(cudaPeekAtLastError());
217+
}
218+
219+
template void dequantize_blockwise<float, General8bit>(const float *code, const unsigned char *A, const float *absmax, float *out, int blocksize, int n);
220+
template void dequantize_blockwise<float, FP4>(const float *code, const unsigned char *A, const float *absmax, float *out, int blocksize, int n);
221+
template void dequantize_blockwise<float, NF4>(const float *code, const unsigned char *A, const float *absmax, float *out, int blocksize, int n);
222+
//template void dequantize_blockwise<half, General8bit>(const float *code, const unsigned char *A, const float *absmax, half *out, int blocksize, int n);
223+
//template void dequantize_blockwise<half, FP4>(const float *code, const unsigned char *A, const float *absmax, half *out, int blocksize, int n);
224+
//template void dequantize_blockwise<half, NF4>(const float *code, const unsigned char *A, const float *absmax, half *out, int blocksize, int n);
225+
//template void dequantize_blockwise<__nv_bfloat16, General8bit>(const float *code, const unsigned char *A, const float *absmax, __nv_bfloat16 *out, int blocksize, int n);
226+
//template void dequantize_blockwise<__nv_bfloat16, FP4>(const float *code, const unsigned char *A, const float *absmax, __nv_bfloat16 *out, int blocksize, int n);
227+
//template void dequantize_blockwise<__nv_bfloat16, NF4>(const float *code, const unsigned char *A, const float *absmax, __nv_bfloat16 *out, int blocksize, int n);
228+
229+
std::vector<paddle::Tensor> DequantizeBlockwise(const paddle::Tensor& input, const paddle::Tensor& code, const paddle::Tensor& absmax, int blocksize, std::string quant_type) {
230+
int64_t input_numel = input.numel();
231+
int n = input_numel;
232+
std::vector<int64_t> out_shape = input.shape();
233+
if (quant_type != "8bit") { // 4bit
234+
out_shape = {input_numel * 2, 1};
235+
n = n * 2;
236+
}
237+
auto out = paddle::empty(out_shape, paddle::DataType::FLOAT32, input.place());
238+
239+
if (quant_type == "8bit")
240+
dequantize_blockwise<float, General8bit>(code.data<float>(), input.data<unsigned char>(), absmax.data<float>(), out.data<float>(), blocksize, n);
241+
else if (quant_type == "nf4")
242+
dequantize_blockwise<float, NF4>(NULL, input.data<unsigned char>(), absmax.data<float>(), out.data<float>(), blocksize, n);
243+
else if (quant_type == "fp4")
244+
dequantize_blockwise<float, FP4>(NULL, input.data<unsigned char>(), absmax.data<float>(), out.data<float>(), blocksize, n);
245+
else
246+
PD_THROW("NOT supported quant type. Only 8bit, nf4, fp4 are supported. ");
247+
return {out};
248+
};
249+
250+
std::vector<std::vector<int64_t>> GetDequantizeBlockwiseInferShape(const std::vector<int64_t>& input_shape, const std::vector<int64_t>& code_shape, const std::vector<int64_t>& abs_max_shape, int blocksize, std::string quant_type){
251+
int64_t first_shape = input_shape[0] * input_shape[1] * 2;
252+
if (quant_type != "8bit")
253+
return {{first_shape, 1}};
254+
else
255+
return {input_shape};
256+
}
257+
std::vector<paddle::DataType> GetDequantizeBlockwiseInferDtype(const paddle::DataType& input_dtype, const paddle::DataType& code_dtype, const paddle::DataType& abs_max_dtype){
258+
return {paddle::DataType::FLOAT32};
259+
}
260+
261+
262+
PD_BUILD_OP(dequant_blockwise)
263+
.Inputs({"input", "code", "abs_max"})
264+
.Outputs({"output"})
265+
.Attrs({"blocksize: int", "quant_type: std::string"})
266+
.SetKernelFn(PD_KERNEL(DequantizeBlockwise))
267+
.SetInferShapeFn(PD_INFER_SHAPE(GetDequantizeBlockwiseInferShape))
268+
.SetInferDtypeFn(PD_INFER_DTYPE(GetDequantizeBlockwiseInferDtype));
269+
270+

0 commit comments

Comments
 (0)