File tree 4 files changed +34
-14
lines changed
examples/cpu/llm/inference
4 files changed +34
-14
lines changed Original file line number Diff line number Diff line change 248
248
model = model .to (memory_format = torch .channels_last )
249
249
250
250
num_beams = 1 if args .greedy else 4
251
- # generate args
251
+ streamer = None
252
252
if args .streaming :
253
- streamer = TextStreamer (tokenizer )
254
- else :
255
- streamer = None
253
+ if num_beams != 1 or args .batch_size != 1 :
254
+ logger .warning (
255
+ "--streaming only supported in greedy search mode (--greedy) with --batch-size 1. Disabling streaming output."
256
+ )
257
+ else :
258
+ streamer = TextStreamer (tokenizer )
259
+
260
+ # generate args
256
261
generate_kwargs = dict (
257
262
do_sample = False ,
258
263
temperature = 0.9 ,
Original file line number Diff line number Diff line change @@ -713,10 +713,15 @@ def write_checkpoints_json():
713
713
# Generate
714
714
print_rank0 (f"*** Starting to generate { num_tokens } tokens with bs={ args .batch_size } " )
715
715
716
+ streamer = None
716
717
if args .streaming :
717
- streamer = TextStreamer (tokenizer )
718
- else :
719
- streamer = None
718
+ if num_beams != 1 or args .batch_size != 1 :
719
+ logger .warning (
720
+ "--streaming only supported in greedy search mode (--greedy) with --batch-size 1. Disabling streaming output."
721
+ )
722
+ elif local_rank == 0 :
723
+ streamer = TextStreamer (tokenizer )
724
+
720
725
generate_kwargs = dict (
721
726
do_sample = False ,
722
727
num_beams = num_beams ,
Original file line number Diff line number Diff line change 251
251
model = model .eval ()
252
252
model = model .to (memory_format = torch .channels_last )
253
253
num_beams = 1 if args .greedy else 4
254
- # generate args
254
+ streamer = None
255
255
if args .streaming :
256
- streamer = TextStreamer (tokenizer )
257
- else :
258
- streamer = None
256
+ if num_beams != 1 or args .batch_size != 1 :
257
+ logger .warning (
258
+ "--streaming only supported in greedy search mode (--greedy) with --batch-size 1. Disabling streaming output."
259
+ )
260
+ else :
261
+ streamer = TextStreamer (tokenizer )
262
+
263
+ # generate args
259
264
generate_kwargs = dict (
260
265
do_sample = False ,
261
266
temperature = 0.9 ,
Original file line number Diff line number Diff line change @@ -579,10 +579,15 @@ def download_and_open(url: str) -> Image.Image:
579
579
580
580
tokenizer = model .get_tokenizer ()
581
581
print ("Data type of the model:" , user_model .dtype )
582
+ streamer = None
582
583
if args .streaming :
583
- streamer = TextStreamer (tokenizer )
584
- else :
585
- streamer = None
584
+ if num_beams != 1 or args .batch_size != 1 :
585
+ print (
586
+ "--streaming only supported in greedy search mode (--greedy) with --batch-size 1. Disabling streaming output."
587
+ )
588
+ else :
589
+ streamer = TextStreamer (tokenizer )
590
+
586
591
generate_kwargs = dict (
587
592
do_sample = False ,
588
593
temperature = 0.9 ,
You can’t perform that action at this time.
0 commit comments