Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update to tune for small ms and quantized gemv #3712

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,46 @@ namespace fbgemm_gpu {
// problem sizes we care about and selected the best elapsed time/bw
// combination. See more in
// deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/fast_gemv/sweep_utils.py
namespace {
dim3 get_best_block_dim(int m, int n, int k) {
if (m == 1 && n == 1280 && k == 8192) {
return dim3(128, 4);
return dim3(128, 2);
} else if (m == 1 && n == 8192 && k == 1024) {
return dim3(32, 8);
return dim3(64, 2);
} else if (m == 1 && n == 7168 && k == 8192) {
return dim3(256, 1);
return dim3(128, 1);
} else if (m == 1 && n == 8192 && k == 3584) {
return dim3(64, 2);
} else if (m == 2 && n == 1280 && k == 8192) {
return dim3(256, 1);
} else if (m == 2 && n == 8192 && k == 1024) {
return dim3(64, 2);
} else if (m == 2 && n == 7168 && k == 8192) {
return dim3(256, 1);
} else if (m == 2 && n == 8192 && k == 3584) {
return dim3(64, 2);
} else if (m == 3 && n == 1280 && k == 8192) {
return dim3(256, 1);
} else if (m == 3 && n == 8192 && k == 1024) {
return dim3(64, 2);
} else if (m == 3 && n == 7168 && k == 8192) {
return dim3(256, 1);
} else if (m == 3 && n == 8192 && k == 3584) {
return dim3(64, 2);
} else if (m == 4 && n == 1280 && k == 8192) {
return dim3(256, 1);
} else if (m == 4 && n == 8192 && k == 1024) {
return dim3(64, 2);
} else if (m == 4 && n == 7168 && k == 8192) {
return dim3(128, 1);
} else if (m == 4 && n == 8192 && k == 3584) {
return dim3(64, 2);
} else {
// Default block dimensions
return dim3(32, 4);
}
}
} // namespace

at::Tensor bf16_fast_gemv(at::Tensor X, at::Tensor W) {
// X: M x K
Expand Down Expand Up @@ -62,6 +88,8 @@ at::Tensor bf16_fast_gemv(at::Tensor X, at::Tensor W) {
reinterpret_cast<__nv_bfloat16*>(X.data_ptr()), // vec
reinterpret_cast<__nv_bfloat16*>(Y.data_ptr()), // res
k,
m,
n,
num_per_thread);

C10_CUDA_KERNEL_LAUNCH_CHECK();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,37 @@ namespace fbgemm_gpu {
namespace {
dim3 get_best_block_dim(int m, int n, int k) {
if (m == 1 && n == 1280 && k == 8192) {
return dim3(128, 1);
return dim3(128, 2);
} else if (m == 1 && n == 8192 && k == 1024) {
return dim3(32, 4);
return dim3(32, 8);
} else if (m == 1 && n == 7168 && k == 8192) {
return dim3(128, 1);
} else if (m == 1 && n == 8192 && k == 3584) {
return dim3(64, 2);
} else if (m == 2 && n == 1280 && k == 8192) {
return dim3(128, 1);
} else if (m == 2 && n == 8192 && k == 1024) {
return dim3(32, 8);
} else if (m == 2 && n == 7168 && k == 8192) {
return dim3(128, 1);
} else if (m == 2 && n == 8192 && k == 3584) {
return dim3(64, 2);
} else if (m == 3 && n == 1280 && k == 8192) {
return dim3(128, 2);
} else if (m == 3 && n == 8192 && k == 1024) {
return dim3(32, 8);
} else if (m == 3 && n == 7168 && k == 8192) {
return dim3(128, 1);
} else if (m == 3 && n == 8192 && k == 3584) {
return dim3(64, 2);
} else if (m == 4 && n == 1280 && k == 8192) {
return dim3(128, 2);
} else if (m == 4 && n == 8192 && k == 1024) {
return dim3(32, 8);
} else if (m == 4 && n == 7168 && k == 8192) {
return dim3(128, 1);
} else if (m == 4 && n == 8192 && k == 3584) {
return dim3(64, 2);
} else {
// Default block dimensions
return dim3(32, 4);
Expand Down Expand Up @@ -65,6 +89,8 @@ bf16fp8bf16_fast_gemv(at::Tensor X, at::Tensor W, at::Tensor w_scale) {
reinterpret_cast<__nv_bfloat16*>(X.data_ptr()), // vec
reinterpret_cast<__nv_bfloat16*>(Y.data_ptr()), // res
k,
m,
n,
reinterpret_cast<float const*>(w_scale.data_ptr()),
num_per_thread);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,57 +17,159 @@

namespace fbgemm_gpu {

using SizeType32 = std::size_t;

// The heuristics are derived by sweeping over 4 different
// problem sizes we care about and selected the best elapsed time/bw
// combination. See more in
// deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/fast_gemv/sweep_utils.py
namespace {
dim3 get_best_block_dim(int m, int n, int k) {
if (m == 1 && n == 1280 && k == 8192) {
return dim3(128, 1);
return dim3(256, 1);
} else if (m == 1 && n == 8192 && k == 1024) {
return dim3(32, 32);
} else if (m == 1 && n == 7168 && k == 8192) {
return dim3(128, 1);
} else if (m == 1 && n == 7168 && k == 8192) {
return dim3(256, 1);
} else if (m == 1 && n == 8192 && k == 3584) {
return dim3(64, 2);
return dim3(128, 1);
} else if (m == 2 && n == 1280 && k == 8192) {
return dim3(128, 1);
} else if (m == 2 && n == 8192 && k == 1024) {
return dim3(64, 1);
} else if (m == 2 && n == 7168 && k == 8192) {
return dim3(256, 1);
} else if (m == 2 && n == 8192 && k == 3584) {
return dim3(128, 1);
} else if (m == 3 && n == 1280 && k == 8192) {
return dim3(128, 1);
} else if (m == 3 && n == 8192 && k == 1024) {
return dim3(64, 1);
} else if (m == 3 && n == 7168 && k == 8192) {
return dim3(128, 1);
} else if (m == 3 && n == 8192 && k == 3584) {
return dim3(128, 1);
} else if (m == 4 && n == 1280 && k == 8192) {
return dim3(128, 1);
} else if (m == 4 && n == 8192 && k == 1024) {
return dim3(64, 1);
} else if (m == 4 && n == 7168 && k == 8192) {
return dim3(128, 1);
} else if (m == 4 && n == 8192 && k == 3584) {
return dim3(128, 1);
} else {
// Default block dimensions
return dim3(32, 4);
return dim3(32, 1);
}
}
} // namespace

at::Tensor fp8fp8bf16_fast_gemv(at::Tensor X, at::Tensor W, at::Tensor scale) {
// X: M x K
// W: N x K
auto m = X.size(0);
auto n = W.size(0);
auto k = W.size(1);
template <SizeType32 TILE_M, SizeType32 TILE_N>
void fp8fp8FastGemvKernel(
cutlass::float_e4m3_t* mat,
cutlass::float_e4m3_t* vec,
__nv_bfloat16* res,
const unsigned int k,
const unsigned int m,
const unsigned int n,
float const* scale) {
// each threadblock handles TILE_M * TILE_N dot products in the resulting
// matrix.
// block_size is represented as (block_dim.x, block_dim.y).
// grid_dim is accordingly calculated based on the number of threadblocks
// needed to cover the given problem size
dim3 block_dim = get_best_block_dim(m, n, k);
dim3 grid_dim(m / TILE_M, n / TILE_N * block_dim.y);
// total number of memory loads needed per thread
unsigned int num_iter_per_thread = ((k >> 4) + block_dim.x - 1) / block_dim.x;

check_if_valid_input_dimensions_fp8fp8(m, n, k, TILE_N, block_dim);

TORCH_CHECK(X.is_cuda() && X.is_contiguous());
TORCH_CHECK(W.is_cuda() && W.is_contiguous());
auto stream = at::cuda::getCurrentCUDAStream();

dim3 block_dim = get_best_block_dim(m, n, k);
if (block_dim.x == 128) {
gemv_quantized_fp8_fp8<TILE_M, TILE_N, 128>
<<<grid_dim, block_dim, 0, stream>>>(
mat, vec, res, k, m, n, scale, num_iter_per_thread);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else if (block_dim.x == 64) {
gemv_quantized_fp8_fp8<TILE_M, TILE_N, 64>
<<<grid_dim, block_dim, 0, stream>>>(
mat, vec, res, k, m, n, scale, num_iter_per_thread);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else if (block_dim.x == 256) {
gemv_quantized_fp8_fp8<TILE_M, TILE_N, 256>
<<<grid_dim, block_dim, 0, stream>>>(
mat, vec, res, k, m, n, scale, num_iter_per_thread);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
gemv_quantized_fp8_fp8<TILE_M, TILE_N, 32>
<<<grid_dim, block_dim, 0, stream>>>(
mat, vec, res, k, m, n, scale, num_iter_per_thread);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}

template <SizeType32 TILE_M, SizeType32 TILE_N>
bool fastGemvTemplateCaller(
cutlass::float_e4m3_t* mat,
cutlass::float_e4m3_t* vec,
__nv_bfloat16* res,
const unsigned int k,
const unsigned int m,
const unsigned int n,
float const* scale) {
if (m == TILE_M) {
fp8fp8FastGemvKernel<TILE_M, TILE_N>(mat, vec, res, k, m, n, scale);
return true;
}

check_if_valid_block_dimensions(m, n, k, block_dim);
if constexpr (TILE_M < MAX_M_SIZE) {
return fastGemvTemplateCaller<TILE_M + 1, TILE_N>(
mat, vec, res, k, m, n, scale);
}
return false;
}

dim3 grid_dim(1, n / block_dim.y);
unsigned int num_per_thread = k / block_dim.x;
bool fastGemvLauncher(
cutlass::float_e4m3_t* mat,
cutlass::float_e4m3_t* vec,
__nv_bfloat16* res,
const unsigned int k,
const unsigned int m,
const unsigned int n,
float const* scale) {
// Note: based on sweeping result, heuristic TILE_N = 2 here gives best
// performance over larger TILE_N value. this is potentially because smaller
// tile_n leads to more threadblocks and thus increase the block concurrency.
return fastGemvTemplateCaller</* TILE_M=*/1, /* TILE_N=*/2>(
mat, vec, res, k, m, n, scale);
}

auto stream = at::cuda::getCurrentCUDAStream();
at::Tensor
fp8fp8bf16_fast_gemv(at::Tensor XQ, at::Tensor WQ, at::Tensor scale) {
const unsigned int m = XQ.size(0);
const unsigned int n = WQ.size(0);
const unsigned int k = WQ.size(1);

TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous());
TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous());
TORCH_CHECK(XQ.size(-1) == k);

auto Y = at::empty({m, n}, X.options().dtype(at::kBFloat16));
auto Y = at::empty({m, n}, XQ.options().dtype(at::kBFloat16));

gemv_quantized_fp8_fp8<<<grid_dim, block_dim, 0, stream>>>(
reinterpret_cast<cutlass::float_e4m3_t*>(W.data_ptr()), // mat
reinterpret_cast<cutlass::float_e4m3_t*>(X.data_ptr()), // vec
bool dispatched = fastGemvLauncher(
reinterpret_cast<cutlass::float_e4m3_t*>(WQ.data_ptr()), // mat
reinterpret_cast<cutlass::float_e4m3_t*>(XQ.data_ptr()), // vec
reinterpret_cast<__nv_bfloat16*>(Y.data_ptr()), // res
k,
reinterpret_cast<float const*>(scale.data_ptr()),
num_per_thread);
m,
n,
reinterpret_cast<float const*>(scale.data_ptr()));

C10_CUDA_KERNEL_LAUNCH_CHECK();
if (!dispatched) {
throw std::runtime_error("f8f8bf16_fast_gemv cannot run.");
}

return Y;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ namespace fbgemm_gpu {

namespace {

using SizeType32 = std::size_t;

void check_if_valid_block_dimensions(int m, int n, int k, dim3 block_dim) {
TORCH_CHECK(
n % block_dim.y == 0,
Expand Down Expand Up @@ -82,5 +84,82 @@ void check_if_valid_block_dimensions(int m, int n, int k, dim3 block_dim) {
block_dim.y,
".");
}
void check_if_valid_input_dimensions_fp8fp8(
int m,
int n,
int k,
SizeType32 TILE_N,
dim3 block_dim) {
TORCH_CHECK(
m <= 4,
"Invalid value for m: m (",
m,
") must not be greater than 4. The kernel cannot be run with the current value of m."
" Please use an `m` smaller or equal to 4.");
TORCH_CHECK(
k % 16 == 0,
"Invalid value for k: (",
k,
") must be divisible by 16.",
" Please use a `k` that is divisble by 16, "
" All current params - m: ",
m,
", n: ",
n,
", k: ",
k,
", block_dim.x: ",
block_dim.x,
", block_dim.y: ",
block_dim.y,
".");
TORCH_CHECK(
k % block_dim.x == 0,
"Invalid block dimensions: k (",
k,
") must be divisible by block_dim.x (",
block_dim.x,
"). Received k: ",
k,
", block_dim.x: ",
block_dim.x,
" Please either use a `k` which is divisible by `block_dim.x`, or update "
"`get_best_block_dim()` heuristics to choose another `block_dim.x`."
" All current params - m: ",
m,
", n: ",
n,
", k: ",
k,
", block_dim.x: ",
block_dim.x,
", block_dim.y: ",
block_dim.y,
".");
TORCH_CHECK(
n % (TILE_N * block_dim.y) == 0,
"Invalid block dimensions: n (",
n,
") must be divisible by TILE_N * block_dim.y (",
TILE_N * block_dim.y,
"). Received n: ",
n,
", block_dim.y: ",
block_dim.y,
", TILE_N: ",
TILE_N,
" Please use a `n` which is divisible by `TILE_N * block_dim.y`,"
" All current params - m: ",
m,
", n: ",
n,
", k: ",
k,
", block_dim.x: ",
block_dim.x,
", block_dim.y: ",
block_dim.y,
".");
}
} // namespace
} // namespace fbgemm_gpu
Loading
Loading