Skip to content

Commit c131985

Browse files
committed
[AMD][ROCm] Improve support of AMD
The patch delivers several fixes for building issues for CUDA part of DeepSpeed library. Percentage of passed unit tests improved(tested on RDNA hardware, gfx110x and gfx12x) Before: collected 5298 items / 15 skipped 2773 failed, 862 passed, 1665 skipped, 13 errors After: collected 5851 items / 11 skipped 4187 failed, 1373 passed, 292 skipped, 10 errors Signed-off-by: Artem Kuzmitckii <[email protected]>
1 parent 3292e07 commit c131985

File tree

9 files changed

+35
-14
lines changed

9 files changed

+35
-14
lines changed

csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ struct call_conditional<false, TA, TB> {
233233

234234
CUTLASS_DEVICE int32_t warp_uniform(int32_t value)
235235
{
236-
return (int32_t)__shfl_sync(0xffffffff, (unsigned)value, 0);
236+
return (int32_t)__shfl_sync(static_cast<uint64_t>(0xffffffff), (unsigned)value, 0);
237237
}
238238

239239
template <typename T>

csrc/fp_quantizer/fp_quantize.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "fp_quantize.h"
77

88
#include <c10/cuda/CUDAStream.h>
9+
#include <hip/hip_fp16.h>
910
#include <torch/extension.h>
1011
#include <vector>
1112

csrc/fp_quantizer/fp_quantize.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
// DeepSpeed Team
55

66
#include <stdexcept>
7-
#include "context.h"
7+
#include "fp_context.h"
88
#include "fp_quantize.h"
99
#include "memory_access_utils.h"
1010
#include "reduction_utils.h"
@@ -14,6 +14,7 @@
1414

1515
#include <cuda_fp16.h>
1616
#include <curand_kernel.h>
17+
#include <hip/hip_fp16.h>
1718

1819
#ifdef BF16_AVAILABLE
1920
#include <cuda_bf16.h>
File renamed without changes.

csrc/includes/reduction_utils.h

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -526,12 +526,28 @@ here (fold is C++17 only and I don't think helps and recursion feels like
526526
huge overkill that harms readability) that would be wonderful.
527527
*/
528528

529+
template <typename T>
530+
DS_D_INLINE T shfl_xor_helper(cg::thread_block_tile<hw_warp_size>& warp, const T& value, int i)
531+
{
532+
return warp.shfl_xor(value, i);
533+
}
534+
535+
#if defined(__HIP_PLATFORM_AMD__)
536+
template <>
537+
DS_D_INLINE __half shfl_xor_helper<__half>(cg::thread_block_tile<hw_warp_size>& warp,
538+
const __half& value,
539+
int i)
540+
{
541+
return __half(warp.shfl_xor(float(value), i));
542+
}
543+
#endif
544+
529545
template <typename T, ROpType Op, int reduce_width = hw_warp_size>
530546
DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, T* data)
531547
{
532548
#pragma unroll
533549
for (int i = 1; i < reduce_width; i *= 2) {
534-
data[0] = element<Op>(data[0], warp.shfl_xor(data[0], i));
550+
data[0] = element<Op>(data[0], shfl_xor_helper(warp, data[0], i));
535551
}
536552
}
537553

@@ -540,8 +556,8 @@ DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, T* data)
540556
{
541557
#pragma unroll
542558
for (int i = 1; i < reduce_width; i *= 2) {
543-
data[0] = element<Op1>(data[0], warp.shfl_xor(data[0], i));
544-
data[1] = element<Op2>(data[1], warp.shfl_xor(data[1], i));
559+
data[0] = element<Op1>(data[0], shfl_xor_helper(warp, data[0], i));
560+
data[1] = element<Op2>(data[1], shfl_xor_helper(warp, data[0], i));
545561
}
546562
}
547563

@@ -550,9 +566,9 @@ DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, T* data)
550566
{
551567
#pragma unroll
552568
for (int i = 1; i < reduce_width; i *= 2) {
553-
data[0] = element<Op1>(data[0], warp.shfl_xor(data[0], i));
554-
data[1] = element<Op2>(data[1], warp.shfl_xor(data[1], i));
555-
data[2] = element<Op3>(data[2], warp.shfl_xor(data[2], i));
569+
data[0] = element<Op1>(data[0], shfl_xor_helper(warp, data[0], i));
570+
data[1] = element<Op2>(data[1], shfl_xor_helper(warp, data[0], i));
571+
data[2] = element<Op3>(data[2], shfl_xor_helper(warp, data[0], i));
556572
}
557573
}
558574

@@ -566,10 +582,10 @@ DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, T* data)
566582
{
567583
#pragma unroll
568584
for (int i = 1; i < reduce_width; i *= 2) {
569-
data[0] = element<Op1>(data[0], warp.shfl_xor(data[0], i));
570-
data[1] = element<Op2>(data[1], warp.shfl_xor(data[1], i));
571-
data[2] = element<Op3>(data[2], warp.shfl_xor(data[2], i));
572-
data[3] = element<Op4>(data[3], warp.shfl_xor(data[3], i));
585+
data[0] = element<Op1>(data[0], shfl_xor_helper(warp, data[0], i));
586+
data[1] = element<Op2>(data[1], shfl_xor_helper(warp, data[0], i));
587+
data[2] = element<Op3>(data[2], shfl_xor_helper(warp, data[0], i));
588+
data[3] = element<Op4>(data[3], shfl_xor_helper(warp, data[0], i));
573589
}
574590
}
575591

deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_paralleldequant.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ __device__ __forceinline__ void ExtractFromSharedToReg_Scales(uint32_t* Scales,
120120
#pragma unroll
121121
for (int i = 0; i < 4; i++) {
122122
// T __shfl_sync(unsigned mask, T var, int srcLane, int width=warpSize);
123-
Scales[i] = __shfl_sync(0xffffffff, tmpReg, i, 4);
123+
Scales[i] = __shfl_sync(static_cast<uint64_t>(0xffffffff), tmpReg, i, 4);
124124
}
125125
}
126126

deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels_cuda.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ static void Kernel_Ex(cudaStream_t stream,
4545
static size_t SHMEM_SZ =
4646
max(TilingConfig::SMEM_SIZE_B_TILE + SMEM_SIZE_A1_TILE + SMEM_SIZE_A2_TILE,
4747
TilingConfig::SMEM_SIZE_C_TILE);
48-
cudaFuncSetAttribute(QUANT_GEMM_Kernel<TilingConfig, OutputDataType>,
48+
auto kernel = QUANT_GEMM_Kernel<TilingConfig, OutputDataType>;
49+
cudaFuncSetAttribute(reinterpret_cast<const void*>(kernel),
4950
cudaFuncAttributeMaxDynamicSharedMemorySize,
5051
SHMEM_SZ);
5152
size_t dimN = (N_Global - 1) / TilingConfig::TILE_N + 1;

deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
// DeepSpeed Team
55

66
#include <c10/cuda/CUDAStream.h>
7+
#include <hip/hip_bf16.h>
78
#include "mixed_gemm.h"
89
#include "mixed_gemm_api.h"
910
#include "weight_variant.h"

deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
// DeepSpeed Team
55

66
#include <c10/cuda/CUDAStream.h>
7+
#include <hip/hip_bf16.h>
78
#include "moe_gemm.h"
89
#include "moe_gemm_api.h"
910
#include "weight_variant.h"

0 commit comments

Comments
 (0)