diff --git a/examples/model_configs/transformers_vlm_model.yaml b/examples/model_configs/transformers_vlm_model.yaml new file mode 100644 index 000000000..6a32c0932 --- /dev/null +++ b/examples/model_configs/transformers_vlm_model.yaml @@ -0,0 +1,10 @@ +model_parameters: + model_name: "Qwen/Qwen2.5-VL-3B-Instruct" + revision: "main" + dtype: "float16" + compile: false + model_parallel: false + batch_size: 1 + generation_parameters: + temperature: 0.2 + top_p: 0.9 diff --git a/pyproject.toml b/pyproject.toml index 4b2c4c768..76181f18a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ classifiers = [ keywords = ["evaluation", "nlp", "llm"] dependencies = [ # Base dependencies - "transformers>=4.38.0", + "transformers>=4.51.0", "accelerate", "huggingface_hub[hf_xet]>=0.30.2", "torch>=2.0,<3.0", diff --git a/src/lighteval/main_accelerate.py b/src/lighteval/main_accelerate.py index 5a1fe28cf..5255aaefc 100644 --- a/src/lighteval/main_accelerate.py +++ b/src/lighteval/main_accelerate.py @@ -48,6 +48,9 @@ def accelerate( # noqa C901 use_chat_template: Annotated[ bool, Option(help="Use chat template for evaluation.", rich_help_panel=HELP_PANEL_NAME_4) ] = False, + vision_model: Annotated[ + bool, Option(help="Use vision model for evaluation.", rich_help_panel=HELP_PANEL_NAME_4) + ] = False, system_prompt: Annotated[ Optional[str], Option(help="Use system prompt for evaluation.", rich_help_panel=HELP_PANEL_NAME_4) ] = None, @@ -109,6 +112,7 @@ def accelerate( # noqa C901 from lighteval.models.transformers.adapter_model import AdapterModelConfig from lighteval.models.transformers.delta_model import DeltaModelConfig from lighteval.models.transformers.transformers_model import TransformersModelConfig + from lighteval.models.transformers.vlm_transformers_model import VLMTransformersModelConfig from lighteval.models.utils import ModelConfig from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters @@ -147,7 +151,10 @@ def accelerate( # noqa C901 elif config.get("adapter_weights", False): model_config = AdapterModelConfig(**config) else: - model_config = TransformersModelConfig(**config) + if vision_model: + model_config = VLMTransformersModelConfig(**config) + else: + model_config = TransformersModelConfig(**config) pipeline = Pipeline( tasks=tasks, diff --git a/src/lighteval/models/model_loader.py b/src/lighteval/models/model_loader.py index 7b4cb1cac..ef9c77549 100644 --- a/src/lighteval/models/model_loader.py +++ b/src/lighteval/models/model_loader.py @@ -42,6 +42,8 @@ from lighteval.models.transformers.adapter_model import AdapterModel, AdapterModelConfig from lighteval.models.transformers.delta_model import DeltaModel, DeltaModelConfig from lighteval.models.transformers.transformers_model import TransformersModel, TransformersModelConfig +from lighteval.models.transformers.vlm_transformers_model import VLMTransformersModel, VLMTransformersModelConfig +from lighteval.models.utils import ModelConfig from lighteval.models.vllm.vllm_model import VLLMModel, VLLMModelConfig from lighteval.utils.imports import ( NO_LITELLM_ERROR_MSG, @@ -60,21 +62,8 @@ def load_model( # noqa: C901 - config: Union[ - TransformersModelConfig, - AdapterModelConfig, - DeltaModelConfig, - TGIModelConfig, - InferenceEndpointModelConfig, - DummyModelConfig, - VLLMModelConfig, - CustomModelConfig, - OpenAIModelConfig, - LiteLLMModelConfig, - SGLangModelConfig, - InferenceProvidersModelConfig, - ], -) -> Union[TransformersModel, AdapterModel, DeltaModel, ModelClient, DummyModel]: + config: ModelConfig, +) -> LightevalModel: """Will load either a model from an inference server or a model from a checkpoint, depending on the config type. @@ -100,6 +89,9 @@ def load_model( # noqa: C901 if isinstance(config, TransformersModelConfig): return load_model_with_accelerate_or_default(config) + if isinstance(config, VLMTransformersModelConfig): + return load_model_with_accelerate_or_default(config) + if isinstance(config, DummyModelConfig): return load_dummy_model(config) @@ -186,7 +178,9 @@ def load_model_with_inference_endpoints(config: Union[InferenceEndpointModelConf def load_model_with_accelerate_or_default( - config: Union[AdapterModelConfig, TransformersModelConfig, DeltaModelConfig], + config: Union[ + AdapterModelConfig, TransformersModelConfig, DeltaModelConfig, VLLMModelConfig, VLMTransformersModelConfig + ], ): if isinstance(config, AdapterModelConfig): model = AdapterModel(config=config) @@ -197,6 +191,9 @@ def load_model_with_accelerate_or_default( raise ImportError(NO_VLLM_ERROR_MSG) model = VLLMModel(config=config) return model + elif isinstance(config, VLMTransformersModelConfig): + model = VLMTransformersModel(config=config) + return model else: model = TransformersModel(config=config) diff --git a/src/lighteval/models/transformers/vlm_transformers_model.py b/src/lighteval/models/transformers/vlm_transformers_model.py new file mode 100644 index 000000000..611073e88 --- /dev/null +++ b/src/lighteval/models/transformers/vlm_transformers_model.py @@ -0,0 +1,455 @@ +# MIT License + +# Copyright (c) 2025 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import logging +import os +from typing import Optional, Tuple, Union + +import torch +from pydantic import PositiveInt +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import ( + AutoConfig, + AutoModelForImageTextToText, + AutoProcessor, + BitsAndBytesConfig, + PretrainedConfig, +) + +from lighteval.data import GenerativeTaskDataset +from lighteval.models.abstract_model import LightevalModel, ModelInfo +from lighteval.models.model_output import ( + GenerativeResponse, + LoglikelihoodResponse, + LoglikelihoodSingleTokenResponse, +) +from lighteval.models.utils import ModelConfig, _get_dtype, _get_model_sha, _simplify_name +from lighteval.tasks.requests import ( + GreedyUntilRequest, + LoglikelihoodRequest, + LoglikelihoodSingleTokenRequest, +) +from lighteval.utils.imports import ( + is_accelerate_available, +) +from lighteval.utils.utils import as_list + + +logger = logging.getLogger(__name__) + + +if is_accelerate_available(): + from datetime import timedelta + + from accelerate import Accelerator, InitProcessGroupKwargs + from accelerate.utils import gather_object, get_max_memory + + +class BatchCollator: + """Collator for batching requests""" + + def __init__(self, processor, **kwargs): + self.processor = processor + self.kwargs = kwargs + + def __call__(self, requests: list[GreedyUntilRequest]) -> Tuple[dict[str, torch.Tensor], list[GreedyUntilRequest]]: + texts = [request.context for request in requests] + images = [request.images for request in requests] + inputs = self.processor(text=texts, images=images, **self.kwargs) + return inputs, requests + + +class VLMTransformersModelConfig(ModelConfig): + """ + Base configuration class for models. + + Attributes: + model_name (str): + HuggingFace Hub model ID name or the path to a pre-trained + model to load. This is effectively the `pretrained_model_name_or_path` + argument of `from_pretrained` in the HuggingFace `transformers` API. + processor (Optional[str]): HuggingFace Hub processor ID that will be + used for preprocessing images and text. + use_fast_image_processor (Optional[bool]): + Whether to use a fast image processor. Not all VLMs support this yet. + subfolder (Optional[str]): The subfolder within the model repository. + revision (str): The revision of the model. + batch_size (int): The batch size for model training. + generation_size (Optional[int]): The maximum number of tokens to generate. + max_length (Optional[int]): The maximum length of the generated output. + add_special_tokens (bool, optional, defaults to True): Whether to add special tokens to the input sequences. + model_parallel (bool, optional, defaults to None): + True/False: force to use or not the `accelerate` library to load a large + model across multiple devices. + Default: None which corresponds to comparing the number of processes with + the number of GPUs. If it's smaller => model-parallelism, else not. + dtype (Union[str, torch.dtype], optional, defaults to None): + Converts the model weights to `dtype`, if specified. Strings get + converted to `torch.dtype` objects (e.g. `float16` -> `torch.float16`). + Use `dtype="auto"` to derive the type from the model's weights. + device (Union[int, str]): device to use for model training. + quantization_config (Optional[BitsAndBytesConfig]): quantization + configuration for the model, manually provided to load a normally floating point + model at a quantized precision. Needed for 4-bit and 8-bit precision. + trust_remote_code (bool): Whether to trust remote code during model + loading. + generation_parameters (GenerationParameters): Range of parameters which will affect the generation. + generation_config (GenerationConfig): GenerationConfig object (only passed during manual creation) + + Methods: + __post_init__(): Performs post-initialization checks on the configuration. + _init_configs(model_name, env_config): Initializes the model configuration. + init_configs(env_config): Initializes the model configuration using the environment configuration. + get_model_sha(): Retrieves the SHA of the model. + + """ + + model_name: str + processor: str | None = None + use_fast_image_processor: bool | None = None + subfolder: str | None = None + revision: str = "main" + batch_size: PositiveInt = 1 + generation_size: PositiveInt | None = None + max_length: PositiveInt | None = None + add_special_tokens: bool = True + model_parallel: bool | None = None + dtype: str | None = None + device: Union[int, str] = "cuda" + trust_remote_code: bool = False + use_chat_template: bool = False + compile: bool = False + device_map: str | None = None + + def get_model_sha(self): + return _get_model_sha(repo_id=self.model_name, revision=self.revision) + + def get_transformers_config(self) -> PretrainedConfig: + revision = f"{self.revision}/{self.subfolder}" if self.subfolder else self.revision + config = AutoConfig.from_pretrained( + self.model_name, + revision=revision, + trust_remote_code=self.trust_remote_code, + ) + return config + + +class VLMTransformersModel(LightevalModel): + def __init__( + self, + config: VLMTransformersModelConfig, + ): + """Initializes a HuggingFace `AutoModel` and `AutoTokenizer` for evaluation.""" + + self.accelerator = Accelerator(kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=3000))]) + self.device = self.accelerator.device + self.torch_dtype = _get_dtype(config.dtype) + + # Config attributes + self.config = config + self.use_chat_template = config.use_chat_template + self.batch_size = config.batch_size + + # Model, config, and processor + self.model_sha = config.get_model_sha() + self.model_name = _simplify_name(config.model_name) + self.model = self._create_auto_model() + self.processor = self._create_auto_processor() + self.transformers_config = config.get_transformers_config() + + # Attributes exposed by @property + self._max_length = self._init_max_length() + self._add_special_tokens = config.add_special_tokens or False + + # Generation config + self.generation_config_dict = config.generation_parameters.to_transformers_dict() + self.generation_config_dict["pad_token_id"] = self.pad_token_id + self.generation_config_dict["eos_token_id"] = self.eos_token_id + self.generation_config_dict["renormalize_logits"] = True + + self.model_info = ModelInfo( + model_name=self.config.model_name, + model_sha=self.model_sha, + model_dtype=config.dtype, + ) + + @property + def tokenizer(self): + return self.processor.tokenizer + + @property + def pad_token_id(self): + return self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else self.tokenizer.eos_token_id + + @property + def eos_token_id(self): + return self.tokenizer.eos_token_id + + @property + def add_special_tokens(self): + return self._add_special_tokens + + @property + def max_length(self): + return self._max_length + + @property + def disable_tqdm(self) -> bool: + disable_tqdm = False + if self.accelerator: + disable_tqdm = bool(not self.accelerator.is_main_process) + return disable_tqdm + + # Copied from ./transformers_model.py + def init_model_parallel(self, model_parallel: bool | None = None) -> Tuple[bool, Optional[dict], Optional[str]]: + """Compute all the parameters related to model_parallel""" + if not is_accelerate_available(): + return False, None, None + + self.num_local_processes = int(os.environ.get("LOCAL_WORLD_SIZE", 1)) + self.num_machines = torch.cuda.device_count() // self.num_local_processes + + if self.num_machines == 1: + logger.info("We are not in a distributed setting. Setting model_parallel to False.") + model_parallel = False + + if model_parallel is None: + max_memory_all_gpus = get_max_memory() # A dict of the max memory for all the gpus + if "cpu" in max_memory_all_gpus: + del max_memory_all_gpus["cpu"] + model_parallel = bool(self.num_local_processes < len(max_memory_all_gpus)) + logger.info( + f"Setting model parallel to {model_parallel} since " + f"the number of local processes is {self.num_local_processes} " + f"and the number of GPUs is {len(max_memory_all_gpus)}" + ) + if model_parallel is True: + max_memory_all_gpus = get_max_memory() # A dict of the max memory for all the gpus + if "cpu" in max_memory_all_gpus: + del max_memory_all_gpus["cpu"] + max_mem_this_process = { + k: v + for k, v in max_memory_all_gpus.items() + if k % self.num_local_processes == (self.accelerator.process_index % self.num_local_processes) + } + device_map = "auto" + logger.info( + f"Model parallel was set to True, setting max memory per GPU to {max_mem_this_process} and device map to {device_map}" + ) + else: + max_mem_this_process = None + device_map = None + logger.info( + f"Model parallel was set to False, max memory set to {max_mem_this_process} and device map to {device_map}" + ) + return model_parallel, max_mem_this_process, device_map + + @staticmethod + def _get_quantization_config(config: VLMTransformersModelConfig) -> BitsAndBytesConfig | None: + if config.dtype == "4bit": + quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16) + elif config.dtype == "8bit": + quantization_config = BitsAndBytesConfig(load_in_8bit=True) + else: + quantization_config = None + return quantization_config + + def _create_auto_model(self): + model_parallel, max_memory, device_map = self.init_model_parallel(self.config.model_parallel) + self.config.model_parallel = model_parallel + + quantization_config = self._get_quantization_config(self.config) + + subfolder = self.config.subfolder + revision = f"{self.config.revision}/{subfolder}" if subfolder is not None else self.config.revision + + model = AutoModelForImageTextToText.from_pretrained( + self.config.model_name, + revision=revision, + device_map=device_map, + max_memory=max_memory, + torch_dtype=self.torch_dtype, + quantization_config=quantization_config, + trust_remote_code=self.config.trust_remote_code, + ) + model.eval() + torch.set_grad_enabled(False) + + if self.config.compile: + raise NotImplementedError("Compiling VLM models is not supported yet") + + # We are in DP (and launch the script with `accelerate launch`) + if model_parallel is False and self.config.dtype not in ["4bit", "8bit"]: + logger.info(f"Using Data Parallelism, putting model on device {self.device}") + model = model.to(self.device) + + return model + + def _create_auto_processor(self): + """ + Create a transformers `Processor` for VLM (image-text-to-text) model. + + Returns: + transformers.ProcessorMixin: The created processor. + """ + processor_name = self.config.processor or self.config.model_name + revision, subfolder = self.config.revision, self.config.subfolder + revision = revision if not subfolder else f"{revision}/{subfolder}" + + processor = AutoProcessor.from_pretrained( + processor_name, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=self.config.trust_remote_code, + use_fast=self.config.use_fast_image_processor, + ) + + return processor + + def _init_max_length(self) -> int: + """Return the maximum sequence length of the model. + + NOTE: For relative position encoded models you should specify the max + sequence length of the model in the constructor via `max_length`. + + Returns: + int: Max length to use depending on the available args and config + """ + if self.config.max_length is not None: + return self.config.max_length + + # Try to get the sequence length from the model config. It's no super robust + text_model_config = self.transformers_config.get_text_config() + max_seq_length = getattr(text_model_config, "max_position_embeddings", None) + if max_seq_length is not None: + return max_seq_length + + logger.warning( + "No max_length attribute found in the model config. Using the default max sequence length setting `2048`. " + "It is recommended to set max_length trough the model args: max_length=..." + ) + + return 2048 + + def _tokenize_requests_context_inplace(self, requests: list[GreedyUntilRequest]): + """Preprocess requests to fill in the tokenized_context field for sorting in the dataset""" + for request in requests: + request.stop_sequence = as_list(request.stop_sequence) + [self.tokenizer.eos_token] + inputs = self.processor(text=request.context, images=request.images) + request.tokenized_context = inputs["input_ids"][0] + + def greedy_until( + self, + requests: list[GreedyUntilRequest], + ) -> list[GenerativeResponse]: + """ + Generates responses using a greedy decoding strategy until certain ending conditions are met. + + Args: + requests (list[Request]): list of requests containing the context and ending conditions. + override_bs (int, optional): Override the batch size for generation. Defaults to None. + + Returns: + list[GenerativeResponse]: list of generated responses. + """ + + # Tokenizing context for sorting in the dataset + logger.info("Tokenizing requests context for sorting in the dataset") + self._tokenize_requests_context_inplace(requests) + logger.info("Done tokenizing!") + + dataset = GenerativeTaskDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS) + collator = BatchCollator( + self.processor, + truncation="longest_first", # we truncate to the model max length if needed + padding="longest", # we pad to the longest sequence + max_length=self.max_length - 1, # we should always allow minimum one token of generation + add_special_tokens=self.add_special_tokens, + return_tensors="pt", + ) + + results = [] + for split in dataset.splits_iterator(): + batch_size = self.batch_size or 1 + dataloader = DataLoader(split, batch_size=batch_size, collate_fn=collator) + if self.accelerator: + dataloader = self.accelerator.prepare(dataloader) + + for batch_inputs, batch_requests in tqdm( + dataloader, desc="Greedy generation", position=1, leave=True, disable=self.disable_tqdm + ): + batch_inputs = batch_inputs.to(self.device) + if self.torch_dtype is not None: + batch_inputs = batch_inputs.to(self.torch_dtype) + + max_new_tokens = self.config.generation_size or batch_requests[0].generation_size + outputs = self.model.generate( + **batch_inputs, + **self.generation_config_dict, # custom generation params + max_new_tokens=max_new_tokens, + do_sample=batch_requests[0].do_sample, + num_return_sequences=batch_requests[0].num_samples, + output_logits=batch_requests[0].use_logits, + ) + input_tokens = batch_inputs.input_ids + generated_tokens = outputs.sequences[:, input_tokens.shape[1] :] + generated_texts = self.processor.batch_decode(generated_tokens, skip_special_tokens=True) + attention_mask = batch_inputs["attention_mask"] + padded_tokens_count = (attention_mask == 0).sum(dim=1) + + batch_results = [] + for i in range(len(generated_texts)): + generated_response = GenerativeResponse( + result=generated_texts[i], + generated_tokens=generated_tokens[i].cpu().numpy(), + input_tokens=input_tokens[i].cpu().numpy(), + truncated_tokens_count=-1, + padded_tokens_count=padded_tokens_count[i].item(), + logits=outputs.logits[i].cpu().numpy() if outputs.logits is not None else None, + ) + batch_results.append(generated_response) + + if self.accelerator: + batch_results = gather_object(batch_results) + + results.extend(batch_results) + + return dataset.get_original_order(results) + + def loglikelihood( + self, + requests: list[LoglikelihoodRequest], + ) -> list[LoglikelihoodResponse]: + raise NotImplementedError() + + def loglikelihood_single_token( + self, requests: list[LoglikelihoodSingleTokenRequest] + ) -> list[LoglikelihoodSingleTokenResponse]: + raise NotImplementedError() + + def loglikelihood_rolling( + self, + requests: list[LoglikelihoodRequest], + ) -> list[LoglikelihoodResponse]: + raise NotImplementedError() diff --git a/src/lighteval/tasks/default_prompts.py b/src/lighteval/tasks/default_prompts.py index 2745b63c5..8f973ffda 100644 --- a/src/lighteval/tasks/default_prompts.py +++ b/src/lighteval/tasks/default_prompts.py @@ -26,6 +26,7 @@ import random import re import string +from typing import Optional import numpy as np import pycountry @@ -43,6 +44,82 @@ # fmt: on +def mmmu_pro(line, task_name: Optional[str] = None): + # fmt: off + question = line["question"] # "What is the capital of France?" + choices_string = line["options"] # "[Paris, London, Berlin, Madrid]" + answer = line["answer"] # "A" + # fmt: on + + instructions = "Answer with the option letter from the given choices directly." + + # Preprocess choices + # "[Paris, London, Berlin, Madrid]" -> ["A. Paris", "B. London", "C. Berlin", "D. Madrid"] + choices = ast.literal_eval(str(choices_string)) + choices_letters = [chr(ord("A") + i) for i in range(len(choices))] # ["A", "B", "C", "D"] + choices = [f"{letter}. {choice}" for letter, choice in zip(choices_letters, choices)] + + # Construct prompt + formatted_choices = "\n".join(choices) + prompt = f"{instructions}\n{question}\n{formatted_choices}" + + # Collect images + image_order = [] + for num in re.findall(r"", prompt): + num = int(num) + if num not in image_order: + image_order.append(num) + images = [line[f"image_{i}"].convert("RGB") for i in image_order] + + gold_index = string.ascii_uppercase.index(answer) + + # Replace image placeholders in prompt , , ... with [image 1], [image 2], ... + prompt = re.sub(r"", "[image \\1]", prompt) + choices = [re.sub(r"", "[image \\1]", choice) for choice in choices] + + return Doc( + task_name=task_name, + query=prompt, + choices=choices, + gold_index=gold_index, + images=images, + specific={"id": line["id"]}, + instruction=instructions, + ) + + +def mmmu_pro_vision(line, task_name: str = None): + instruction = ( + "Answer with the option letter from the given choices directly." + " The last line of your response should be of the following format: " + "'Answer: $LETTER' (without quotes) where LETTER is one of options." + ) + + # Preprocess choices + # "[Paris, London, Berlin, Madrid]" -> ["A. Paris", "B. London", "C. Berlin", "D. Madrid"] + choices_string = line["options"] + choices = ast.literal_eval(str(choices_string)) + choices_letters = [chr(ord("A") + i) for i in range(len(choices))] # ["A", "B", "C", "D"] + choices = [f"{letter}. {choice}" for letter, choice in zip(choices_letters, choices)] + + # Preprocess answer + # e.g. "A" -> 0 + answer = line["answer"] + gold_index = string.ascii_uppercase.index(answer) + + # Preprocess images + images = [line["image"]] + + return Doc( + task_name=task_name, + query=instruction, + choices=choices, + gold_index=gold_index, + images=images, + instruction=instruction, + ) + + def simpleqa(line, task_name: str = None): query = line["problem"] choices = [line["answer"]] diff --git a/src/lighteval/tasks/default_tasks.py b/src/lighteval/tasks/default_tasks.py index 3960e6f5c..676aaeae2 100644 --- a/src/lighteval/tasks/default_tasks.py +++ b/src/lighteval/tasks/default_tasks.py @@ -24,6 +24,54 @@ from lighteval.tasks.lighteval_task import LightevalTaskConfig +mmmu_pro_standard_4_options = LightevalTaskConfig( + name="mmmu_pro:standard-4", + suite=["lighteval"], + prompt_function=prompt.mmmu_pro, + hf_repo="MMMU/MMMU_pro", + hf_subset="standard (4 options)", + hf_avail_splits=["test"], + evaluation_splits=["test"], + few_shots_split=None, + few_shots_select=None, + generation_size=30, # expected an answer in a format 'Answer: B' + metric=[Metrics.gpqa_instruct_metric], + stop_sequence=None, + trust_dataset=True, + version=0, +) +mmmu_pro_standard_10_options = LightevalTaskConfig( + name="mmmu_pro:standard-10", + suite=["lighteval"], + prompt_function=prompt.mmmu_pro, + hf_repo="MMMU/MMMU_pro", + hf_subset="standard (10 options)", + hf_avail_splits=["test"], + evaluation_splits=["test"], + few_shots_split=None, + few_shots_select=None, + generation_size=30, # expected an answer in a format 'Answer: B' + metric=[Metrics.gpqa_instruct_metric], + stop_sequence=None, + trust_dataset=True, + version=0, +) +mmmu_pro_vision = LightevalTaskConfig( + name="mmmu_pro:vision", + suite=["lighteval"], + prompt_function=prompt.mmmu_pro_vision, + hf_repo="MMMU/MMMU_pro", + hf_subset="vision", + hf_avail_splits=["test"], + evaluation_splits=["test"], + few_shots_split=None, + few_shots_select=None, + generation_size=30, # expected an answer in a format 'Answer: B' + metric=[Metrics.gpqa_instruct_metric], + stop_sequence=None, + trust_dataset=True, + version=0, +) abstract_narrative_understanding_bigbench = LightevalTaskConfig( name="abstract_narrative_understanding", suite=["bigbench", "bigbench_json"], diff --git a/src/lighteval/tasks/lighteval_task.py b/src/lighteval/tasks/lighteval_task.py index da09ec000..d6b203d58 100644 --- a/src/lighteval/tasks/lighteval_task.py +++ b/src/lighteval/tasks/lighteval_task.py @@ -364,6 +364,7 @@ def construct_requests( context=context, choice=gold, metric_categories=[MetricCategory.TARGET_PERPLEXITY], + images=formatted_doc.images, ) for i, gold in enumerate(golds) ] @@ -375,12 +376,13 @@ def construct_requests( request_index=0, context=context, metric_categories=[MetricCategory.PERPLEXITY], + images=formatted_doc.images, ) ] if self.has_metric_category[MetricCategory.GENERATIVE_SAMPLING]: # All the possible sampling tasks require the same generation process - we can do them in one step # so we select the maximum number of samples and the metrics will select only the - # relevant number of tiems + # relevant number of items requests[RequestType.GREEDY_UNTIL] += [ GreedyUntilRequest( task_name=current_task_name, @@ -394,6 +396,7 @@ def construct_requests( do_sample=True, use_logits=False, metric_categories=[MetricCategory.GENERATIVE_SAMPLING], + images=formatted_doc.images, ) ] if ( @@ -420,6 +423,7 @@ def construct_requests( ] if self.has_metric_category[c] ], + images=formatted_doc.images, ) ] if ( @@ -438,6 +442,7 @@ def construct_requests( for c in [MetricCategory.MULTICHOICE, MetricCategory.MULTICHOICE_PMI] if self.has_metric_category[c] ], + images=formatted_doc.images, ) for i, choice in enumerate(formatted_doc.choices) ] @@ -454,6 +459,7 @@ def construct_requests( context=formatted_doc.unconditioned_query, choice=choice, metric_categories=[MetricCategory.MULTICHOICE_PMI], + images=formatted_doc.images, ) for i, choice in enumerate(formatted_doc.choices) ] @@ -466,6 +472,7 @@ def construct_requests( context=context, choices=formatted_doc.choices, metric_categories=[MetricCategory.MULTICHOICE_ONE_TOKEN], + images=formatted_doc.images, ) ] if self.has_metric_category[MetricCategory.LLM_AS_JUDGE_MULTI_TURN]: @@ -478,6 +485,7 @@ def construct_requests( stop_sequence=self.stop_sequence, generation_size=self.generation_size, metric_categories=[MetricCategory.LLM_AS_JUDGE_MULTI_TURN], + images=formatted_doc.images, ) ] if self.has_metric_category[MetricCategory.LLM_AS_JUDGE]: @@ -492,6 +500,7 @@ def construct_requests( generation_grammar=self.generation_grammar, num_samples=1, metric_categories=[MetricCategory.LLM_AS_JUDGE], + images=formatted_doc.images, ) ] diff --git a/src/lighteval/tasks/prompt_manager.py b/src/lighteval/tasks/prompt_manager.py index d1dc15d0d..3ec959cb9 100644 --- a/src/lighteval/tasks/prompt_manager.py +++ b/src/lighteval/tasks/prompt_manager.py @@ -210,15 +210,20 @@ def _single_turn_context( system_prompt=system_prompt, use_chat_template=use_chat_template, cot_prompt=cot_prompt, + doc=doc, ) - if not use_chat_template: - toks = self.model.tok_encode(output) - else: - toks = [self.model.tok_encode(msg["content"]) for msg in output] - toks = [t for ts in toks for t in ts] + + if truncate_few_shots and doc.images is not None: + raise NotImplementedError("Few shot evaluation is not supported for multi-modal tasks yet.") # If we need to truncate few-shots to fit in the context if truncate_few_shots and self.model.max_length is not None and self.model.tokenizer is not None: + if not use_chat_template: + toks = self.model.tok_encode(output) + else: + toks = [self.model.tok_encode(msg["content"]) for msg in output] + toks = [t for ts in toks for t in ts] + # If self.generation_size is None, the maximum allowed generation size depends # on the model maximum context length, not on the task - we don't take it into account here # but we probably should @@ -250,7 +255,7 @@ def _single_turn_context( return output, num_effective_fewshots - def get_examples( + def get_examples( # noqa: C901 self, example: str, instruction: Union[str | None], @@ -258,8 +263,37 @@ def get_examples( system_prompt: Union[str | None], use_chat_template: bool, cot_prompt: Union[str | None], + doc: Doc, ): + is_multimodal = doc.images is not None + + if is_multimodal and not use_chat_template: + raise NotImplementedError("Multi-modal tasks do not support formatting without chat template yet.") + + if is_multimodal and fewshot_ex: + raise NotImplementedError("Multi-modal tasks do not support fewshot evaluation yet.") + + content = example + cot_prompt if cot_prompt is not None else example + + if is_multimodal: + text_content = [{"type": "text", "text": content}] + image_content = [{"type": "image", "image": image} for image in doc.images] + message = {"role": "user", "content": text_content + image_content} + + if ( + system_prompt is not None or instruction is not None + ): # We add system prompt and instruction jointly if possible + system_prompt = system_prompt if system_prompt is not None else "" + instruction = instruction if instruction is not None else "" + system_content = [{"type": "text", "text": system_prompt + instruction}] + system_prompt_message = {"role": "system", "content": system_content} + return [system_prompt_message, message] + + return [message] + + # Regular text (not multimodal) examples = [] + # Few shot examples for ex in fewshot_ex: if use_chat_template: @@ -269,8 +303,6 @@ def get_examples( examples.append(self.doc_to_text(ex, return_instructions=False) + self.doc_to_target(ex)) # Actual example - content = example + cot_prompt if cot_prompt is not None else example - if use_chat_template: examples.append({"role": "user", "content": content}) else: @@ -284,10 +316,8 @@ def get_examples( examples[0]["content"] = instruction + examples[0]["content"] return examples else: - if system_prompt is not None: - output = system_prompt + instruction + "\n\n".join(examples) - else: - output = instruction + "\n\n".join(examples) + system_prompt = system_prompt if system_prompt is not None else "" + output = system_prompt + instruction + "\n\n".join(examples) if output == "\n\n": return "" return output diff --git a/src/lighteval/tasks/requests.py b/src/lighteval/tasks/requests.py index cd75ad402..733040406 100644 --- a/src/lighteval/tasks/requests.py +++ b/src/lighteval/tasks/requests.py @@ -23,13 +23,17 @@ import json from dataclasses import asdict, dataclass from enum import Enum, auto -from typing import NamedTuple, Optional, Union +from typing import TYPE_CHECKING, NamedTuple, Optional, Union from huggingface_hub import TextGenerationInputGrammarType from lighteval.utils.utils import as_list +if TYPE_CHECKING: + from PIL.Image import Image + + class RequestType(Enum): LOGLIKELIHOOD = auto() LOGLIKELIHOOD_SINGLE_TOKEN = auto() @@ -75,6 +79,7 @@ class LoglikelihoodRequest(Request): request_type = RequestType.LOGLIKELIHOOD tokenized_context: list[int] = None tokenized_continuation: list[int] = None + images: Optional[list["Image"]] = None @dataclass @@ -92,6 +97,7 @@ class LoglikelihoodSingleTokenRequest(Request): request_type = RequestType.LOGLIKELIHOOD_SINGLE_TOKEN tokenized_context: list[int] = None tokenized_continuation: list[int] = None + images: Optional[list["Image"]] = None @dataclass @@ -105,6 +111,7 @@ class LoglikelihoodRollingRequest(Request): request_type = RequestType.LOGLIKELIHOOD_ROLLING tokenized_context: list[int] = None tokenized_continuation: list[int] = None + images: Optional[list["Image"]] = None @dataclass @@ -128,6 +135,7 @@ class GreedyUntilRequest(Request): num_samples: int = None do_sample: bool = False use_logits: bool = False + images: Optional[list["Image"]] = None @dataclass @@ -145,6 +153,7 @@ class GreedyUntilMultiTurnRequest(Request): generation_size: int request_type = RequestType.GREEDY_UNTIL_MULTI_TURN use_logits: bool = False + images: Optional[list["Image"]] = None class SampleUid(NamedTuple): @@ -190,6 +199,9 @@ class Doc: # The uncoditioned query shouldn't contain any information about the task, thus usually it's empty string or 'Answer:'. unconditioned_query: Optional[str] = None + # For multi-modal tasks + images: Optional[list["Image"]] = None + def __post_init__(self): if self.instruction is None: self.instruction = "" diff --git a/tests/reference_scores/Qwen2.5-VL-3B-Instruct-results-vlm.json b/tests/reference_scores/Qwen2.5-VL-3B-Instruct-results-vlm.json new file mode 100644 index 000000000..f6e7849b3 --- /dev/null +++ b/tests/reference_scores/Qwen2.5-VL-3B-Instruct-results-vlm.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3bce71550c40c0934883a71385ab6aaad19a823dcc28481e3b535d376bf9625c +size 3080 diff --git a/tests/reference_scores/Qwen2.5-VL-7B-Instruct-results-vlm.json b/tests/reference_scores/Qwen2.5-VL-7B-Instruct-results-vlm.json new file mode 100644 index 000000000..b9080f9f7 --- /dev/null +++ b/tests/reference_scores/Qwen2.5-VL-7B-Instruct-results-vlm.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d07d8341188999f359a530e1dae4cd8ec3936d4046232a68b90a56c9f2994b3c +size 3083 diff --git a/tests/slow_tests/test_accelerate_vlm_model.py b/tests/slow_tests/test_accelerate_vlm_model.py new file mode 100644 index 000000000..fed43edc3 --- /dev/null +++ b/tests/slow_tests/test_accelerate_vlm_model.py @@ -0,0 +1,103 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import json +import os +from functools import lru_cache, partial +from typing import Callable, Tuple + +import pytest +from deepdiff import DeepDiff + +from lighteval.main_accelerate import accelerate # noqa: E402 + + +# Set env var for deterministic run of models +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + +MODELS_ARGS = [ + { + "model_name": "examples/model_configs/transformers_vlm_model.yaml", + "use_chat_template": True, + "results_file": "tests/reference_scores/Qwen2.5-VL-3B-Instruct-results-vlm.json", + } +] +TASKS = "lighteval|mmmu_pro:standard-4|0|0" + +ModelInput = Tuple[str, Callable[[], dict]] + + +@lru_cache(maxsize=len(MODELS_ARGS)) +def run_model(model_name: str, use_chat_template: bool): + """Runs the full main as a black box, using the input model and tasks, on 10 samples without parallelism""" + results = accelerate( + model_args=model_name, + tasks=TASKS, + use_chat_template=use_chat_template, + output_dir="", + dataset_loading_processes=1, + save_details=False, + max_samples=30, + vision_model=True, + ) + return results + + +def generate_tests() -> list[ModelInput]: + """Generate test parameters for all models and tasks.""" + + tests = [] + for model_args in MODELS_ARGS: + predictions_lite = partial(run_model, model_args["model_name"], model_args["use_chat_template"]) + tests.append((model_args, predictions_lite)) + + return tests + + +# generates the model predictions parameters at test collection time +tests: list[ModelInput] = generate_tests() +ids = [f"{model_input[0]['model_name']}" for model_input in tests] + + +@pytest.mark.slow +@pytest.mark.parametrize("tests", tests, ids=ids) +def test_accelerate_model_prediction(tests: list[ModelInput]): + """Evaluates a model on a full task - is parametrized using pytest_generate_test""" + model_args, get_predictions = tests + + # Load the reference results + with open(model_args["results_file"], "r") as f: + reference_results = json.load(f)["results"] + + # Change the key names, replace '|' with ':' + reference_results = {k.replace("|", ":"): v for k, v in reference_results.items()} + + # Get the predictions + predictions = get_predictions()["results"] + + # Convert defaultdict values to regular dict for comparison + predictions_dict = {k: dict(v) if hasattr(v, "default_factory") else v for k, v in predictions.items()} + + # Compare the predictions with the reference results + diff = DeepDiff(reference_results, predictions_dict, ignore_numeric_type_changes=True, math_epsilon=0.05) + + assert diff == {}, f"Differences found: {diff}"