Skip to content

Commit f1f3864

Browse files
fix streamer problems (#3601)
* fix streamer problems * format correction --------- Co-authored-by: Chunyuan WU <[email protected]>
1 parent d8f39f4 commit f1f3864

File tree

4 files changed

+34
-14
lines changed

4 files changed

+34
-14
lines changed

examples/cpu/llm/inference/distributed/run_generation_tp.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -248,11 +248,16 @@
248248
model = model.to(memory_format=torch.channels_last)
249249

250250
num_beams = 1 if args.greedy else 4
251-
# generate args
251+
streamer = None
252252
if args.streaming:
253-
streamer = TextStreamer(tokenizer)
254-
else:
255-
streamer = None
253+
if num_beams != 1 or args.batch_size != 1:
254+
logger.warning(
255+
"--streaming only supported in greedy search mode (--greedy) with --batch-size 1. Disabling streaming output."
256+
)
257+
else:
258+
streamer = TextStreamer(tokenizer)
259+
260+
# generate args
256261
generate_kwargs = dict(
257262
do_sample=False,
258263
temperature=0.9,

examples/cpu/llm/inference/distributed/run_generation_with_deepspeed.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -713,10 +713,15 @@ def write_checkpoints_json():
713713
# Generate
714714
print_rank0(f"*** Starting to generate {num_tokens} tokens with bs={args.batch_size}")
715715

716+
streamer = None
716717
if args.streaming:
717-
streamer = TextStreamer(tokenizer)
718-
else:
719-
streamer = None
718+
if num_beams != 1 or args.batch_size != 1:
719+
logger.warning(
720+
"--streaming only supported in greedy search mode (--greedy) with --batch-size 1. Disabling streaming output."
721+
)
722+
elif local_rank == 0:
723+
streamer = TextStreamer(tokenizer)
724+
720725
generate_kwargs = dict(
721726
do_sample=False,
722727
num_beams=num_beams,

examples/cpu/llm/inference/single_instance/run_generation.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -251,11 +251,16 @@
251251
model = model.eval()
252252
model = model.to(memory_format=torch.channels_last)
253253
num_beams = 1 if args.greedy else 4
254-
# generate args
254+
streamer = None
255255
if args.streaming:
256-
streamer = TextStreamer(tokenizer)
257-
else:
258-
streamer = None
256+
if num_beams != 1 or args.batch_size != 1:
257+
logger.warning(
258+
"--streaming only supported in greedy search mode (--greedy) with --batch-size 1. Disabling streaming output."
259+
)
260+
else:
261+
streamer = TextStreamer(tokenizer)
262+
263+
# generate args
259264
generate_kwargs = dict(
260265
do_sample=False,
261266
temperature=0.9,

examples/cpu/llm/inference/single_instance/run_quantization.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -579,10 +579,15 @@ def download_and_open(url: str) -> Image.Image:
579579

580580
tokenizer = model.get_tokenizer()
581581
print("Data type of the model:", user_model.dtype)
582+
streamer = None
582583
if args.streaming:
583-
streamer = TextStreamer(tokenizer)
584-
else:
585-
streamer = None
584+
if num_beams != 1 or args.batch_size != 1:
585+
print(
586+
"--streaming only supported in greedy search mode (--greedy) with --batch-size 1. Disabling streaming output."
587+
)
588+
else:
589+
streamer = TextStreamer(tokenizer)
590+
586591
generate_kwargs = dict(
587592
do_sample=False,
588593
temperature=0.9,

0 commit comments

Comments
 (0)