Skip to content

Adds multimodal support and MMMU pro #675

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 43 commits into from
May 19, 2025
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
409b0c0
init
NathanHB Apr 15, 2025
ee334c5
init
NathanHB Apr 15, 2025
e988f6f
init
NathanHB Apr 15, 2025
5fddc82
Naive implementation
qubvel Apr 21, 2025
7ce9c97
Fix choices + change metric
qubvel Apr 22, 2025
e08731a
refactor prompt function
qubvel Apr 22, 2025
8d4543b
style
qubvel Apr 22, 2025
05df4b6
FIx typing
qubvel May 6, 2025
16a9e97
Merge branch 'main' into nathan-adds-multimodal
qubvel May 6, 2025
de60add
Update max length
qubvel May 6, 2025
5fd52f5
Remove docs
qubvel May 6, 2025
10b4e0b
Update auto processor
qubvel May 6, 2025
bc7610d
add quantization config, transformers config
qubvel May 6, 2025
49e4986
Update generation size
qubvel May 7, 2025
75c900c
Add batching
qubvel May 7, 2025
4e5fdd3
Style
qubvel May 7, 2025
d1ae8b7
Add images to requests
qubvel May 7, 2025
f855158
nit
qubvel May 7, 2025
641819e
nit
qubvel May 7, 2025
aa0acb7
Clean up a bit
qubvel May 7, 2025
56f962b
nit
qubvel May 7, 2025
8e99388
Fix batch size
qubvel May 7, 2025
418840d
Add images for Doc class
qubvel May 7, 2025
e35db98
clean-up prompt manager
qubvel May 7, 2025
57c18f7
Style
qubvel May 7, 2025
7cd35c2
Style
qubvel May 7, 2025
e13cac9
Clean up prompt manager
qubvel May 7, 2025
fa18ec2
Add dtype
qubvel May 7, 2025
c59e5af
Update prompt function
qubvel May 7, 2025
8f31f1b
Refactor to pass ruff check
qubvel May 7, 2025
3675066
fix the CI
NathanHB May 12, 2025
30e22ab
fix the CI
NathanHB May 12, 2025
924bf13
Fit typing
qubvel May 12, 2025
b909259
Fix system content
qubvel May 12, 2025
665474a
Split to vision and standard tasks
qubvel May 13, 2025
1a73dd0
Data parallel
qubvel May 13, 2025
b618af7
Clean up config docs, tokenizer -> processor
qubvel May 13, 2025
79e222d
Add fast image processor option
qubvel May 13, 2025
bd2c595
Fix style
qubvel May 13, 2025
831f95e
commit
NathanHB May 19, 2025
80568e7
commit
NathanHB May 19, 2025
9fb75a6
commit
NathanHB May 19, 2025
62165a8
commit
NathanHB May 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lighteval/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def init_split_limits(self, num_dataset_splits):
splits_indices = [tuple(e) for e in splits_indices]
return num_dataset_splits, splits_indices

