diff --git a/src/mcore_bridge/model/mm_gpts/kimi_vl.py b/src/mcore_bridge/model/mm_gpts/kimi_vl.py index df8a7cb..54024fb 100644 --- a/src/mcore_bridge/model/mm_gpts/kimi_vl.py +++ b/src/mcore_bridge/model/mm_gpts/kimi_vl.py @@ -72,7 +72,6 @@ class KimiK25Vit(HuggingFaceVit): module_mapping = {'vision_tower': 'vision_tower', 'mm_projector': 'mm_projector'} _vision_tower = ['vision_tower'] _aligner = ['mm_projector'] - test_mm_type = 'text' def prepare_model(self, hf_config: PretrainedConfig): output = [] @@ -85,10 +84,44 @@ def prepare_model(self, hf_config: PretrainedConfig): self.vision_tower = MoonViT3dPretrainedModel._from_config(vit_config) self.mm_projector = PatchMergerMLP(proj_config).to(self.vision_tower.dtype) + def _encode_images(self, pixel_values, grid_thws): + # vision_tower returns a list of un-projected feature tensors; mm_projector + # (PatchMergerMLP) maps them to the language hidden size. Mirrors + # KimiK25ForConditionalGeneration.forward. + image_features = self.vision_tower(pixel_values, grid_thws) + image_features = self.mm_projector(image_features) + return torch.cat(image_features, dim=0) + def get_inputs_embeds(self, inputs_embeds, **kwargs): - pixel_values = kwargs.pop('pixel_values', None) - if pixel_values is not None: - raise NotImplementedError('Kimi-K25 currently only supports plain text training.') + input_ids = kwargs['input_ids'] + pixel_values = kwargs.get('pixel_values') + grid_thws = kwargs.get('grid_thws') + dtype = next(self.vision_tower.parameters()).dtype + if pixel_values is not None and pixel_values.size(0) > 0: + if grid_thws is None: + raise KeyError('pixel_values present in inputs but grid_thws is missing') + pixel_values = pixel_values.to(device=inputs_embeds.device, dtype=dtype) + grid_thws = grid_thws.to(inputs_embeds.device) + image_features = self._encode_images(pixel_values, grid_thws) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + media_token_id = self.hf_config.media_placeholder_token_id + image_mask = (input_ids == media_token_id).unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_features) + else: + # plain-text batch: still run the vision graph on a dummy image so that + # DP ranks stay in sync during gradient all-reduce. + vision_config = self.hf_config.vision_config + patch_size = vision_config.patch_size + merge_kernel_size = vision_config.merge_kernel_size + kernel = merge_kernel_size[0] if isinstance(merge_kernel_size, (list, tuple)) else merge_kernel_size + h = w = kernel * 2 + dummy_pixels = torch.zeros((h * w, 3, patch_size, patch_size), dtype=dtype, device=inputs_embeds.device) + dummy_grid = input_ids.new_tensor([[1, h, w]]) + image_features = self._encode_images(dummy_pixels, dummy_grid) + # nan_to_num guards against a non-finite value from the all-zero dummy pass + # leaking into the text batch (NaN * 0 == NaN in IEEE-754). + zero_term = torch.nan_to_num(image_features.mean() * 0.) + inputs_embeds = inputs_embeds + zero_term.to(device=inputs_embeds.device, dtype=inputs_embeds.dtype) return inputs_embeds diff --git a/tests/test_mllm.py b/tests/test_mllm.py index 5eb07d2..efb10e9 100644 --- a/tests/test_mllm.py +++ b/tests/test_mllm.py @@ -87,6 +87,10 @@ def test_kimi_vl(): _test_model('moonshotai/Kimi-VL-A3B-Thinking-2506') +def test_kimi_k25(): + _test_model('moonshotai/Kimi-K2.6') + + def test_qwen3_vl(): _test_model('Qwen/Qwen3-VL-4B-Instruct') @@ -134,6 +138,7 @@ def test_gemma4(): # test_glm4_6v_flash() # test_ovis2_5() # test_kimi_vl() + # test_kimi_k25() # test_qwen3_vl() # test_qwen3_vl_moe() # test_qwen3_omni()