Skip to content
168 changes: 162 additions & 6 deletions src/lighteval/models/transformers/transformers_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@

import logging
import os
from typing import Dict, Optional, Tuple, Union
from datetime import timedelta
from typing import Optional, Tuple, Union

import torch
import torch.nn.functional as F
import transformers
Expand Down Expand Up @@ -252,14 +252,15 @@ def from_model(

# Instanciate the object without using __init__
self = cls.__new__(cls)
self.config = config
self.transformers_config = model.config
self.generation_config_dict = config.generation_parameters.to_transformers_dict()
self.config = config if config is not None else TransformersModelConfig(model_name=model.config.name_or_path)
if config is not None:
self.generation_config_dict = config.generation_parameters.to_transformers_dict()
self._max_length = self._init_max_length()
self._tokenizer = self._create_auto_tokenizer()
self.batch_size = config.batch_size
self.batch_size = getattr(config, "batch_size", None)
self.model_name = _simplify_name(model.name_or_path)
self.model_sha = config.get_model_sha()
self.model_sha = self.config.get_model_sha()

# If model_parallel is not set we compare the number of processes with the number of GPUs
self.model = model
Expand Down Expand Up @@ -500,6 +501,121 @@ def forward_batch(batch_size):
logger.info(f"Determined largest batch size: {batch_size}")
return batch_size


def greedy_until_multi_turn( # noqa: C901
self,
docs: list[Doc],
) -> ModelResponse:
raise NotImplementedError("This method is not implemented for this model")

def _continious_greedy_until(
self,
docs: list[Doc],
) -> list[ModelResponse]:
"""
Generates responses using a greedy decoding strategy until certain ending conditions are met.

Args:
requests (list[Request]): list of requests containing the context and ending conditions.
override_bs (int, optional): Override the batch size for generation. Defaults to None.

Returns:
list[GenerateReturn]: list of generated responses.
"""
for request in requests:
request.stop_sequence = as_list(request.stop_sequence) + [self.tokenizer.eos_token]
request.tokenized_context = self.tok_encode(request.context)

dataset = GenerativeTaskDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS)
results = []

for split in tqdm(
dataset.splits_iterator(),
total=dataset.num_dataset_splits,
desc="Splits",
position=0,
disable=False, # self.disable_tqdm,
):
# For chat models, generation stops with EOS token, so we don't need to specify stop tokens
if self.use_chat_template:
stop_tokens = []
else:
# NOTE: we are assuming all items in a batch behave similarly (same
# stop_tokens and max_tokens genrated) which is not necessarily
# the case! Because of that we only use batch size of 1
stop_tokens = split[0].stop_sequence

max_new_tokens = self.config.generation_parameters.max_new_tokens or split[0].generation_size
returns_logits = split[0].use_logits
num_samples = split[0].num_samples

context = [sample.context for sample in split]
tokenized = self.tokenizer(context, add_special_tokens=self.add_special_tokens)

# The main question for this step is the following:
# Would we rather truncate the prompt to allow generation to go to max_new_tokens, at the risk
# of losing some meaning, or have some generations that are exceedingly short?
# The choice we go for here is to avoid truncating the prompt if we can, since it
# should have been managed by the prompt creator/few shot manager if requested by the user.
inputs = tokenized["input_ids"]
context_size = len(inputs[0])

# left truncate the inputs to the maximum length
if max_new_tokens is not None:
if context_size + max_new_tokens > self.max_length:
logger.warning(
f"{context_size + max_new_tokens=} which is greater than {self.max_length=}. Truncating context to {self.max_length - max_new_tokens} tokens."
)
context_size = self.max_length - max_new_tokens
if context_size < 0:
logger.critical(
f"{context_size=} is less than 0, either reduce the max_new_tokens or increase model max length."
)
raise ValueError("Context size is less than 0.")
inputs = [input[-context_size:] for input in inputs]
else:
if context_size > self.max_length:
logger.warning(
f"{context_size=} which is greater than {self.max_length=}. Truncating context to {self.max_length} tokens."
)
context_size = self.max_length
inputs = [input[-context_size:] for input in inputs]

_outputs = self._generate(
inputs=inputs,
max_new_tokens=max_new_tokens,
stop_tokens=stop_tokens,
returns_logits=returns_logits,
num_samples=num_samples,
)

for req_id, _output in _outputs.items():
output_token_ids = []
logprobs_raw = []
result = []

# for output in _output.outputs:
output_token_ids.append(_output.static_outputs)
# logprobs_raw.append(output.logprobs)
result.append(self.tokenizer.decode(_output.static_outputs))

if logprobs_raw and output_token_ids and False:
logprobs = [logprobs_raw[0][token_id].logprob for token_id in output_token_ids[0]]
else:
logprobs = []

input_token_ids = _output.full_prompt_ids
cur_response = GenerativeResponse(
result=result,
logits=logprobs,
generated_tokens=output_token_ids,
input_tokens=input_token_ids,
)
results.append(cur_response)

return dataset.get_original_order(results)


def greedy_until(
self,
docs: list[Doc],
Expand Down Expand Up @@ -613,12 +729,42 @@ def greedy_until(
stop_tokens=stop_tokens,
returns_logits=False,
num_samples=num_samples,
do_sample=do_sample,
use_fast=False,
)
results.extend(cur_reponses)

return dataset.get_original_order(results)

def _generate(
def greedy_until(
self,
requests: list[GreedyUntilRequest],
use_fast: bool = True,
) -> list[GenerativeResponse]:
if use_fast:
return self._continious_greedy_until(requests)
else:
return self._padded_greedy_until(requests)

def _generate_fast(
self,
inputs: list[list[int]],
max_new_tokens: Optional[int] = None,
stop_tokens: Optional[list[str]] = None,
returns_logits: Optional[bool] = False,
num_samples: int = 1,
generate: bool = True,
) -> Dict[str, GenerativeResponse]:
# Compute model generation
batch_outputs = self.model.generate_batch(
inputs=inputs,
generation_config=self.model.generation_config,
# You can pass request-specific overrides here, e.g., max_new_tokens=100
)

return batch_outputs

def _generate_padded(
self,
batch: Batch,
max_new_tokens: int,
Expand Down Expand Up @@ -704,6 +850,16 @@ def _generate(

return all_responses

def _generate(
self,
use_fast: bool = True,
**kwargs,
) -> list[GenerativeResponse]:
if use_fast:
return self._generate_fast(**kwargs)
else:
return self._generate_padded(**kwargs)

def loglikelihood(
self,
docs: list[Doc],
Expand Down
Loading