@@ -139,8 +139,11 @@ def inference_generate_video(config, pipeline, filename_prefix=""):
139139
140140
141141def run (config , pipeline = None , filename_prefix = "" ):
142- print ("seed: " , config .seed )
143142 model_key = config .model_name
143+ # Initialize TensorBoard writer
144+ writer = max_utils .initialize_summary_writer (config )
145+ if jax .process_index () == 0 and writer :
146+ max_logging .log (f"TensorBoard logs will be written to: { config .tensorboard_dir } " )
144147
145148 checkpointer_lib = get_checkpointer (model_key )
146149 WanCheckpointer = checkpointer_lib .WanCheckpointer
@@ -163,8 +166,19 @@ def run(config, pipeline=None, filename_prefix=""):
163166 )
164167
165168 videos = call_pipeline (config , pipeline , prompt , negative_prompt )
166-
167- print ("compile time: " , (time .perf_counter () - s0 ))
169+ max_logging .log ("===================== Model details =======================" )
170+ max_logging .log (f"model name: { config .model_name } " )
171+ max_logging .log (f"model path: { config .pretrained_model_name_or_path } " )
172+ max_logging .log ("model type: t2v" )
173+ max_logging .log (f"hardware: { jax .devices ()[0 ].platform } " )
174+ max_logging .log (f"number of devices: { jax .device_count ()} " )
175+ max_logging .log (f"per_device_batch_size: { config .per_device_batch_size } " )
176+ max_logging .log ("============================================================" )
177+
178+ compile_time = time .perf_counter () - s0
179+ max_logging .log (f"compile_time: { compile_time } " )
180+ if writer and jax .process_index () == 0 :
181+ writer .add_scalar ("inference/compile_time" , compile_time , global_step = 0 )
168182 saved_video_path = []
169183 for i in range (len (videos )):
170184 video_path = f"{ filename_prefix } wan_output_{ config .seed } _{ i } .mp4"
@@ -175,14 +189,30 @@ def run(config, pipeline=None, filename_prefix=""):
175189
176190 s0 = time .perf_counter ()
177191 videos = call_pipeline (config , pipeline , prompt , negative_prompt )
178- print ("generation time: " , (time .perf_counter () - s0 ))
192+ generation_time = time .perf_counter () - s0
193+ max_logging .log (f"generation_time: { generation_time } " )
194+ if writer and jax .process_index () == 0 :
195+ writer .add_scalar ("inference/generation_time" , generation_time , global_step = 0 )
196+ num_devices = jax .device_count ()
197+ num_videos = num_devices * config .per_device_batch_size
198+ if num_videos > 0 :
199+ generation_time_per_video = generation_time / num_videos
200+ writer .add_scalar ("inference/generation_time_per_video" , generation_time_per_video , global_step = 0 )
201+ max_logging .log (f"generation time per video: { generation_time_per_video } " )
202+ else :
203+ max_logging .log ("Warning: Number of videos is zero, cannot calculate generation_time_per_video." )
204+
179205
180206 s0 = time .perf_counter ()
181207 if config .enable_profiler :
182208 max_utils .activate_profiler (config )
183209 videos = call_pipeline (config , pipeline , prompt , negative_prompt )
184210 max_utils .deactivate_profiler (config )
185- print ("generation time: " , (time .perf_counter () - s0 ))
211+ generation_time_with_profiler = time .perf_counter () - s0
212+ max_logging .log (f"generation_time_with_profiler: { generation_time_with_profiler } " )
213+ if writer and jax .process_index () == 0 :
214+ writer .add_scalar ("inference/generation_time_with_profiler" , generation_time_with_profiler , global_step = 0 )
215+
186216 return saved_video_path
187217
188218
0 commit comments