1- // Inspired by TRT-LLM .
2- // Modified by Shang Yang and Haotian Tang .
1+ // Inspired by QServe https://github.com/mit-han-lab/qserve/tree/main .
2+ // Modified by Yuming Lou .
33// @article{lin2024awq,
44// title={AWQ: Activation-aware Weight Quantization for On-Device LLM Compression and Acceleration},
55// author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Chen, Wei-Ming and Wang, Wei-Chen and Xiao, Guangxuan and Dang, Xingyu and Gan, Chuang and Han, Song},
1010// }
1111#include < ATen/cuda/CUDAContext.h>
1212#include < torch/extension.h>
13-
1413#include " dispatch_utils.h"
1514#include " utils.cuh"
1615#include " reduction_utils.cuh"
@@ -41,18 +40,19 @@ __inline__ __device__ Tf compute_layernorm(Tf val, float s_mean, float s_varianc
4140 * First pass (loop) computes the mean.
4241 * Second computes the variance via Var[x] = E[(x - E[x])²].
4342 * Third pass computes and writes normed_output
44- *
45- * with USE_DIFF_OF_SQUARES set to true (may be faster but less accurate):
43+ * For better speedup, we set USE_DIFF_OF_SQUARES to true (may be faster but less accurate):
44+ * It turns out the accuracy dosen't drop.
4645 * First pass (loop) computes the mean and variance via Var[x] = E[x²] - E[x]²
4746 * Second pass computes and writes normed_output
47+ *
4848 *
4949 * use_shmem controls if we cache input values into shared memory
5050 *
5151 * Optional: with dynamic scaling, the last pass doesn't write immediately but finds the
5252 * amax per row. A final pass scales to int8 accordingly, and writes output to
5353 * normed_output_quant.
5454 */
55- template <typename T, typename scale_type, bool USE_DIFF_OF_SQUARES = false >
55+ template <typename T, typename scale_type, bool USE_DIFF_OF_SQUARES = true >
5656__global__ void generalLayerNorm (const T* input, const T* gamma, const T* beta, T* normed_output, const float eps,
5757 int tokens, int hidden_dim, const scale_type* scale_orig_quant_per_tensor, scale_type* scale_orig_quant_per_token,
5858 int8_t * normed_output_quant, bool use_shmem)
@@ -74,7 +74,6 @@ __global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta,
7474 float variance = 0 .0f ;
7575 float local_sum = 0 .0f ;
7676 float local_var_sum = 0 .0f ;
77-
7877 const int n_elems = hidden_dim / num_elems_T;
7978 for (int i = tidx; i < n_elems; i += blockDim .x )
8079 {
@@ -83,15 +82,14 @@ __global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta,
8382 {
8483 shmem[i] = val;
8584 }
86-
8785 const float_packed_t val_f = cuda_cast<float_packed_t >(val);
8886 local_sum += cuda_sum<float >(val_f);
8987 if (USE_DIFF_OF_SQUARES)
9088 {
9189 local_var_sum += cuda_sum<float >(val_f * val_f);
9290 }
9391 }
94-
92+ // Compute mean
9593 if (USE_DIFF_OF_SQUARES)
9694 {
9795 float packed[2 ] = {local_sum, local_var_sum};
@@ -116,12 +114,13 @@ __global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta,
116114 }
117115 __syncthreads ();
118116
117+
119118 if (!USE_DIFF_OF_SQUARES)
120119 {
121120 for (int i = tidx; i < n_elems; i += blockDim .x )
122121 {
123122 const T val = use_shmem ? shmem[i] : input[bidx * n_elems + i];
124- float_packed_t diff = cuda_cast<float_packed_t >(val) - s_mean;
123+ float_packed_t diff = cuda_cast<float_packed_t >(val); // - s_mean;
125124 local_var_sum += cuda_sum<float >(diff * diff);
126125 }
127126 variance = blockReduceSum (local_var_sum);
@@ -133,6 +132,7 @@ __global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta,
133132 __syncthreads ();
134133 }
135134
135+ // Compute LN and Quantize
136136 const bool with_per_token_scaling = scale_orig_quant_per_token != nullptr ;
137137 const bool with_per_tensor_scaling = scale_orig_quant_per_tensor != nullptr ;
138138 const float_packed_t scale_orig_quant
@@ -186,51 +186,21 @@ __global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta,
186186 }
187187 }
188188}
189- }
190189
191- // TODO(woosuk): Further optimize this kernel.
192- template <typename scalar_t , typename out_type, bool use_quant>
193- __global__ void
194- rms_norm_kernel (out_type *__restrict__ out, // [..., hidden_size]
195- const scalar_t *__restrict__ input, // [..., hidden_size]
196- const scalar_t *__restrict__ weight, // [hidden_size]
197- const float epsilon, const int num_tokens,
198- const int hidden_size) {
199- __shared__ float s_variance;
200- float variance = 0 .0f ;
201190
202- for (int idx = threadIdx .x ; idx < hidden_size; idx += blockDim .x ) {
203- const float x = (float )input[blockIdx .x * hidden_size + idx];
204- variance += x * x;
205- }
206- variance = blockReduceSum<float >(variance);
207- if (threadIdx .x == 0 ) {
208- s_variance = rsqrtf (variance / hidden_size + epsilon);
209- }
210- __syncthreads ();
211-
212- for (int idx = threadIdx .x ; idx < hidden_size; idx += blockDim .x ) {
213- float x = (float )input[blockIdx .x * hidden_size + idx];
214- if constexpr (use_quant) {
215- out[blockIdx .x * hidden_size + idx] = float_to_int8_rn (
216- ((float )(x * s_variance)) * (float )(weight[idx]));
217- } else {
218- out[blockIdx .x * hidden_size + idx] =
219- ((scalar_t )(x * s_variance)) * weight[idx];
220- }
221- }
222- }
191+ } // namespace vllm
223192
224193void rms_norm_general (torch::Tensor &out, // [..., hidden_size]
225194 torch::Tensor &input, // [..., hidden_size]
226195 torch::Tensor &weight, // [hidden_size]
196+ torch::Tensor &bias, // [hidden_size]
227197 torch::Tensor &scaling, // [tokens] or [1]
228198 float epsilon,
229- bool use_per_token_quant) {
199+ bool use_per_token_quant = true ) {
230200 int hidden_size = input.size (-1 );
231201 int num_tokens = input.numel () / hidden_size;
232202 dim3 grid (num_tokens);
233- dim3 block (std::min (hidden_size, 1024 ));
203+ dim3 block (std::min (hidden_size, 128 ));// Reduce the idle probability of threads
234204 block.x = 32 * ((block.x + 31 ) / 32 );
235205
236206 const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
@@ -240,7 +210,8 @@ void rms_norm_general(torch::Tensor &out, // [..., hidden_size]
240210 // per-token
241211 vllm::generalLayerNorm<T, at::Half><<<grid, block, 0 , stream>>> (
242212 reinterpret_cast <T*>(input.data_ptr <scalar_t >()),
243- reinterpret_cast <T*>(weight.data_ptr <scalar_t >()), nullptr ,
213+ reinterpret_cast <T*>(weight.data_ptr <scalar_t >()),
214+ reinterpret_cast <T*>(bias.data_ptr <scalar_t >()),
244215 nullptr , epsilon, num_tokens, hidden_size, nullptr , scaling.data_ptr <at::Half>(),
245216 out.data_ptr <int8_t >(), false
246217 );
@@ -258,4 +229,4 @@ void rms_norm_general(torch::Tensor &out, // [..., hidden_size]
258229 );
259230 }
260231 });
261- }
232+ }
0 commit comments