Skip to content

Commit ba041cb

Browse files
authored
Adding TensorBoard logging for inference metrics (#283)
* Added tensorboard logging for inference metrics * Removed config.tensorboard_dir * Added tokamax as a requirement in requirements_with_jax_ai_image.txt * Added logging for model details * Added logging for model details * Adding Tensorboard logging for inference metrics * Adding Tensorboard logging for inference metrics * Adding Tensorboard logging for inference metrics
1 parent 4896870 commit ba041cb

File tree

2 files changed

+36
-5
lines changed

2 files changed

+36
-5
lines changed

requirements_with_jax_ai_image.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ orbax-checkpoint
3030
tokenizers==0.21.0
3131
huggingface_hub>=0.30.2
3232
transformers==4.48.1
33+
tokamax
3334
einops==0.8.0
3435
sentencepiece
3536
aqtp

src/maxdiffusion/generate_wan.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,11 @@ def inference_generate_video(config, pipeline, filename_prefix=""):
139139

140140

141141
def run(config, pipeline=None, filename_prefix=""):
142-
print("seed: ", config.seed)
143142
model_key = config.model_name
143+
# Initialize TensorBoard writer
144+
writer = max_utils.initialize_summary_writer(config)
145+
if jax.process_index() == 0 and writer:
146+
max_logging.log(f"TensorBoard logs will be written to: {config.tensorboard_dir}")
144147

145148
checkpointer_lib = get_checkpointer(model_key)
146149
WanCheckpointer = checkpointer_lib.WanCheckpointer
@@ -163,8 +166,19 @@ def run(config, pipeline=None, filename_prefix=""):
163166
)
164167

165168
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
166-
167-
print("compile time: ", (time.perf_counter() - s0))
169+
max_logging.log("===================== Model details =======================")
170+
max_logging.log(f"model name: {config.model_name}")
171+
max_logging.log(f"model path: {config.pretrained_model_name_or_path}")
172+
max_logging.log("model type: t2v")
173+
max_logging.log(f"hardware: {jax.devices()[0].platform}")
174+
max_logging.log(f"number of devices: {jax.device_count()}")
175+
max_logging.log(f"per_device_batch_size: {config.per_device_batch_size}")
176+
max_logging.log("============================================================")
177+
178+
compile_time = time.perf_counter() - s0
179+
max_logging.log(f"compile_time: {compile_time}")
180+
if writer and jax.process_index() == 0:
181+
writer.add_scalar("inference/compile_time", compile_time, global_step=0)
168182
saved_video_path = []
169183
for i in range(len(videos)):
170184
video_path = f"{filename_prefix}wan_output_{config.seed}_{i}.mp4"
@@ -175,14 +189,30 @@ def run(config, pipeline=None, filename_prefix=""):
175189

176190
s0 = time.perf_counter()
177191
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
178-
print("generation time: ", (time.perf_counter() - s0))
192+
generation_time = time.perf_counter() - s0
193+
max_logging.log(f"generation_time: {generation_time}")
194+
if writer and jax.process_index() == 0:
195+
writer.add_scalar("inference/generation_time", generation_time, global_step=0)
196+
num_devices = jax.device_count()
197+
num_videos = num_devices * config.per_device_batch_size
198+
if num_videos > 0:
199+
generation_time_per_video = generation_time / num_videos
200+
writer.add_scalar("inference/generation_time_per_video", generation_time_per_video, global_step=0)
201+
max_logging.log(f"generation time per video: {generation_time_per_video}")
202+
else:
203+
max_logging.log("Warning: Number of videos is zero, cannot calculate generation_time_per_video.")
204+
179205

180206
s0 = time.perf_counter()
181207
if config.enable_profiler:
182208
max_utils.activate_profiler(config)
183209
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
184210
max_utils.deactivate_profiler(config)
185-
print("generation time: ", (time.perf_counter() - s0))
211+
generation_time_with_profiler = time.perf_counter() - s0
212+
max_logging.log(f"generation_time_with_profiler: {generation_time_with_profiler}")
213+
if writer and jax.process_index() == 0:
214+
writer.add_scalar("inference/generation_time_with_profiler", generation_time_with_profiler, global_step=0)
215+
186216
return saved_video_path
187217

188218

0 commit comments

Comments
 (0)