Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 85 additions & 1 deletion tests/ut/ops/test_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
from transformers.configuration_utils import PretrainedConfig
from vllm.config import ModelConfig, VllmConfig
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding)

from tests.ut.base import TestBase
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.ops.rotary_embedding import _custom_rotary_embedding_enabled

MODEL = "Qwen3-0.6B"
MODEL_VL = "Qwen/Qwen2.5-VL-3B-Instruct"
MAX_NUM_BATCHED_TOKEND = 10000


Expand Down Expand Up @@ -376,3 +377,86 @@ def test_yarn_get_mscale(self, mock_npuplatform):
expected,
places=6,
msg=f"Failed for scale={scale}, mscale={mscale}")


class TestAscendMRotaryEmbedding(unittest.TestCase):

def setUp(self):
# Common setup for tests
self.number_tokens = 3
self.num_head = 8
self.num_kvhead = 8
self.head_size = 128
self.max_position_embeddings = 128000
self.is_neox_style = True
self.rope_theta = 1000000.0
self.positions_1d = torch.tensor([1, 2, 3])
self.positions_2d = torch.randint(1, 10, (3, self.number_tokens))

self.query = torch.randn(
(self.number_tokens, self.num_head * self.head_size),
dtype=torch.bfloat16)
self.key = torch.randn(
(self.number_tokens, self.num_kvhead * self.head_size),
dtype=torch.bfloat16)

# Qwen2.5-VL mrope section case
self.mrope_section = [16, 24, 24]

self.layer = MRotaryEmbedding(self.head_size,
self.head_size,
self.max_position_embeddings,
base=self.rope_theta,
is_neox_style=self.is_neox_style,
dtype=torch.bfloat16,
mrope_section=self.mrope_section)
Comment on lines +406 to +412
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

critical: The test case is instantiating MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding instead of vllm_ascend.ops.rotary_embedding.AscendMRotaryEmbedding. This will lead to incorrect behavior during testing, as the test will not be using the Ascend-specific implementation.

To fix this, import AscendMRotaryEmbedding from vllm_ascend.ops.rotary_embedding and use that class to instantiate self.layer.

Suggested change
self.layer = MRotaryEmbedding(self.head_size,
self.head_size,
self.max_position_embeddings,
base=self.rope_theta,
is_neox_style=self.is_neox_style,
dtype=torch.bfloat16,
mrope_section=self.mrope_section)
from vllm_ascend.ops.rotary_embedding import AscendMRotaryEmbedding
self.layer = AscendMRotaryEmbedding(self.head_size,
self.head_size,
self.max_position_embeddings,
base=self.rope_theta,
is_neox_style=self.is_neox_style,
dtype=torch.bfloat16,
mrope_section=self.mrope_section)


self.mock_config = MagicMock()
self.mock_config.torchair_graph_config.enabled = False

def _create_vllm_config(self):
vllm_config = VllmConfig()
model_config = ModelConfig(MODEL_VL,
tokenizer=MODEL_VL,
max_model_len=MAX_NUM_BATCHED_TOKEND)
model_config.hf_config = PretrainedConfig()
vllm_config.model_config = model_config
return vllm_config

@patch('torch_npu.npu_mrope')
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
def test_forward_oot_1d_positions(self, mock_npu_mrope):
mock_npu_mrope.return_value = (torch.zeros_like(self.query),
torch.zeros_like(self.key))

vllm_config = self._create_vllm_config()
with set_ascend_forward_context(None, vllm_config):
result_q, result_k = self.layer.forward_oot(
self.positions_1d, self.query, self.key)

mock_npu_mrope.assert_called_once()
self.assertFalse(torch.isnan(result_q).any().item())
self.assertFalse(torch.isnan(result_k).any().item())
self.assertEqual(result_q.shape, self.query.shape)

@patch('torch_npu.npu_mrope')
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
def test_forward_oot_2d_positions(self, mock_npu_mrope):
mock_npu_mrope.return_value = (torch.zeros_like(self.query),
torch.zeros_like(self.key))

vllm_config = self._create_vllm_config()
with set_ascend_forward_context(None, vllm_config):
result_q, result_k = self.layer.forward_oot(
self.positions_2d, self.query, self.key)

mock_npu_mrope.assert_called_once()
self.assertFalse(torch.isnan(result_q).any().item())
self.assertFalse(torch.isnan(result_k).any().item())
self.assertEqual(result_q.shape, self.query.shape)
36 changes: 35 additions & 1 deletion vllm_ascend/ops/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import torch_npu
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding,
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding,
YaRNScalingRotaryEmbedding)

from vllm_ascend.platform import NPUPlatform
Expand Down Expand Up @@ -395,3 +395,37 @@ def forward(self,
q_pe, k_pe = _rope_forward_oot(self, positions, query, key,
is_neox_style, offsets)
return q_pe, k_pe


class AscendMRotaryEmbedding(MRotaryEmbedding):

def forward_oot(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
):
if self.mrope_section != [16, 24, 24]:
return super().forward_oot(positions, query, key)

import torch_npu
mrope_section = [0, 0, 0
] if positions.ndim == 1 else self.mrope_section

if self.cos_sin_cache.device != query.device: # type: ignore
self.cos_sin_cache = self.cos_sin_cache.to( # type: ignore
query.device) # type: ignore

if self.cos_sin_cache.dtype != query.dtype: # type: ignore
self.cos_sin_cache = self.cos_sin_cache.to( # type: ignore
query.dtype) # type: ignore

query, key = torch_npu.npu_mrope(positions,
query.contiguous(),
key.contiguous(),
self.cos_sin_cache.contiguous(),
self.head_size,
mrope_section=mrope_section,
rotary_mode='half')

return query, key
5 changes: 3 additions & 2 deletions vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,8 +537,8 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
AscendReplicatedLinear,
AscendRowParallelLinear)
from vllm_ascend.ops.rotary_embedding import (
AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding,
AscendYaRNRotaryEmbedding)
AscendDeepseekScalingRotaryEmbedding, AscendMRotaryEmbedding,
AscendRotaryEmbedding, AscendYaRNRotaryEmbedding)
from vllm_ascend.ops.vocab_parallel_embedding import (
AscendLogitsProcessor, AscendParallelLMHead,
AscendVocabParallelEmbedding)
Expand All @@ -548,6 +548,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
"QuickGELU": AscendQuickGELU,
"SiluAndMul": AscendSiluAndMul,
"RotaryEmbedding": AscendRotaryEmbedding,
"MRotaryEmbedding": AscendMRotaryEmbedding,
"ColumnParallelLinear": AscendColumnParallelLinear,
"RowParallelLinear": AscendRowParallelLinear,
"YaRNScalingRotaryEmbedding": AscendYaRNRotaryEmbedding,
Expand Down
Loading