Skip to content

Commit 12fbd58

Browse files
hukongyiliuchenbingguanyuzhu
committed
[BugFix]This PR aims to fix the precision issue of the LoRA feature in vllm-ascend.
Co-authored-by: liuchenbing <[email protected]> Co-authored-by: guanyuzhu <[email protected]> vLLM version: v0.11.0 vLLM main: vllm-project/vllm signed-off-by: hukongyi <[email protected]>
1 parent 25534b7 commit 12fbd58

File tree

4 files changed

+12
-4
lines changed

4 files changed

+12
-4
lines changed

csrc/kernels/bgmv_expand.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ class BGMVExpand {
342342

343343
// declare all dtype kernel
344344
BGMV_EXPAND_TYPE_DECLARE(half)
345-
#if (__CCE_AICORE__ >= 220)
345+
#if !defined(__CCE_AICORE__) || (__CCE_AICORE__ >= 220)
346346
BGMV_EXPAND_TYPE_DECLARE(bfloat16_t)
347347
#endif
348348

@@ -356,9 +356,11 @@ extern void bgmv_expand_impl(AscendType type, void* stream, void* x, void* weigh
356356
bgmv_expand_half<<<blockDim, nullptr, stream>>>(x, weight, indices, indicesSize, yIn, yOut, batchSize, numTokensPerCore,
357357
maxLoRARank, outputHiddenDim, sliceOffset, outputFullDim);
358358
} else if (type == AscendType::BF16) {
359+
#if !defined(__CCE_AICORE__) || (__CCE_AICORE__ >= 220)
359360
bgmv_expand_bfloat16_t<<<blockDim, nullptr, stream>>>(x, weight, indices, indicesSize, yIn, yOut, batchSize,
360361
numTokensPerCore, maxLoRARank, outputHiddenDim,
361362
sliceOffset, outputFullDim);
363+
#endif
362364
} else {
363365
return;
364366
}

csrc/kernels/bgmv_shrink.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ class BGMVShrink {
226226

227227
// declare all dtype kernel
228228
BGMV_SHRINK_TYPE_DECLARE(half)
229-
#if (__CCE_AICORE__ >= 220)
229+
#if !defined(__CCE_AICORE__) || (__CCE_AICORE__ >= 220)
230230
BGMV_SHRINK_TYPE_DECLARE(bfloat16_t)
231231
#endif
232232

@@ -240,8 +240,10 @@ extern void bgmv_shrink_impl(AscendType type, void* stream, void* x, void* weigh
240240
bgmv_shrink_half<<<blockDim, nullptr, stream>>>(x, weight, indices, indicesSize, y, batchSize, numTokensPerCore,
241241
inputHiddenDim, maxLoRARank, scale);
242242
} else if (type == AscendType::BF16) {
243+
#if !defined(__CCE_AICORE__) || (__CCE_AICORE__ >= 220)
243244
bgmv_shrink_bfloat16_t<<<blockDim, nullptr, stream>>>(x, weight, indices, indicesSize, y, batchSize, numTokensPerCore,
244245
inputHiddenDim, maxLoRARank, scale);
246+
#endif
245247
} else {
246248
return;
247249
}

csrc/kernels/sgmv_expand.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ class SGMVExpand {
357357

358358
// declare all dtype kernel
359359
SGMV_EXPAND_TYPE_DECLARE(half)
360-
#if (__CCE_AICORE__ >= 220)
360+
#if !defined(__CCE_AICORE__) || (__CCE_AICORE__ >= 220)
361361
SGMV_EXPAND_TYPE_DECLARE(bfloat16_t)
362362
#endif
363363

@@ -375,10 +375,12 @@ extern void sgmv_expand_impl(AscendType type, void* stream, void* x, void* weigh
375375
numTokensPerCore, maxLoRARank, outputHiddenDim, sliceOffset,
376376
outputFullDim);
377377
} else if (type == AscendType::BF16) {
378+
#if !defined(__CCE_AICORE__) || (__CCE_AICORE__ >= 220)
378379
sgmv_expand_bfloat16_t<<<blockDim, nullptr, stream>>>(x, weight, loraIndices, loraIndicesSize,
379380
seqLen, seqLenSize, yIn, yOut, batchSize,
380381
numTokensPerCore, maxLoRARank, outputHiddenDim,
381382
sliceOffset, outputFullDim);
383+
#endif
382384
} else {
383385
return;
384386
}

csrc/kernels/sgmv_shrink.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ class SGMVShrink {
242242

243243
// declare all dtype kernel
244244
SGMV_SHRINK_TYPE_DECLARE(half)
245-
#if (__CCE_AICORE__ >= 220)
245+
#if !defined(__CCE_AICORE__) || (__CCE_AICORE__ >= 220)
246246
SGMV_SHRINK_TYPE_DECLARE(bfloat16_t)
247247
#endif
248248

@@ -260,11 +260,13 @@ extern void sgmv_shrink_impl(AscendType type, void* stream, void* x, void* weigh
260260
numTokensPerCore, inputHiddenDim, maxLoRARank,
261261
scale);
262262
} else if (type == AscendType::BF16) {
263+
#if !defined(__CCE_AICORE__) || (__CCE_AICORE__ >= 220)
263264
sgmv_shrink_bfloat16_t<<<blockDim, nullptr, stream>>>(x, weight, loraIndices, loraIndicesSize,
264265
seqLen, seqLenSize,
265266
y, batchSize,
266267
numTokensPerCore, inputHiddenDim, maxLoRARank,
267268
scale);
269+
#endif
268270
} else {
269271
return;
270272
}

0 commit comments

Comments
 (0)