1
+ #include < cub/block/block_radix_sort.cuh>
2
+ #include < cub/warp/warp_reduce.cuh>
3
+ #include < cub/block/block_load.cuh>
4
+ #include < cub/block/block_discontinuity.cuh>
5
+ #include < cub/block/block_store.cuh>
6
+ #include < cub/block/block_reduce.cuh>
7
+ #include < cub/cub.cuh>
8
+ #include < math_constants.h>
9
+ #include < thrust/host_vector.h>
10
+ #include < thrust/device_vector.h>
11
+ #include < mma.h>
12
+ #include " helper.h"
13
+ #include < iostream>
14
+ using namespace std ;
15
+
16
+ #define HLF_MAX 65504
17
+ #define TH 1024
18
+ #define NUM 4
19
+ #define NUM_BLOCK 4096
20
+
21
+ __device__ unsigned char dQuantizeNF4 (float x)
22
+ {
23
+
24
+ // the values for this tree was generated by test_normal_map_tree
25
+ // in the file tests/test_functional.py
26
+ if (x > 0 .03979014977812767f )
27
+ if (x > 0 .3893125355243683f ) // 1
28
+ if (x > 0 .6427869200706482f ) // 11
29
+ if (x > 0 .8614784181118011f ) // 111
30
+ return 0b1111 ;
31
+ else
32
+ return 0b1110 ;
33
+ else
34
+ if (x > 0 .5016634166240692f ) // 110
35
+ return 0b1101 ;
36
+ else
37
+ return 0b1100 ;
38
+ else
39
+ if (x > 0 .2035212516784668f ) // 10
40
+ if (x > 0 .2920137718319893f ) // 101
41
+ return 0b1011 ;
42
+ else
43
+ return 0b1010 ;
44
+ else
45
+ if (x > 0 .1202552504837513f ) // 100
46
+ return 0b1001 ;
47
+ else
48
+ return 0b1000 ;
49
+ else
50
+ if (x > -0 .33967943489551544f ) // 0
51
+ if (x > -0 .13791173323988914f ) // 01
52
+ if (x > -0 .045525018125772476f ) // 011
53
+ return 0b0111 ;
54
+ else
55
+ return 0b0110 ;
56
+ else
57
+ if (x > -0 .23460740596055984f ) // 010
58
+ return 0b0101 ;
59
+ else
60
+ return 0b0100 ;
61
+ else
62
+ if (x > -0 .6106329262256622f ) // 00
63
+ if (x > -0 .4599952697753906f ) // 001
64
+ return 0b0011 ;
65
+ else
66
+ return 0b0010 ;
67
+ else
68
+ if (x > -0 .8480964004993439f ) // 000
69
+ return 0b0001 ;
70
+ else
71
+ return 0b0000 ;
72
+ }
73
+
74
+ template <typename T, int BLOCK_SIZE, int NUM_PER_TH>
75
+ // __launch_bounds__(TH, 4)
76
+ __global__ void kQuantizeBlockwiseNF4 (const T* A, float *absmax, unsigned char *out, const int n)
77
+ {
78
+ // 所有的 CUDA blocks 处理的所有元素个数
79
+ const int n_full = gridDim .x * BLOCK_SIZE;
80
+ int valid_items = 0 ;
81
+ // 当前 CUDA block 处理元素的起始索引
82
+ const int base_idx = (blockIdx .x * BLOCK_SIZE);
83
+ // 当前 CUDA thread 处理的输入元素
84
+ T vals[NUM_PER_TH];
85
+ // 当前 CUDA thread 处理的输出元素个数
86
+ const int output_num_per_thread = NUM_PER_TH/2 ;
87
+ // 当前 CUDA thread 处理的输出元素
88
+ unsigned char qvals[output_num_per_thread];
89
+ // float local_abs_max = -FLT_MAX;
90
+ float local_abs_max = 0 .0f ;
91
+ typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
92
+ typedef cub::BlockStore<unsigned char , BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH/2 , cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
93
+ typedef cub::BlockReduce<float , BLOCK_SIZE/NUM_PER_TH> BlockReduce;
94
+ typedef cub::BlockLoad<float , BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
95
+
96
+ __shared__ typename LoadT::TempStorage loadt;
97
+ __shared__ typename LoadFloat::TempStorage loadf;
98
+ __shared__ typename StoreChar::TempStorage storec;
99
+ __shared__ typename BlockReduce::TempStorage reduce;
100
+ // 每个CUDA block (也是每个 quantization block)的absmax
101
+ __shared__ float smem_absmax_value[1 ];
102
+
103
+ for (unsigned int i = base_idx; i < n_full; i += gridDim .x *BLOCK_SIZE)
104
+ {
105
+ valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i;
106
+ local_abs_max = -FLT_MAX;
107
+
108
+ __syncthreads ();
109
+ LoadT (loadt).Load (&(A[i]), vals, valid_items, (T)0 .0f );
110
+
111
+ // 1. compute local max
112
+ // 2. broadcast local max
113
+ // 3. normalize inputs and quantize
114
+
115
+ #pragma unroll NUM_PER_TH
116
+ for (int j = 0 ; j < NUM_PER_TH; j++)
117
+ local_abs_max = fmaxf (local_abs_max, fabsf ((float )vals[j]));
118
+
119
+ local_abs_max = BlockReduce (reduce).Reduce (local_abs_max, cub::Max (), valid_items);
120
+
121
+ if (threadIdx .x == 0 )
122
+ smem_absmax_value[0 ] = local_abs_max;
123
+
124
+ __syncthreads ();
125
+
126
+ if (threadIdx .x == 0 )
127
+ absmax[i/BLOCK_SIZE] = local_abs_max;
128
+ else
129
+ local_abs_max = smem_absmax_value[0 ];
130
+
131
+ __syncwarp ();
132
+
133
+ local_abs_max = 1 .0f /local_abs_max;
134
+
135
+ unsigned char packed_4bit = 0 ;
136
+
137
+ #pragma unroll NUM_PER_TH
138
+ for (int j = 0 ; j < NUM_PER_TH/2 ; j++)
139
+ {
140
+ packed_4bit |= dQuantizeNF4 (((float )vals[2 *j])*local_abs_max) << 4 ;
141
+ packed_4bit |= dQuantizeNF4 (((float )vals[2 *j+1 ])*local_abs_max);
142
+ qvals[j] = packed_4bit;
143
+ }
144
+
145
+ __syncthreads ();
146
+ StoreChar (storec).Store (&(out[i/2 ]), qvals, (valid_items+1 )/2 );
147
+ }
148
+ }
149
+
150
+ #define MAKE_kQuantizeBlockwiseNF4 (dtype, blocksize, num_per_thread ) \
151
+ template __global__ void kQuantizeBlockwiseNF4 <dtype, blocksize, num_per_thread>(const dtype * A, float *absmax, unsigned char *out, const int n); \
152
+
153
+ MAKE_kQuantizeBlockwiseNF4 (half, 4096 , 4 )
154
+ MAKE_kQuantizeBlockwiseNF4(half, 1024 , 4 )
155
+ MAKE_kQuantizeBlockwiseNF4(half, 512 , 2 )
156
+ MAKE_kQuantizeBlockwiseNF4(half, 256 , 2 )
157
+ MAKE_kQuantizeBlockwiseNF4(half, 128 , 2 )
158
+ MAKE_kQuantizeBlockwiseNF4(half, 64 , 2 )
159
+
160
+ MAKE_kQuantizeBlockwiseNF4(__nv_bfloat16, 4096 , 4 )
161
+ MAKE_kQuantizeBlockwiseNF4(__nv_bfloat16, 1024 , 4 )
162
+ MAKE_kQuantizeBlockwiseNF4(__nv_bfloat16, 512 , 2 )
163
+ MAKE_kQuantizeBlockwiseNF4(__nv_bfloat16, 256 , 2 )
164
+ MAKE_kQuantizeBlockwiseNF4(__nv_bfloat16, 128 , 2 )
165
+ MAKE_kQuantizeBlockwiseNF4(__nv_bfloat16, 64 , 2 )
166
+
167
+ MAKE_kQuantizeBlockwiseNF4(float , 4096 , 4 )
168
+ MAKE_kQuantizeBlockwiseNF4(float , 1024 , 4 )
169
+ MAKE_kQuantizeBlockwiseNF4(float , 512 , 2 )
170
+ MAKE_kQuantizeBlockwiseNF4(float , 256 , 2 )
171
+ MAKE_kQuantizeBlockwiseNF4(float , 128 , 2 )
172
+ MAKE_kQuantizeBlockwiseNF4(float , 64 , 2 )
173
+
174
+ template <paddle::DataType D>
175
+ std::vector<paddle::Tensor> LaunchQuantizeNF4(const paddle::Tensor& input, int block_size) {
176
+ cout << " LaunchQuantizeNF4 begin-------" << endl;
177
+ typedef PDTraits<D> traits_;
178
+ typedef typename traits_::DataType DataType_;
179
+ typedef typename traits_::data_t data_t ;
180
+ auto input_shape = input.shape ();
181
+ auto output = paddle::full (input_shape, 1 , paddle::DataType::UINT8, input.place ());
182
+ const int n = input.numel ();
183
+ int num_blocks = n/block_size;
184
+ num_blocks = n % block_size == 0 ? num_blocks : num_blocks + 1 ;
185
+
186
+ auto abs_max = paddle::full ({num_blocks}, 1 , paddle::DataType::FLOAT32, input.place ());
187
+
188
+ const DataType_ *in_ptr = reinterpret_cast <const DataType_*>(input.data <data_t >());
189
+ unsigned char *out_ptr = output.mutable_data <unsigned char >();
190
+ float *abs_max_ptr = abs_max.mutable_data <float >();
191
+
192
+ if (block_size == 2048 ) {
193
+ kQuantizeBlockwiseNF4 <DataType_, 2048 , 4 ><<<num_blocks, 512 >>> (in_ptr, abs_max_ptr, out_ptr, n);
194
+ } else if (block_size == 1024 ) {
195
+ kQuantizeBlockwiseNF4 <DataType_, 1024 , 4 ><<<num_blocks, 256 >>> (in_ptr, abs_max_ptr, out_ptr, n);
196
+ } else if (block_size == 512 ) {
197
+ kQuantizeBlockwiseNF4 <DataType_, 512 , 2 ><<<num_blocks, 256 >>> (in_ptr, abs_max_ptr, out_ptr, n);
198
+ } else if (block_size == 256 ) {
199
+ kQuantizeBlockwiseNF4 <DataType_, 256 , 2 ><<<num_blocks, 128 >>> (in_ptr, abs_max_ptr, out_ptr, n);
200
+ } else if (block_size == 128 ) {
201
+ kQuantizeBlockwiseNF4 <DataType_, 128 , 2 ><<<num_blocks, 64 >>> (in_ptr, abs_max_ptr, out_ptr, n);
202
+ } else if (block_size == 64 ) {
203
+ kQuantizeBlockwiseNF4 <DataType_, 64 , 2 ><<<num_blocks, 32 >>> (in_ptr, abs_max_ptr, out_ptr, n);
204
+ }
205
+ return {output, abs_max};
206
+ }
207
+
208
+ std::vector<paddle::Tensor> QuantizeNF4 (const paddle::Tensor& input, int block_size) {
209
+ cout << " QuantizeNF4 begin-------" << endl;
210
+ switch (input.type ()) {
211
+ case paddle::DataType::BFLOAT16: {
212
+ return LaunchQuantizeNF4<paddle::DataType::BFLOAT16>(input, block_size);
213
+ }
214
+ case paddle::DataType::FLOAT16: {
215
+ return LaunchQuantizeNF4<paddle::DataType::FLOAT16>(input, block_size);
216
+ }
217
+ case paddle::DataType::FLOAT32: {
218
+ return LaunchQuantizeNF4<paddle::DataType::FLOAT32>(input, block_size);
219
+ }
220
+ default : {
221
+ PD_THROW (
222
+ " NOT supported data type. "
223
+ " Only bfloat16, float16 and float32 are supported. " );
224
+ break ;
225
+ }
226
+ }
227
+ }
228
+
229
+
230
+
231
+
232
+ PD_BUILD_OP (quantize_nf4)
233
+ .Inputs({" input" })
234
+ .Outputs({" out" , " abs_max" })
235
+ .SetKernelFn(PD_KERNEL(QuantizeNF4));
0 commit comments