Skip to content
This repository was archived by the owner on Sep 18, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion inference/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def generate(self, task_name):
ds_loader = DataLoader(ds_tokenized, batch_size=1)

generations, generations_raw = complete_code(
task, self.model, sampling_params, ds_loader, self.args.batch_size, n_tasks
task, self.model, sampling_params, ds_loader, self.args.batch_size, n_tasks, backend=self.args.backend, tokenizer=self.tokenizer
)

references = [task.get_reference(dataset[i]) for i in range(n_tasks)]
Expand Down
41 changes: 34 additions & 7 deletions inference/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,14 @@ def parse_args():
default="references.json",
help="Path for saving the reference solutions/tests",
)

parser.add_argument(
"--backend",
type=str,
default="vllm",
help="Load Model on (default: `vllm`). Use `gptqmodel` for gptq quantized models.",
)

args = parser.parse_args()

precision_map = {
Expand Down Expand Up @@ -185,13 +193,32 @@ def main():
transformers.logging.set_verbosity_error()
datasets.logging.set_verbosity_error()

model = LLM(
model=args.model,
dtype=args.precision,
trust_remote_code=args.trust_remote_code,
gpu_memory_utilization=0.98,
tensor_parallel_size=args.tensor_parallel_size,
)
if args.backend == 'vllm':
model = LLM(
model=args.model,
dtype=args.precision,
trust_remote_code=args.trust_remote_code,
gpu_memory_utilization=0.98,
tensor_parallel_size=args.tensor_parallel_size,
)
elif args.backend == 'gptqmodel':
try:
from gptqmodel import GPTQModel
except ModuleNotFoundError as exception:
raise type(exception)(
"Tried to load gptqmodel, but gptqmodel is not installed ",
"please install gptqmodel via `pip install gptqmodel --no-build-isolation`",
)

kwargs = {
"model_id_or_path": args.model,
"trust_remote_code": args.trust_remote_code,
"torch_dtype": args.precision
}
model = GPTQModel.load(**kwargs)
else:
raise ValueError("backend support: [vllm, gptqmodel]")


tokenizer = AutoTokenizer.from_pretrained(
args.model,
Expand Down
55 changes: 50 additions & 5 deletions inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def complete_code(
n_tasks,
prefix="",
postprocess=True,
backend='vllm',
tokenizer=None,
):
max_length_generation = sampling_params.max_tokens
code_gens = defaultdict(list)
Expand All @@ -92,12 +94,55 @@ def complete_code(
)
continue
sampling_params.max_tokens = max_length_generation - num_tokens
outputs = model.generate(
prompt_token_ids=inputs, sampling_params=sampling_params, use_tqdm=False
)

generated_tasks = batch["row_index"].repeat(batch_size)
generated_texts = [o.text for o in outputs[0].outputs]
if backend == 'vllm':
outputs = model.generate(
prompt_token_ids=inputs, sampling_params=sampling_params, use_tqdm=False
)
generated_texts = [o.text for o in outputs[0].outputs]
elif backend == 'gptqmodel':
inputs_tensor = torch.tensor(inputs).to(model.device)
model_kwargs = {
"input_ids": inputs_tensor,
"max_new_tokens": sampling_params.max_tokens,
"num_return_sequences": sampling_params.n,
}

# normalize stop_strings
if sampling_params.stop is not None:
if isinstance(sampling_params.stop, str):
sampling_params.stop = [sampling_params.stop]

# transformer require tokenizer to be passed when stop_strings is used
model_kwargs["stop_strings"] = sampling_params.stop
model_kwargs["tokenizer"] = tokenizer

do_sample = True if sampling_params.temperature != 1.0 else False

if do_sample:
model_kwargs["temperature"] = sampling_params.temperature
model_kwargs["do_sample"] = do_sample

if sampling_params.top_k > 0:
model_kwargs["top_k"] = sampling_params.top_k

if sampling_params.top_p != 1.0:
model_kwargs["top_p"] = sampling_params.top_p

outputs = model.generate(**model_kwargs)
generated_texts = tokenizer.batch_decode(
outputs[:, inputs_tensor.size(-1):],
skip_special_tokens=True,
)

# transformers will inlcude stop_string in output, normalize (strip) to align with vllm default
if sampling_params.stop is not None:
for stop_string in sampling_params.stop:
generated_texts = [generated_text.replace(stop_string, "") for generated_text in generated_texts]
else:
raise ValueError("backend support: [vllm, gptqmodel]")

combined_texts = [
batch["prompt"][0] + generated_text for generated_text in generated_texts
]
Expand All @@ -109,4 +154,4 @@ def complete_code(
code_gens[task_idx].append(text_processed)
code_gens_raw[task_idx].append(text)

return code_gens, code_gens_raw
return code_gens, code_gens_raw