Skip to content

Commit f81c164

Browse files
author
Yuming Lou
authored
[Major] More efficient ViT with W8A8 GEMM and kernel fusion(#254)
* [Minor] Fused some kernels * Add act-quant fusion * [Fix] W8A8 GEMM Nan * [Minor] Delete unnecessary files * [Minor] Reset torch version
1 parent 99174c5 commit f81c164

File tree

14 files changed

+375
-655
lines changed

14 files changed

+375
-655
lines changed

awq/kernels/csrc/fused_layernorm/utils.cuh

Lines changed: 0 additions & 469 deletions
This file was deleted.

awq/kernels/csrc/pybind.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
#include "rope_new/fused_rope_with_pos.h"
1111
#include "w8a8/w8a8_gemm_cuda.h"
1212
#include "w8a8/quantization.h"
13-
// #include "fused_layernorm/layernorm.h"
13+
#include "w8a8/layernorm.h"
14+
#include "w8a8/act.h"
1415

1516

1617
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
@@ -29,7 +30,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
2930
m.def("w8a8_gemm_forward_cuda", &w8a8_gemm_forward_cuda, "our w8a8 gemm kernel");
3031
m.def("w8a8_gemm_fuse_bias_forward_cuda", &w8a8_gemm_fuse_bias_forward_cuda, "our w8a8 gemm fused bias kernel");
3132
m.def("invoke_quant", &invoke_quant, "fp16->int8 quantization");
32-
// m.def("rms_norm_general", &rms_norm_general, py::arg("out"), py::arg("input"),
33-
// py::arg("weight"), py::arg("scaling"), py::arg("epsilon"), py::arg("use_per_token_quant") = false,
34-
// "Apply Root Mean Square (RMS) Normalization to the input tensor (TRTLLM kernel).");
33+
m.def("rms_norm_general", &rms_norm_general, py::arg("out"), py::arg("input"),
34+
py::arg("weight"), py::arg("bias"),py::arg("scaling"), py::arg("epsilon"), py::arg("use_per_token_quant") = true,
35+
"Apply Root Mean Square (RMS) Normalization to the input tensor (TRTLLM kernel).");
36+
m.def("silu_and_mul", &silu_and_mul, "Activation function.");
37+
m.def("gelu_and_quant",&gelu_and_quant, "Apply gelu act and quant output");
3538
}

