Skip to content

Commit 6e95d2a

Browse files
Fix breakdown infer (#534)
* fully support ormsgpack * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * dependency * torch==2.4.1 windows compilable * Update docs * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove unused code * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove autorerank * api usage * back slash * fix docs * Fix infer warmup params * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * max_new_tokens=1024 * Fix break down infer --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 4ae0d07 commit 6e95d2a

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

tools/llama/generate.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ def worker():
605605
multiple=True,
606606
)
607607
@click.option("--num-samples", type=int, default=1)
608-
@click.option("--max-new-tokens", type=int, default=0)
608+
@click.option("--max-new-tokens", type=int, default=1024)
609609
@click.option("--top-p", type=float, default=0.7)
610610
@click.option("--repetition-penalty", type=float, default=1.2)
611611
@click.option("--temperature", type=float, default=0.7)
@@ -650,7 +650,10 @@ def main(
650650
model, decode_one_token = load_model(
651651
checkpoint_path, device, precision, compile=compile
652652
)
653-
653+
with torch.device(device):
654+
model.setup_caches(
655+
max_batch_size=1, max_seq_len=2048, dtype=next(model.parameters()).dtype
656+
)
654657
if torch.cuda.is_available():
655658
torch.cuda.synchronize()
656659

0 commit comments

Comments
 (0)