-
-
Notifications
You must be signed in to change notification settings - Fork 302
Description
Bug Description
LoRA fine-tuning of GLM-OCR (mlx-community/GLM-OCR-bf16) fails because image tokens are handled differently during training vs inference, resulting in a model that generates garbage after fine-tuning.
Root Cause
There are two separate issues:
Issue 1: transform_dataset_to_messages uses wrong image placeholder
In lora.py, the transform_dataset_to_messages function's else branch (which GLM-OCR falls into since "glm_ocr" is not in vlm_message_model_prefixes) formats messages as:
{"role": "user", "content": f"<image>{q}"}But GLM-OCR's tokenizer does not convert <image> into the image token (ID 59280). It treats it as regular text tokens. The correct format for GLM-OCR is:
{"role": "user", "content": f"<|begin_of_image|><|image|><|end_of_image|>{q}"}Evidence: With <image>, the image token count in input_ids is 0. With <|begin_of_image|><|image|><|end_of_image|>, it's 798 (correct).
This means the model trains without ever seeing the actual images — it only learns text patterns, producing a useless adapter.
Issue 2: get_rope_index shape mismatch when image tokens ARE present
When using the correct image token format, training crashes with:
ValueError: [broadcast_shapes] Shapes (3,1,32) and (3,1,803) cannot be broadcast.
at language.py:452 in get_rope_index.
Cause: get_rope_index computes llm_positions by expanding each image token into a grid of llm_grid_t * llm_grid_h * llm_grid_w positions (e.g., 798 positions for one image). But merge_input_ids_with_image_features in glm_ocr.py merges image embeddings in-place (keeping sequence length the same). So:
llm_positionsshape:(3, 803)(32 text tokens + 798 expanded image grid - 1 image token)attention_maskshape:(32,)(original sequence length, not expanded)position_idsshape:(3, 1, 32)(original sequence length)
The mx.where on line 452 tries to broadcast these incompatible shapes.
During inference, this works because input_ids already contains the expanded image tokens (798 tokens). During training, the dataset's prepare_inputs returns input_ids with the original placeholder token (1 token), and embeddings are merged in-place later.
Steps to Reproduce
pip install mlx-vlm==0.4.0
# Create a simple training dataset (training_dataset/train.jsonl):
# {"messages": [{"role": "user", "content": "<|begin_of_image|><|image|><|end_of_image|>OCR this image"}, {"role": "assistant", "content": "{\"text\": \"hello\"}"}], "image": "/path/to/image.jpg"}
python3 -m mlx_vlm.lora \
--model-path mlx-community/GLM-OCR-bf16 \
--dataset training_dataset \
--split train \
--iters 20 \
--batch-size 1 \
--learning-rate 1e-4 \
--lora-rank 8 \
--max-seq-length 1024 \
--train-on-completions \
--output-path glm_ocr_adapterWith <image>: Training completes but adapter is useless (0 image tokens processed).
With <|begin_of_image|><|image|><|end_of_image|>: Crashes with broadcast shape error.
Environment
- mlx-vlm: 0.4.0
- mlx: 0.31.1
- Python: 3.14
- Hardware: Apple M3 Max, 64 GB
- Model: mlx-community/GLM-OCR-bf16
Suggested Fix
- Add
"glm_ocr"tovlm_message_model_prefixesinlora.py(or handle it in the else branch with the correct image token format). - Fix
get_rope_indexinmodels/glm_ocr/language.pyto handle the non-expanded image token case during training, whereinput_idscontains single placeholder tokens but positions need to account for the merged embeddings.