awq/kernels/csrc/w8a8/act.cu

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
#include <ATen/cuda/CUDAContext.h>
2+
#include <torch/extension.h>
3+
#include <cuda_fp16.h>
4+
5+
#include "dispatch_utils.h"
6+
#include "utils.cuh"
7+
#include "reduction_utils.cuh"
8+
9+
namespace vllm {
10+
11+
template <typename T> __device__ __forceinline__ T silu(const T &x) {
12+
// x * sigmoid(x)
13+
return (T)(((float)x) / (1.0f + expf((float)-x)));
14+
}
15+
16+
template <typename T> __device__ __forceinline__ T gelu_new(const T &x) {
17+
const half x3 = (half)(x * x * x);
18+
const T t = (T)tanhf((T)((T)0.79788456f * (half)(x + (T)((T)0.044715f * x3))));
19+
return ((T)0.5) * x * (((T)1.0) + t);
20+
}
21+
22+
template <typename T>
23+
__device__ __forceinline__ T gelu_fast(const T &x) {
24+
const half f = (half)x;
25+
const T t =
26+
(T)tanhf(((T)(f * (T)0.79788456f)) * (((T)1.0) + (T)((T)0.044715f * f) * x));
27+
return ((T)0.5) * x * (((T)1.0) + t);
28+
}
29+
30+
31+
32+
// dequant int32 input, apply silu and mul, then per token quant to int8
33+
template <typename scale_type, bool use_per_token_quant>
34+
__global__ void gelu_and_quant_kernel(
35+
int8_t *__restrict__ out, // [..., d]
36+
half *__restrict__ input, // [..., d]
37+
const int d,
38+
scale_type * scale_out, // [num_tokens]
39+
half *__restrict__ tmp = nullptr // [num_tokens, d]
40+
) {
41+
const int token_idx = blockIdx.x;
42+
const float max_value= 127.0f;
43+
if constexpr (use_per_token_quant) {
44+
float amax_val = 0.0f;
45+
const half zero = 0.0001f;
46+
47+
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
48+
const half x =
49+
(half)__ldg(&input[token_idx * d + idx]);
50+
half t = gelu_fast(x);
51+
tmp[token_idx * d + idx] = t;
52+
t = t > zero ? t : -t;
53+
if ((float)t > amax_val)
54+
amax_val = (float)t;
55+
}
56+
57+
__shared__ float s_amax;
58+
const float block_amax_val = blockReduceMax(amax_val);
59+
if (threadIdx.x == 0) {
60+
s_amax = block_amax_val;
61+
scale_out[token_idx] = half(block_amax_val / max_value);
62+
}
63+
__syncthreads();
64+
65+
float tmp_scale = max_value / s_amax;
66+
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
67+
out[token_idx * d + idx] =
68+
float_to_int8_rn((half)tmp_scale * tmp[token_idx * d + idx]);
69+
}
70+
} else {
71+
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
72+
const float x =
73+
(float)__ldg(&input[token_idx * d + idx]);
74+
out[token_idx * d + idx] = float_to_int8_rn((half)gelu_fast(x) / scale_out[0]);
75+
}
76+
}
77+
}
78+
} // namespace vllm
79+
80+
81+
82+
void gelu_and_quant(
83+
torch::Tensor &out, // [..., d]
84+
torch::Tensor &input, // [..., d]
85+
torch::Tensor &scale_out, // [...]
86+
torch::Tensor &tmp // [num_tokens, d]
87+
) {
88+
int64_t num_tokens = input.numel() / input.size(-1);
89+
int d = input.size(-1);
90+
dim3 grid(num_tokens);
91+
dim3 block(std::min(d, 128));
92+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
93+
vllm::gelu_and_quant_kernel<half, true><<<grid, block, 0, stream>>>(
94+
out.data_ptr<int8_t>(), reinterpret_cast<half *>(input.data_ptr<at::Half>()), d, reinterpret_cast<half *>(scale_out.data_ptr<at::Half>()),reinterpret_cast<half *>(tmp.data_ptr<at::Half>()));
95+
}
96+
97+
98+
99+
namespace vllm {
100+
101+
template<typename scalar_t>
102+
__global__ void silu_and_mul_kernel(
103+
scalar_t* __restrict__ out, // [..., d]
104+
const scalar_t* __restrict__ input, // [..., 2 * d]
105+
const int d) {
106+
107+
const int token_idx = blockIdx.x;
108+
const int64_t token_idx_d = token_idx * int64_t(d);
109+
const int64_t token_idx_2d = token_idx_d * 2;
110+
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
111+
const scalar_t x = __ldg(&input[token_idx_2d + idx]);
112+
const scalar_t y = __ldg(&input[token_idx_2d + d + idx]);
113+
out[token_idx_d + idx] = silu(x) * y;
114+
}
115+
}
116+
} // namespace vllm
117+
118+
119+
120+
torch::Tensor silu_and_mul(
121+
torch::Tensor& input) // [..., 2 * d]
122+
{
123+
int64_t num_tokens = input.numel() / input.size(-1);
124+
int d = input.size(-1) / 2;
125+
126+
std::vector<int64_t> output_shape = input.sizes().vec();
127+
output_shape[output_shape.size() - 1]=d;
128+
auto options =
129+
torch::TensorOptions().dtype(input.dtype()).device(input.device());
130+
at::Tensor output = torch::empty(output_shape, options);
131+
132+
133+
dim3 grid(num_tokens);
134+
dim3 block(std::min(d, 256));
135+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
136+
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_kernel", [&] {
137+
vllm::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
138+
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d);
139+
});
140+
return output;
141+
}

awq/kernels/csrc/w8a8/act.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// Inspired by TRT-LLM.
2+
// Modified by Shang Yang and Haotian Tang.
3+
// @article{lin2024awq,
4+
// title={AWQ: Activation-aware Weight Quantization for On-Device LLM Compression and Acceleration},
5+
// author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Chen, Wei-Ming and Wang, Wei-Chen and Xiao, Guangxuan and Dang, Xingyu and Gan, Chuang and Han, Song},
6+
// journal={Proceedings of Machine Learning and Systems},
7+
// volume={6},
8+
// pages={87--100},
9+
// year={2024}
10+
// }
11+
12+
#include <torch/extension.h>
13+
#include <cuda_fp16.h>
14+
// Inspired by vLLM-SmoothQuant: https://github.com/vllm-project/vllm/pull/1112.
15+
#include <torch/extension.h>
16+
17+
18+
void gelu_and_quant(torch::Tensor &out, // [..., d]
19+
torch::Tensor &input, // [..., d]
20+
torch::Tensor &scale_out, // [num_tokens]
21+
torch::Tensor &tmp // [num_tokens, d]
22+
);
23+
24+
torch::Tensor silu_and_mul(torch::Tensor &input // [..., 2 * d]
25+
);
26+
27+
28+
29+