def _sorting_criteria(self, request: GreedyUntilRequest) -> tuple[bool, bool, list, int, int]:
def _sorting_criteria(self, request: GreedyUntilRequest) -> tuple[bool, bool, tuple, int, int]:
"""
Collate function for generating batches.

Expand Down
3 changes: 2 additions & 1 deletion src/lighteval/models/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from lighteval.models.transformers.adapter_model import AdapterModel, AdapterModelConfig
from lighteval.models.transformers.delta_model import DeltaModel, DeltaModelConfig
from lighteval.models.transformers.transformers_model import TransformersModel, TransformersModelConfig
from lighteval.models.transformers.vlm_transformers import VLMTransformersModel
from lighteval.models.vllm.vllm_model import VLLMModel, VLLMModelConfig
from lighteval.utils.imports import (
NO_LITELLM_ERROR_MSG,
Expand Down Expand Up @@ -198,7 +199,7 @@ def load_model_with_accelerate_or_default(
model = VLLMModel(config=config)
return model
else:
model = TransformersModel(config=config)
model = VLMTransformersModel(config=config)

return model

Expand Down
417 changes: 417 additions & 0 deletions src/lighteval/models/transformers/vlm_transformers.py
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need working will mainly be here, first is to have the greedy untill function working

Large diffs are not rendered by default.

49 changes: 49 additions & 0 deletions src/lighteval/tasks/default_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import random
import re
import string
from typing import Optional

import numpy as np
import pycountry
Expand All @@ -43,6 +44,54 @@
# fmt: on


def mmmu_pro(line, task_name: Optional[str] = None):
# fmt: off
question = line["question"] # "What is the capital of France?"
choices_string = line["options"] # "[Paris, London, Berlin, Madrid]"
answer = line["answer"] # "A"
# fmt: on

# TODO: Should be different for "vision"/"standard (4 options)" subsets
instructions = (
"Answer with the option letter from the given choices directly. "
"The last line of your response should be of the following format: "
"'Answer: $LETTER' (without quotes) where LETTER is one of options."
)

# Preprocess choices
# "[Paris, London, Berlin, Madrid]" -> ["A. Paris", "B. London", "C. Berlin", "D. Madrid"]
choices = ast.literal_eval(str(choices_string))
choices_letters = [chr(ord("A") + i) for i in range(len(choices))] # ["A", "B", "C", "D"]
choices = [f"{letter}. {choice}" for letter, choice in zip(choices_letters, choices)]

# Construct prompt
formatted_choices = "\n".join(choices)
prompt = f"{question}\n{formatted_choices}\n{instructions}"

# Collect images
image_order = []
for num in re.findall(r"<image\s+(\d+)>", prompt):
num = int(num)
if num not in image_order:
image_order.append(num)
images = [line[f"image_{i}"] for i in image_order]

gold_index = string.ascii_uppercase.index(answer)

# Replace image placeholders in prompt <image 1>, <image 2>, ... with [image 1], [image 2], ...
prompt = re.sub(r"<image\s+(\d+)>", "[image \\1]", prompt)
choices = [re.sub(r"<image\s+(\d+)>", "[image \\1]", choice) for choice in choices]

return Doc(
task_name=task_name,
query=prompt,
choices=choices,
gold_index=gold_index,
images=images,
specific={"id": line["id"]},
)


def simpleqa(line, task_name: str = None):
query = line["problem"]
choices = [line["answer"]]
Expand Down
16 changes: 16 additions & 0 deletions src/lighteval/tasks/default_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,22 @@
from lighteval.tasks.lighteval_task import LightevalTaskConfig


mmmu_pro = LightevalTaskConfig(
name="mmmu_pro",
suite=["lighteval"],
prompt_function=prompt.mmmu_pro,
hf_repo="MMMU/MMMU_pro",
hf_subset="standard (4 options)",
hf_avail_splits=["test"],
evaluation_splits=["test"],
few_shots_split=None,
few_shots_select=None,
generation_size=30, # expected an answer in a format 'Answer: B'
metric=[Metrics.gpqa_instruct_metric],
stop_sequence=None,
trust_dataset=True,
version=0,
)
abstract_narrative_understanding_bigbench = LightevalTaskConfig(
name="abstract_narrative_understanding",
suite=["bigbench", "bigbench_json"],
Expand Down
20 changes: 19 additions & 1 deletion src/lighteval/tasks/lighteval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ def construct_requests(
context=context,
choice=gold,
metric_categories=[MetricCategory.TARGET_PERPLEXITY],
images=formatted_doc.images,
)
for i, gold in enumerate(golds)
]
Expand All @@ -375,12 +376,13 @@ def construct_requests(
request_index=0,
context=context,
metric_categories=[MetricCategory.PERPLEXITY],
images=formatted_doc.images,
)
]
if self.has_metric_category[MetricCategory.GENERATIVE_SAMPLING]:
# All the possible sampling tasks require the same generation process - we can do them in one step
# so we select the maximum number of samples and the metrics will select only the
# relevant number of tiems
# relevant number of items
requests[RequestType.GREEDY_UNTIL] += [
GreedyUntilRequest(
task_name=current_task_name,
Expand All @@ -394,6 +396,7 @@ def construct_requests(
do_sample=True,
use_logits=False,
metric_categories=[MetricCategory.GENERATIVE_SAMPLING],
images=formatted_doc.images,
)
]
if (
Expand All @@ -420,6 +423,7 @@ def construct_requests(
]
if self.has_metric_category[c]
],
images=formatted_doc.images,
)
]
if (
Expand All @@ -438,6 +442,7 @@ def construct_requests(
for c in [MetricCategory.MULTICHOICE, MetricCategory.MULTICHOICE_PMI]
if self.has_metric_category[c]
],
images=formatted_doc.images,
)
for i, choice in enumerate(formatted_doc.choices)
]
Expand All @@ -454,6 +459,7 @@ def construct_requests(
context=formatted_doc.unconditioned_query,
choice=choice,
metric_categories=[MetricCategory.MULTICHOICE_PMI],
images=formatted_doc.images,
)
for i, choice in enumerate(formatted_doc.choices)
]
Expand All @@ -466,6 +472,7 @@ def construct_requests(
context=context,
choices=formatted_doc.choices,
metric_categories=[MetricCategory.MULTICHOICE_ONE_TOKEN],
images=formatted_doc.images,
)
]
if self.has_metric_category[MetricCategory.LLM_AS_JUDGE_MULTI_TURN]:
Expand All @@ -478,6 +485,7 @@ def construct_requests(
stop_sequence=self.stop_sequence,
generation_size=self.generation_size,
metric_categories=[MetricCategory.LLM_AS_JUDGE_MULTI_TURN],
images=formatted_doc.images,
)
]
if self.has_metric_category[MetricCategory.LLM_AS_JUDGE]:
Expand All @@ -492,6 +500,7 @@ def construct_requests(
generation_grammar=self.generation_grammar,
num_samples=1,
metric_categories=[MetricCategory.LLM_AS_JUDGE],
images=formatted_doc.images,
)
]

Expand Down Expand Up @@ -569,6 +578,15 @@ def load_datasets(tasks: list["LightevalTask"], dataset_loading_processes: int =
],
)

# TODO: debug purpose, to remove later
import os

debug_samples = int(os.getenv("DATASET_SAMPLES", 0))
if debug_samples > 0:
for dataset in datasets:
for split in dataset.keys():
dataset[split] = dataset[split].select(range(debug_samples))

for task, dataset in zip(tasks, datasets):
task.dataset = dataset

Expand Down
42 changes: 31 additions & 11 deletions src/lighteval/tasks/prompt_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,15 +210,20 @@ def _single_turn_context(
system_prompt=system_prompt,
use_chat_template=use_chat_template,
cot_prompt=cot_prompt,
doc=doc,
)
if not use_chat_template:
toks = self.model.tok_encode(output)
else:
toks = [self.model.tok_encode(msg["content"]) for msg in output]
toks = [t for ts in toks for t in ts]

if truncate_few_shots and doc.images is not None:
raise NotImplementedError("Few shot evaluation is not supported for multi-modal tasks yet.")

# If we need to truncate few-shots to fit in the context
if truncate_few_shots and self.model.max_length is not None and self.model.tokenizer is not None:
if not use_chat_template:
toks = self.model.tok_encode(output)
else:
toks = [self.model.tok_encode(msg["content"]) for msg in output]
toks = [t for ts in toks for t in ts]

# If self.generation_size is None, the maximum allowed generation size depends
# on the model maximum context length, not on the task - we don't take it into account here
# but we probably should
Expand Down Expand Up @@ -258,8 +263,27 @@ def get_examples(
system_prompt: Union[str | None],
use_chat_template: bool,
cot_prompt: Union[str | None],
doc: Doc,
):
is_multimodal = doc.images is not None

if is_multimodal and not use_chat_template:
raise NotImplementedError("Multi-modal tasks do not support formatting without chat template yet.")

if is_multimodal and fewshot_ex:
raise NotImplementedError("Multi-modal tasks do not support fewshot evaluation yet.")

content = example + cot_prompt if cot_prompt is not None else example

if is_multimodal:
text_content = [{"type": "text", "text": content}]
image_content = [{"type": "image", "image": image} for image in doc.images]
message = {"role": "user", "content": text_content + image_content}
return [message]

# Regular text (not multimodal)
examples = []

# Few shot examples
for ex in fewshot_ex:
if use_chat_template:
Expand All @@ -269,8 +293,6 @@ def get_examples(
examples.append(self.doc_to_text(ex, return_instructions=False) + self.doc_to_target(ex))

# Actual example
content = example + cot_prompt if cot_prompt is not None else example

if use_chat_template:
examples.append({"role": "user", "content": content})
else:
Expand All @@ -284,10 +306,8 @@ def get_examples(
examples[0]["content"] = instruction + examples[0]["content"]
return examples
else:
if system_prompt is not None:
output = system_prompt + instruction + "\n\n".join(examples)
else:
output = instruction + "\n\n".join(examples)
system_prompt = system_prompt if system_prompt is not None else ""
output = system_prompt + instruction + "\n\n".join(examples)
if output == "\n\n":
return ""
return output
Expand Down
8 changes: 8 additions & 0 deletions src/lighteval/tasks/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class LoglikelihoodRequest(Request):
request_type = RequestType.LOGLIKELIHOOD
tokenized_context: list[int] = None
tokenized_continuation: list[int] = None
images: Optional[list["PIL.Image.Image"]] = None # noqa F821


@dataclass
Expand All @@ -92,6 +93,7 @@ class LoglikelihoodSingleTokenRequest(Request):
request_type = RequestType.LOGLIKELIHOOD_SINGLE_TOKEN
tokenized_context: list[int] = None
tokenized_continuation: list[int] = None
images: Optional[list["PIL.Image.Image"]] = None # noqa F821


@dataclass
Expand All @@ -105,6 +107,7 @@ class LoglikelihoodRollingRequest(Request):
request_type = RequestType.LOGLIKELIHOOD_ROLLING
tokenized_context: list[int] = None
tokenized_continuation: list[int] = None
images: Optional[list["PIL.Image.Image"]] = None # noqa F821


@dataclass
Expand All @@ -128,6 +131,7 @@ class GreedyUntilRequest(Request):
num_samples: int = None
do_sample: bool = False
use_logits: bool = False
images: Optional[list["PIL.Image.Image"]] = None # noqa F821


@dataclass
Expand All @@ -145,6 +149,7 @@ class GreedyUntilMultiTurnRequest(Request):
generation_size: int
request_type = RequestType.GREEDY_UNTIL_MULTI_TURN
use_logits: bool = False
images: Optional[list["PIL.Image.Image"]] = None # noqa F821


class SampleUid(NamedTuple):
Expand Down Expand Up @@ -190,6 +195,9 @@ class Doc:
# The uncoditioned query shouldn't contain any information about the task, thus usually it's empty string or 'Answer:'.
unconditioned_query: Optional[str] = None

# For multi-modal tasks
images: Optional[list["PIL.Image.Image"]] = None # noqa F821

def __post_init__(self):
if self.instruction is None:
self.instruction = ""
Expand Down
Loading