From f432a5c3ba0335da6dea91ad4d4101302d54c14f Mon Sep 17 00:00:00 2001 From: liyuhang Date: Wed, 10 Dec 2025 03:20:06 +0000 Subject: [PATCH 1/6] [model] add glm-asr support --- convert_hf_to_gguf.py | 83 ++++++++++++++++++++++++++++++++++++--- gguf-py/gguf/constants.py | 1 + tools/mtmd/clip-impl.h | 2 + tools/mtmd/clip.cpp | 44 ++++++++++++++++++++- 4 files changed, 123 insertions(+), 7 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 2cc2a38823..bab1e8b352 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -736,9 +736,10 @@ def __init__(self, *args, **kwargs): else: self.hf_arch = "" - if "text_config" in self.hparams: + llm_config_key = "lm_config" if "lm_config" in self.hparams else "text_config" + if llm_config_key in self.hparams: # move the text_config to the root level - self.hparams = {**self.hparams, **self.hparams["text_config"]} + self.hparams = {**self.hparams, **self.hparams[llm_config_key]} self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"]) self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) @@ -1604,7 +1605,7 @@ class MmprojModel(ModelBase): preprocessor_config: dict[str, Any] global_config: dict[str, Any] - n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"] + n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers"] has_vision_encoder: bool = True # by default has_audio_encoder: bool = False @@ -1621,11 +1622,12 @@ def __init__(self, *args, **kwargs): # get n_embd of the text model if not self.is_mistral_format: - if "text_config" not in self.hparams: + llm_config_key = "lm_config" if "lm_config" in self.hparams else "text_config" + if llm_config_key not in self.hparams: self.hparams["text_config"] = {} if "audio_config" not in self.hparams: self.hparams["audio_config"] = {} - text_config = {**self.hparams, **self.hparams["text_config"]} + text_config = {**self.hparams, **self.hparams[llm_config_key]} self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0)) else: text_config = { @@ -1680,7 +1682,8 @@ def get_vision_config(self) -> dict[str, Any] | None: return self.global_config.get(config_name) def get_audio_config(self) -> dict[str, Any] | None: - return self.global_config.get("audio_config") + mm_config_key = "whisper_config" if "whisper_config" in self.hparams else "audio_config" + return self.global_config.get(mm_config_key) def set_type(self): self.gguf_writer.add_type(gguf.GGUFType.MMPROJ) @@ -2356,6 +2359,7 @@ def prepare_tensors(self): "VLlama3ForCausalLM", "LlavaForConditionalGeneration", "VoxtralForConditionalGeneration", + "GlmasrModel", "LlamaModel") class LlamaModel(TextModel): model_arch = gguf.MODEL_ARCH.LLAMA @@ -2407,6 +2411,17 @@ def set_vocab(self): # Apply to granite small models only if self.hparams.get("vocab_size", 32000) == 49152: self.gguf_writer.add_add_bos_token(False) + + if isinstance(self.hparams.get("eos_token_id"), list): + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) + special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab.add_to_gguf(self.gguf_writer) + special_vocab.chat_template = "glmedge" def set_gguf_parameters(self): super().set_gguf_parameters() @@ -2443,6 +2458,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter "vision_language_adapter.", "patch_merger.", "pre_mm_projector_norm", + "audio_encoder.", ] is_multimodal_tensor = "vision_tower" in name \ @@ -8998,6 +9014,61 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) raise NotImplementedError("Ultravox does not have text decoder. Instead, it uses Llama or other models for text. If you want to get the audio encoder, please use --mmproj argument") +@ModelBase.register("GlmasrModel") +class GlmASRWhisperEncoderModel(MmprojModel): + has_vision_encoder = False + has_audio_encoder = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if "hidden_size" not in self.hparams and "intermediate_size" not in self.hparams: + self.hparams["hidden_size"] = self.hparams["d_model"] + self.hparams["intermediate_size"] = self.hparams["encoder_ffn_dim"] + self.hparams["num_attention_heads"] = self.hparams["encoder_attention_heads"] + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GLMA) + self.gguf_writer.add_audio_num_mel_bins(self.hparams["num_mel_bins"]) + self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-5)) + + def tensor_force_quant(self, name, new_name, bid, n_dims): + if ".conv" in name and ".weight" in name: + return gguf.GGMLQuantizationType.F16 + return super().tensor_force_quant(name, new_name, bid, n_dims) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + if name.startswith("model.") or name.startswith("lm_head."): + # skip language model tensors + return [] + + if name.startswith("audio_encoder.whisper."): + name = name.replace("audio_encoder.whisper.","audio_tower.") + if "audio_encoder.layer_norm." in name or "audio_encoder.proj." in name: + name = name.replace("audio_encoder.", "audio_encoder.adapting.") + + if name.startswith("audio_encoder.audio_bos_eos_token."): + return [(self.map_tensor_name("model.vision.boi"), data_torch[0]), (self.map_tensor_name("model.vision.eoi"), data_torch[1])] + + if name.startswith("audio_encoder.adapting."): + name = name.replace("audio_encoder.adapting.","audio.multi_modal_projector.") + if ".layer_norm." in name: + name = name.replace(".layer_norm.", ".ln_pre.") + if ".0." in name: + name = name.replace(".0.", ".linear_1.") + if ".2." in name: + name = name.replace(".2.", ".linear_2.") + if ".proj." in name: + print("skip proj") + return [] + + if "conv1.bias" in name or "conv2.bias" in name: + # transpose conv1 and conv2 bias + data_torch = data_torch.unsqueeze(-1) + + return [(self.map_tensor_name(name), data_torch)] @ModelBase.register("Qwen2AudioForConditionalGeneration") class WhisperEncoderModel(MmprojModel): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 2b8489c591..8ef4a23a10 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -3320,6 +3320,7 @@ class VisionProjectorType: ULTRAVOX = "ultravox" INTERNVL = "internvl" QWEN2A = "qwen2a" # audio + GLMA = "glma" # audio QWEN25O = "qwen2.5o" # omni VOXTRAL = "voxtral" LFM2 = "lfm2" diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index cd47865bf4..93153226d4 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -149,6 +149,7 @@ enum projector_type { PROJECTOR_TYPE_INTERNVL, PROJECTOR_TYPE_LLAMA4, PROJECTOR_TYPE_QWEN2A, + PROJECTOR_TYPE_GLMA, PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx PROJECTOR_TYPE_VOXTRAL, PROJECTOR_TYPE_LFM2, @@ -175,6 +176,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_INTERNVL, "internvl"}, { PROJECTOR_TYPE_LLAMA4, "llama4"}, { PROJECTOR_TYPE_QWEN2A, "qwen2a"}, + { PROJECTOR_TYPE_GLMA, "glma"}, { PROJECTOR_TYPE_QWEN25O, "qwen2.5o"}, { PROJECTOR_TYPE_VOXTRAL, "voxtral"}, { PROJECTOR_TYPE_LFM2, "lfm2"}, diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 3ed08a0fec..32318aec5d 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -388,6 +388,7 @@ struct clip_model { ggml_tensor * conv1d_2_w = nullptr; ggml_tensor * conv1d_2_b = nullptr; ggml_tensor * mm_norm_pre_w = nullptr; + ggml_tensor * mm_norm_pre_b = nullptr; ggml_tensor * mm_norm_mid_w = nullptr; // cogvlm @@ -1829,7 +1830,6 @@ struct clip_graph { GGML_ASSERT(model.layers[0].q_b); GGML_ASSERT(model.layers[0].v_b); GGML_ASSERT(!model.layers[0].k_b); // no bias for k - GGML_ASSERT(model.post_ln_w && model.post_ln_b); ggml_tensor * pos_embd_selected = ggml_view_2d( ctx0, model.position_embeddings, @@ -1891,6 +1891,18 @@ struct clip_graph { cur = ggml_gelu_erf(ctx0, cur); cur = ggml_mul_mat(ctx0, model.mm_2_w, cur); + } else if (ctx->proj_type() == PROJECTOR_TYPE_GLMA) { + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w); + cur = ggml_add(ctx0, cur, model.mm_norm_pre_b); + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0] * 4, cur->ne[1] / 4); + cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); + cur = ggml_add(ctx0, cur, model.mm_1_b); + cur = ggml_gelu_erf(ctx0, cur); + cur = ggml_mul_mat(ctx0, model.mm_2_w, cur); + cur = ggml_add(ctx0, cur, model.mm_2_b); + cur = ggml_concat(ctx0, model.mm_boi, cur, 1); + cur = ggml_concat(ctx0, cur, model.mm_eoi, 1); } else { GGML_ABORT("%s: unknown projector type", __func__); } @@ -2518,6 +2530,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_VOXTRAL: case PROJECTOR_TYPE_QWEN2A: + case PROJECTOR_TYPE_GLMA: { res = graph.build_whisper_enc(); } break; @@ -3225,6 +3238,21 @@ struct clip_model_loader { model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight")); model.mm_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias")); } break; + case PROJECTOR_TYPE_GLMA: + { + model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight")); + model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias")); + model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight")); + model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias")); + model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight")); + model.mm_1_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "bias")); + model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight")); + model.mm_2_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "bias")); + model.mm_norm_pre_w = get_tensor(string_format(TN_MM_NORM_PRE, "weight")); + model.mm_norm_pre_b = get_tensor(string_format(TN_MM_NORM_PRE, "bias")); + model.mm_boi = get_tensor(string_format(TN_TOK_BOI, "weight")); + model.mm_eoi = get_tensor(string_format(TN_TOK_EOI, "weight")); + } break; case PROJECTOR_TYPE_LLAMA4: { model.mm_model_proj = get_tensor(TN_MM_PROJECTOR); @@ -4606,6 +4634,16 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im n_patches /= 2; } } break; + case PROJECTOR_TYPE_GLMA: + { + n_patches = img->nx; + // whisper downscales input token by half after conv1d + n_patches /= 2; + // reshape by merge_factor + n_patches /= 4; + // for BOI and EOI token embeddings + n_patches += 2; + } break; case PROJECTOR_TYPE_COGVLM: { n_patches += 2; // for BOI and EOI token embeddings @@ -4941,6 +4979,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima case PROJECTOR_TYPE_IDEFICS3: case PROJECTOR_TYPE_INTERNVL: case PROJECTOR_TYPE_QWEN2A: + case PROJECTOR_TYPE_GLMA: case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_LFM2: case PROJECTOR_TYPE_VOXTRAL: @@ -5051,6 +5090,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->model.mm_model_proj->ne[1]; case PROJECTOR_TYPE_QWEN2A: return ctx->model.mm_fc_w->ne[1]; + case PROJECTOR_TYPE_GLMA: + return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_LFM2: case PROJECTOR_TYPE_KIMIVL: return ctx->model.mm_2_w->ne[1]; @@ -5097,6 +5138,7 @@ bool clip_has_audio_encoder(const struct clip_ctx * ctx) { bool clip_has_whisper_encoder(const struct clip_ctx * ctx) { return ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX || ctx->proj_type() == PROJECTOR_TYPE_QWEN2A + || ctx->proj_type() == PROJECTOR_TYPE_GLMA || ctx->proj_type() == PROJECTOR_TYPE_VOXTRAL; } From c382d64af4e264c02ed9cc67e0cd446a274861c9 Mon Sep 17 00:00:00 2001 From: liyuhang Date: Wed, 10 Dec 2025 03:54:53 +0000 Subject: [PATCH 2/6] fix format for ci --- convert_hf_to_gguf.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index bab1e8b352..6dfb771c92 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2411,7 +2411,6 @@ def set_vocab(self): # Apply to granite small models only if self.hparams.get("vocab_size", 32000) == 49152: self.gguf_writer.add_add_bos_token(False) - if isinstance(self.hparams.get("eos_token_id"), list): from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) @@ -9014,6 +9013,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) raise NotImplementedError("Ultravox does not have text decoder. Instead, it uses Llama or other models for text. If you want to get the audio encoder, please use --mmproj argument") + @ModelBase.register("GlmasrModel") class GlmASRWhisperEncoderModel(MmprojModel): has_vision_encoder = False @@ -9025,7 +9025,7 @@ def __init__(self, *args, **kwargs): self.hparams["hidden_size"] = self.hparams["d_model"] self.hparams["intermediate_size"] = self.hparams["encoder_ffn_dim"] self.hparams["num_attention_heads"] = self.hparams["encoder_attention_heads"] - + def set_gguf_parameters(self): super().set_gguf_parameters() self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GLMA) @@ -9043,12 +9043,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name.startswith("model.") or name.startswith("lm_head."): # skip language model tensors return [] - + if name.startswith("audio_encoder.whisper."): name = name.replace("audio_encoder.whisper.","audio_tower.") if "audio_encoder.layer_norm." in name or "audio_encoder.proj." in name: name = name.replace("audio_encoder.", "audio_encoder.adapting.") - + if name.startswith("audio_encoder.audio_bos_eos_token."): return [(self.map_tensor_name("model.vision.boi"), data_torch[0]), (self.map_tensor_name("model.vision.eoi"), data_torch[1])] @@ -9061,7 +9061,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if ".2." in name: name = name.replace(".2.", ".linear_2.") if ".proj." in name: - print("skip proj") return [] if "conv1.bias" in name or "conv2.bias" in name: From e8a1ec511ddaa328028c41d320714a43342aa926 Mon Sep 17 00:00:00 2001 From: liyuhang Date: Wed, 10 Dec 2025 03:57:50 +0000 Subject: [PATCH 3/6] fix convert format for ci --- convert_hf_to_gguf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 6dfb771c92..4acca39d88 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -9069,6 +9069,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(self.map_tensor_name(name), data_torch)] + @ModelBase.register("Qwen2AudioForConditionalGeneration") class WhisperEncoderModel(MmprojModel): has_vision_encoder = False # no vision encoder From 103e894780c954bc3a29f591d8e3a6a1b510f59f Mon Sep 17 00:00:00 2001 From: liyuhang Date: Fri, 12 Dec 2025 04:41:43 +0000 Subject: [PATCH 4/6] update glm_asr convert script & use build_ffn for glm_asr clip & use build_stack for padding and review --- convert_hf_to_gguf.py | 28 ++++++++++++++---------- tools/mtmd/clip.cpp | 50 ++++++++++++++++++++++++++++--------------- 2 files changed, 50 insertions(+), 28 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 4acca39d88..90af7e67e6 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2359,7 +2359,6 @@ def prepare_tensors(self): "VLlama3ForCausalLM", "LlavaForConditionalGeneration", "VoxtralForConditionalGeneration", - "GlmasrModel", "LlamaModel") class LlamaModel(TextModel): model_arch = gguf.MODEL_ARCH.LLAMA @@ -2411,16 +2410,6 @@ def set_vocab(self): # Apply to granite small models only if self.hparams.get("vocab_size", 32000) == 49152: self.gguf_writer.add_add_bos_token(False) - if isinstance(self.hparams.get("eos_token_id"), list): - from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) - special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) - special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"]) - special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) - special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) - special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"]) - special_vocab.add_to_gguf(self.gguf_writer) - special_vocab.chat_template = "glmedge" def set_gguf_parameters(self): super().set_gguf_parameters() @@ -2575,6 +2564,22 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) +@ModelBase.register("GlmasrModel") +class GlmasrModel(LlamaModel): + model_arch = gguf.MODEL_ARCH.LLAMA + + def set_vocab(self): + super().set_vocab() + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) + special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab.add_to_gguf(self.gguf_writer) + + @ModelBase.register("AfmoeForCausalLM") class AfmoeModel(LlamaModel): model_arch = gguf.MODEL_ARCH.AFMOE @@ -9031,6 +9036,7 @@ def set_gguf_parameters(self): self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GLMA) self.gguf_writer.add_audio_num_mel_bins(self.hparams["num_mel_bins"]) self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-5)) + self.gguf_writer.add_audio_stack_factor(self.global_config["merge_factor"]) def tensor_force_quant(self, name, new_name, bid, n_dims): if ".conv" in name and ".weight" in name: diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 32318aec5d..d4f8928eba 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1848,15 +1848,7 @@ struct clip_graph { if (model.audio_has_stack_frames()) { // StackAudioFrames // https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py - int64_t stride = n_embd * hparams.proj_stack_factor; - int64_t padded_len = GGML_PAD(ggml_nelements(cur), stride); - int64_t pad = padded_len - ggml_nelements(cur); - if (pad > 0) { - cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0); - cur = ggml_pad(ctx0, cur, pad, 0, 0, 0); - } - cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride, - ggml_row_size(cur->type, stride), 0); + cur = build_stack(cur, hparams.proj_stack_factor, n_embd); cb(cur, "after_stacked", -1); } @@ -1895,12 +1887,8 @@ struct clip_graph { cur = ggml_norm(ctx0, cur, hparams.eps); cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w); cur = ggml_add(ctx0, cur, model.mm_norm_pre_b); - cur = ggml_reshape_2d(ctx0, cur, cur->ne[0] * 4, cur->ne[1] / 4); - cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); - cur = ggml_add(ctx0, cur, model.mm_1_b); - cur = ggml_gelu_erf(ctx0, cur); - cur = ggml_mul_mat(ctx0, model.mm_2_w, cur); - cur = ggml_add(ctx0, cur, model.mm_2_b); + cur = build_stack(cur, hparams.proj_stack_factor, n_embd); + cur = build_ffn(cur, model.mm_1_w, model.mm_1_b, nullptr, nullptr, model.mm_2_w, model.mm_2_b, hparams.ffn_op, 0); cur = ggml_concat(ctx0, model.mm_boi, cur, 1); cur = ggml_concat(ctx0, cur, model.mm_eoi, 1); } else { @@ -2486,6 +2474,32 @@ struct clip_graph { return cur; } + // Generic function to stack frames for audio processing + // Abstracts out the StackAudioFrames logic used by ultravox + ggml_tensor * build_stack(ggml_tensor * cur, int32_t stack_factor, int32_t n_embed) { + if (stack_factor <= 1) { + return cur; + } + + int64_t total_elements = ggml_nelements(cur); + int64_t stride = n_embed * stack_factor; + + // Calculate padded length + int64_t padded_len = GGML_PAD(total_elements, stride); + int64_t pad = padded_len - total_elements; + + if (pad > 0) { + // Pad the tensor to make it divisible by stride + cur = ggml_view_1d(ctx0, cur, total_elements, 0); + cur = ggml_pad(ctx0, cur, pad, 0, 0, 0); + } + + // Reshape to [stride, padded_len / stride] + cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride, + ggml_row_size(cur->type, stride), 0); + return cur; + } + }; static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch & imgs) { @@ -2864,10 +2878,12 @@ struct clip_model_loader { } break; case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_QWEN2A: + case PROJECTOR_TYPE_GLMA: case PROJECTOR_TYPE_VOXTRAL: { bool require_stack = model.proj_type == PROJECTOR_TYPE_ULTRAVOX || - model.proj_type == PROJECTOR_TYPE_VOXTRAL; + model.proj_type == PROJECTOR_TYPE_VOXTRAL || + model.proj_type == PROJECTOR_TYPE_GLMA; get_u32(KEY_A_PROJ_STACK_FACTOR, hparams.proj_stack_factor, require_stack); if (hparams.n_mel_bins != 128) { throw std::runtime_error(string_format("%s: only 128 mel bins are supported for ultravox\n", __func__)); @@ -4640,7 +4656,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im // whisper downscales input token by half after conv1d n_patches /= 2; // reshape by merge_factor - n_patches /= 4; + n_patches /= ctx->model.hparams.proj_stack_factor; // for BOI and EOI token embeddings n_patches += 2; } break; From 98cf99f55c2bab265c5c629301a5acbb8558e705 Mon Sep 17 00:00:00 2001 From: liyuhang Date: Sat, 13 Dec 2025 08:04:55 +0000 Subject: [PATCH 5/6] check root architecture for convert hf script --- convert_hf_to_gguf.py | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 90af7e67e6..cc155db272 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1476,6 +1476,16 @@ def _try_set_pooling_type(self) -> None: raise NotImplementedError("Only MEAN, CLS, and LAST pooling types supported") self.gguf_writer.add_pooling_type(pooling_type) + def _set_vocab_glmedge(self): + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) + special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab.add_to_gguf(self.gguf_writer) + def _set_vocab_interns1(self): tokens: list[str] = [] toktypes: list[int] = [] @@ -2359,6 +2369,7 @@ def prepare_tensors(self): "VLlama3ForCausalLM", "LlavaForConditionalGeneration", "VoxtralForConditionalGeneration", + "GlmasrModel", "LlamaModel") class LlamaModel(TextModel): model_arch = gguf.MODEL_ARCH.LLAMA @@ -2410,6 +2421,8 @@ def set_vocab(self): # Apply to granite small models only if self.hparams.get("vocab_size", 32000) == 49152: self.gguf_writer.add_add_bos_token(False) + if self.hf_arch == "GlmasrModel": + self._set_vocab_glmedge() def set_gguf_parameters(self): super().set_gguf_parameters() @@ -2564,22 +2577,6 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) -@ModelBase.register("GlmasrModel") -class GlmasrModel(LlamaModel): - model_arch = gguf.MODEL_ARCH.LLAMA - - def set_vocab(self): - super().set_vocab() - from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) - special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) - special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"]) - special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) - special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) - special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"]) - special_vocab.add_to_gguf(self.gguf_writer) - - @ModelBase.register("AfmoeForCausalLM") class AfmoeModel(LlamaModel): model_arch = gguf.MODEL_ARCH.AFMOE From 86339b0e6f163a6cfd420437d6bea88874c1adc8 Mon Sep 17 00:00:00 2001 From: liyuhang Date: Sat, 13 Dec 2025 10:20:32 +0000 Subject: [PATCH 6/6] fix conficlt with upstream --- tools/mtmd/clip-graph.h | 4 +++ tools/mtmd/clip-model.h | 1 + tools/mtmd/clip.cpp | 54 +++++++++++++++---------------- tools/mtmd/models/whisper-enc.cpp | 19 ++++++----- 4 files changed, 40 insertions(+), 38 deletions(-) diff --git a/tools/mtmd/clip-graph.h b/tools/mtmd/clip-graph.h index 6d303b4e48..17f90e8aa8 100644 --- a/tools/mtmd/clip-graph.h +++ b/tools/mtmd/clip-graph.h @@ -112,4 +112,8 @@ struct clip_graph { // aka pixel_shuffle / pixel_unshuffle / patch_merger (Kimi-VL) // support dynamic resolution ggml_tensor * build_patch_merge_permute(ggml_tensor * cur, int scale_factor); + + // Generic function to stack frames for audio processing + // Abstracts out the StackAudioFrames logic used by ultravox + ggml_tensor * build_stack(ggml_tensor * cur, int32_t stack_factor, int32_t n_embed); }; diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index 51bcce1ebb..bdd4f63e0f 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -256,6 +256,7 @@ struct clip_model { ggml_tensor * conv1d_2_w = nullptr; ggml_tensor * conv1d_2_b = nullptr; ggml_tensor * mm_norm_pre_w = nullptr; + ggml_tensor * mm_norm_pre_b = nullptr; ggml_tensor * mm_norm_mid_w = nullptr; // cogvlm diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 2f57f8e970..b4aa0218d5 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -720,6 +720,32 @@ ggml_tensor * clip_graph::build_rope_2d( return cur; } +// Generic function to stack frames for audio processing +// Abstracts out the StackAudioFrames logic used by ultravox +ggml_tensor * clip_graph::build_stack(ggml_tensor * cur, int32_t stack_factor, int32_t n_embed) { + if (stack_factor <= 1) { + return cur; + } + + int64_t total_elements = ggml_nelements(cur); + int64_t stride = n_embed * stack_factor; + + // Calculate padded length + int64_t padded_len = GGML_PAD(total_elements, stride); + int64_t pad = padded_len - total_elements; + + if (pad > 0) { + // Pad the tensor to make it divisible by stride + cur = ggml_view_1d(ctx0, cur, total_elements, 0); + cur = ggml_pad(ctx0, cur, pad, 0, 0, 0); + } + + // Reshape to [stride, padded_len / stride] + cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride, + ggml_row_size(cur->type, stride), 0); + return cur; +} + // aka pixel_shuffle / pixel_unshuffle / patch_merger (Kimi-VL) // support dynamic resolution ggml_tensor * clip_graph::build_patch_merge_permute(ggml_tensor * cur, int scale_factor) { @@ -753,34 +779,6 @@ ggml_tensor * clip_graph::build_patch_merge_permute(ggml_tensor * cur, int scale return cur; } - // Generic function to stack frames for audio processing - // Abstracts out the StackAudioFrames logic used by ultravox - ggml_tensor * build_stack(ggml_tensor * cur, int32_t stack_factor, int32_t n_embed) { - if (stack_factor <= 1) { - return cur; - } - - int64_t total_elements = ggml_nelements(cur); - int64_t stride = n_embed * stack_factor; - - // Calculate padded length - int64_t padded_len = GGML_PAD(total_elements, stride); - int64_t pad = padded_len - total_elements; - - if (pad > 0) { - // Pad the tensor to make it divisible by stride - cur = ggml_view_1d(ctx0, cur, total_elements, 0); - cur = ggml_pad(ctx0, cur, pad, 0, 0, 0); - } - - // Reshape to [stride, padded_len / stride] - cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride, - ggml_row_size(cur->type, stride), 0); - return cur; - } - -}; - static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch & imgs) { GGML_ASSERT(imgs.entries.size() == 1 && "n_batch > 1 is not supported"); diff --git a/tools/mtmd/models/whisper-enc.cpp b/tools/mtmd/models/whisper-enc.cpp index 07d378b095..2870d854ab 100644 --- a/tools/mtmd/models/whisper-enc.cpp +++ b/tools/mtmd/models/whisper-enc.cpp @@ -30,7 +30,6 @@ ggml_cgraph * clip_graph_whisper_enc::build() { GGML_ASSERT(model.layers[0].q_b); GGML_ASSERT(model.layers[0].v_b); GGML_ASSERT(!model.layers[0].k_b); // no bias for k - GGML_ASSERT(model.post_ln_w && model.post_ln_b); ggml_tensor * pos_embd_selected = ggml_view_2d( ctx0, model.position_embeddings, @@ -49,15 +48,7 @@ ggml_cgraph * clip_graph_whisper_enc::build() { if (model.audio_has_stack_frames()) { // StackAudioFrames // https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py - int64_t stride = n_embd * hparams.proj_stack_factor; - int64_t padded_len = GGML_PAD(ggml_nelements(cur), stride); - int64_t pad = padded_len - ggml_nelements(cur); - if (pad > 0) { - cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0); - cur = ggml_pad(ctx0, cur, pad, 0, 0, 0); - } - cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride, - ggml_row_size(cur->type, stride), 0); + cur = build_stack(cur, hparams.proj_stack_factor, n_embd); cb(cur, "after_stacked", -1); } @@ -95,6 +86,14 @@ ggml_cgraph * clip_graph_whisper_enc::build() { FFN_GELU_ERF, -1); + } else if (proj_type == PROJECTOR_TYPE_GLMA) { + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w); + cur = ggml_add(ctx0, cur, model.mm_norm_pre_b); + cur = build_stack(cur, hparams.proj_stack_factor, n_embd); + cur = build_ffn(cur, model.mm_1_w, model.mm_1_b, nullptr, nullptr, model.mm_2_w, model.mm_2_b, hparams.ffn_op, 0); + cur = ggml_concat(ctx0, model.mm_boi, cur, 1); + cur = ggml_concat(ctx0, cur, model.mm_eoi, 1); } else { GGML_ABORT("%s: unknown projector type", __func__); }