awq/kernels/csrc/fused_layernorm/layernorm_kernels.cu renamed to awq/kernels/csrc/w8a8/layernorm.cu

Lines changed: 17 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// Inspired by TRT-LLM.
2-
// Modified by Shang Yang and Haotian Tang.
1+
// Inspired by QServe https://github.com/mit-han-lab/qserve/tree/main.
2+
// Modified by Yuming Lou.
33
// @article{lin2024awq,
44
// title={AWQ: Activation-aware Weight Quantization for On-Device LLM Compression and Acceleration},
55
// author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Chen, Wei-Ming and Wang, Wei-Chen and Xiao, Guangxuan and Dang, Xingyu and Gan, Chuang and Han, Song},
@@ -10,7 +10,6 @@
1010
// }
1111
#include <ATen/cuda/CUDAContext.h>
1212
#include <torch/extension.h>
13-
1413
#include "dispatch_utils.h"
1514
#include "utils.cuh"
1615
#include "reduction_utils.cuh"
@@ -41,18 +40,19 @@ __inline__ __device__ Tf compute_layernorm(Tf val, float s_mean, float s_varianc
4140
* First pass (loop) computes the mean.
4241
* Second computes the variance via Var[x] = E[(x - E[x])²].
4342
* Third pass computes and writes normed_output
44-
*
45-
* with USE_DIFF_OF_SQUARES set to true (may be faster but less accurate):
43+
* For better speedup, we set USE_DIFF_OF_SQUARES to true (may be faster but less accurate):
44+
* It turns out the accuracy dosen't drop.
4645
* First pass (loop) computes the mean and variance via Var[x] = E[x²] - E[x]²
4746
* Second pass computes and writes normed_output
47+
*
4848
*
4949
* use_shmem controls if we cache input values into shared memory
5050
*
5151
* Optional: with dynamic scaling, the last pass doesn't write immediately but finds the
5252
* amax per row. A final pass scales to int8 accordingly, and writes output to
5353
* normed_output_quant.
5454
*/
55-
template <typename T, typename scale_type, bool USE_DIFF_OF_SQUARES = false>
55+
template <typename T, typename scale_type, bool USE_DIFF_OF_SQUARES = true>
5656
__global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta, T* normed_output, const float eps,
5757
int tokens, int hidden_dim, const scale_type* scale_orig_quant_per_tensor, scale_type* scale_orig_quant_per_token,
5858
int8_t* normed_output_quant, bool use_shmem)
@@ -74,7 +74,6 @@ __global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta,
7474
float variance = 0.0f;
7575
float local_sum = 0.0f;
7676
float local_var_sum = 0.0f;
77-
7877
const int n_elems = hidden_dim / num_elems_T;
7978
for (int i = tidx; i < n_elems; i += blockDim.x)
8079
{
@@ -83,15 +82,14 @@ __global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta,
8382
{
8483
shmem[i] = val;
8584
}
86-
8785
const float_packed_t val_f = cuda_cast<float_packed_t>(val);
8886
local_sum += cuda_sum<float>(val_f);
8987
if (USE_DIFF_OF_SQUARES)
9088
{
9189
local_var_sum += cuda_sum<float>(val_f * val_f);
9290
}
9391
}
94-
92+
//Compute mean
9593
if (USE_DIFF_OF_SQUARES)
9694
{
9795
float packed[2] = {local_sum, local_var_sum};
@@ -116,12 +114,13 @@ __global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta,
116114
}
117115
__syncthreads();
118116

117+
119118
if (!USE_DIFF_OF_SQUARES)
120119
{
121120
for (int i = tidx; i < n_elems; i += blockDim.x)
122121
{
123122
const T val = use_shmem ? shmem[i] : input[bidx * n_elems + i];
124-
float_packed_t diff = cuda_cast<float_packed_t>(val) - s_mean;
123+
float_packed_t diff = cuda_cast<float_packed_t>(val); // - s_mean;
125124
local_var_sum += cuda_sum<float>(diff * diff);
126125
}
127126
variance = blockReduceSum(local_var_sum);
@@ -133,6 +132,7 @@ __global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta,
133132
__syncthreads();
134133
}
135134

135+
// Compute LN and Quantize
136136
const bool with_per_token_scaling = scale_orig_quant_per_token != nullptr;
137137
const bool with_per_tensor_scaling = scale_orig_quant_per_tensor != nullptr;
138138
const float_packed_t scale_orig_quant
@@ -186,51 +186,21 @@ __global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta,
186186
}
187187
}
188188
}
189-
}
190189

191-
// TODO(woosuk): Further optimize this kernel.
192-
template <typename scalar_t, typename out_type, bool use_quant>
193-
__global__ void
194-
rms_norm_kernel(out_type *__restrict__ out, // [..., hidden_size]
195-
const scalar_t *__restrict__ input, // [..., hidden_size]
196-
const scalar_t *__restrict__ weight, // [hidden_size]
197-
const float epsilon, const int num_tokens,
198-
const int hidden_size) {
199-
__shared__ float s_variance;
200-
float variance = 0.0f;
201190

202-
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
203-
const float x = (float)input[blockIdx.x * hidden_size + idx];
204-
variance += x * x;
205-
}
206-
variance = blockReduceSum<float>(variance);
207-
if (threadIdx.x == 0) {
208-
s_variance = rsqrtf(variance / hidden_size + epsilon);
209-
}
210-
__syncthreads();
211-
212-
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
213-
float x = (float)input[blockIdx.x * hidden_size + idx];
214-
if constexpr (use_quant) {
215-
out[blockIdx.x * hidden_size + idx] = float_to_int8_rn(
216-
((float)(x * s_variance)) * (float)(weight[idx]));
217-
} else {
218-
out[blockIdx.x * hidden_size + idx] =
219-
((scalar_t)(x * s_variance)) * weight[idx];
220-
}
221-
}
222-
}
191+
} // namespace vllm
223192

224193
void rms_norm_general(torch::Tensor &out, // [..., hidden_size]
225194
torch::Tensor &input, // [..., hidden_size]
226195
torch::Tensor &weight, // [hidden_size]
196+
torch::Tensor &bias, // [hidden_size]
227197
torch::Tensor &scaling, // [tokens] or [1]
228198
float epsilon,
229-
bool use_per_token_quant) {
199+
bool use_per_token_quant = true) {
230200
int hidden_size = input.size(-1);
231201
int num_tokens = input.numel() / hidden_size;
232202
dim3 grid(num_tokens);
233-
dim3 block(std::min(hidden_size, 1024));
203+
dim3 block(std::min(hidden_size, 128));//Reduce the idle probability of threads
234204
block.x = 32 * ((block.x + 31) / 32);
235205

236206
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@@ -240,7 +210,8 @@ void rms_norm_general(torch::Tensor &out, // [..., hidden_size]
240210
// per-token
241211
vllm::generalLayerNorm<T, at::Half><<<grid, block, 0, stream>>>(
242212
reinterpret_cast<T*>(input.data_ptr<scalar_t>()),
243-
reinterpret_cast<T*>(weight.data_ptr<scalar_t>()), nullptr,
213+
reinterpret_cast<T*>(weight.data_ptr<scalar_t>()),
214+
reinterpret_cast<T*>(bias.data_ptr<scalar_t>()),
244215
nullptr, epsilon, num_tokens, hidden_size, nullptr, scaling.data_ptr<at::Half>(),
245216
out.data_ptr<int8_t>(), false
246217
);
@@ -258,4 +229,4 @@ void rms_norm_general(torch::Tensor &out, // [..., hidden_size]
258229
);
259230
}
260231
});
261-
}
232+
}

awq/kernels/csrc/fused_layernorm/layernorm.h renamed to awq/kernels/csrc/w8a8/layernorm.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
void rms_norm_general(torch::Tensor &out, // [..., hidden_size]
1515
torch::Tensor &input, // [..., hidden_size]
1616
torch::Tensor &weight, // [hidden_size]
17+
torch::Tensor &bias, // [hidden_size]
1718
torch::Tensor &scaling, // [tokens] or [1]
1819
float epsilon,
1920
bool use_per_token_quant);

0 commit comments

Comments
 (0)