Skip to content

Add validation and batched inference to flux #1205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions torchtitan/components/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,56 @@ def __init__(
def should_log(self, step: int) -> bool:
return step == 1 or step % self.job_config.metrics.log_freq == 0

def val_log(
self,
step: int,
loss: float,
extra_metrics: dict[str, Any] | None = None,
):

time_delta = time.perf_counter() - self.time_last_log

# tokens per second per device, abbreviated as tps
tps = self.ntokens_since_last_log / (
time_delta * self.parallel_dims.non_data_parallel_size
)

time_end_to_end = time_delta / self.job_config.metrics.log_freq

device_mem_stats = self.device_memory_monitor.get_peak_stats()

metrics = {
"val_loss_metrics/avg_loss": loss,
"val_throughput(tps)": tps,
"val_time_metrics/end_to_end(s)": time_end_to_end,
"val_memory/max_active(GiB)": device_mem_stats.max_active_gib,
"val_memory/max_active(%)": device_mem_stats.max_active_pct,
"val_memory/max_reserved(GiB)": device_mem_stats.max_reserved_gib,
"val_memory/max_reserved(%)": device_mem_stats.max_reserved_pct,
"val_memory/num_alloc_retries": device_mem_stats.num_alloc_retries,
"val_memory/num_ooms": device_mem_stats.num_ooms,
}


if extra_metrics:
metrics.update(extra_metrics)

self.logger.log(metrics, step)

color = self.color
logger.info(
f"{color.magenta}val: {step:2} "
f"{color.green}val loss: {loss:7.4f} "
f"{color.yellow}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB"
f"({device_mem_stats.max_reserved_pct:.2f}%) "
f"{color.blue}tps: {round(tps):,} {color.reset}"
)

self.ntokens_since_last_log = 0
self.data_loading_times.clear()
self.time_last_log = time.perf_counter()
self.device_memory_monitor.reset_peak_stats()

def log(
self,
step: int,
Expand Down
30 changes: 30 additions & 0 deletions torchtitan/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,36 @@
from torchtitan.tools.utils import device_module, device_type


def dist_collect(
x: torch.Tensor,
mesh: DeviceMesh,
extra_pg: dist.ProcessGroup | None = None,
) -> torch.Tensor:
"""Collect and sum tensors across devices.

Unlike _dist_reduce, this function returns the full tensor after reduction
rather than extracting a scalar item.

Args:
x (torch.Tensor): Input tensor.
mesh (DeviceMesh): Device mesh to use for reduction.
extra_pg (dist.ProcessGroup, optional): Extra process group to use for reduction.
Defaults to None. If provided, this all_reduce will be called for the extra
process group, and then the result will be all_reduced for the mesh.

Returns:
torch.Tensor: The summed tensor collected from all devices.
"""
if isinstance(x, DTensor):
# functional collectives do not support DTensor inputs
x = x.full_tensor()

if extra_pg is not None:
x = funcol.all_reduce(x, reduceOp="sum", group=extra_pg)

return funcol.all_reduce(x, reduceOp="sum", group=mesh)


def _dist_reduce(
x: torch.Tensor,
reduceOp: str,
Expand Down
11 changes: 7 additions & 4 deletions torchtitan/experiments/flux/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from torchtitan.components.lr_scheduler import build_lr_schedulers
from torchtitan.components.optimizer import build_optimizers
from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_dataloader
from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_train_dataloader, build_flux_val_dataloader
from torchtitan.experiments.flux.loss import build_mse_loss
from torchtitan.experiments.flux.model.autoencoder import AutoEncoderParams
from torchtitan.experiments.flux.parallelize_flux import parallelize_flux
Expand Down Expand Up @@ -104,17 +104,20 @@
}


register_train_spec(
TrainSpec(
train_spec = TrainSpec(
name="flux",
cls=FluxModel,
config=flux_configs,
parallelize_fn=parallelize_flux,
pipelining_fn=None,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_flux_dataloader,
build_dataloader_fn=build_flux_train_dataloader,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this aligns with my proposal to do validation in torchtitan (not just for flux but also for other models). #1210
I would hope we can take a more principled approach and make general improvements, instead of doing an ad hoc change here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely. I wanted to enable this functionality for flux asap, so this is hacky.

Since it will involve changes to some central components in torchtitan, I didnt want to attempt a full implementation just yet, and Im not sure I'd have the bandwidth for this, especially if its work that someone is already doing / plans on doing.

I'm happy to remove the validation dataset bit and wait on a proper implementation being added to main. until then the validation metrics I added in this pr could instead target a subset of the training set, for example

build_tokenizer_fn=None,
build_loss_fn=build_mse_loss,
)
# monkey patch a build_val_dataloader_fn to the train_spec
train_spec.build_val_dataloader_fn = build_flux_val_dataloader
register_train_spec(
train_spec
)
122 changes: 89 additions & 33 deletions torchtitan/experiments/flux/dataset/flux_dataset.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we take first 30_000 samples as validation dataset, will it overlap with the training dataset?

A alternative ways to specify the data_files (eg, dataset = load_dataset("json", data_files={"train": base_url + "train-v1.1.json", "validation": base_url + "dev-v1.1.json"}, field="data")). https://huggingface.co/docs/datasets/en/loading, if we are loading dataset from hugging face directly.

If we are loading data locally, we could keep a _info.json, (https://huggingface.co/datasets/pixparse/cc12m-wds/blob/main/_info.json) to specify train / validation split

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it will. This was just a temporary solution to functionally verify the validation loop. I wanted to ask if you had some insights on how we should include the coco2014 dataset, given that its not easily available on hf hub.

Would we add download instructions to the readme and load it locally?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to using coco dataset because the stable diffusion paper? I think we should keep it simplify and just cut some part from the cc12m dataset to work as validation group.

Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import json
import math
import os
import random
from dataclasses import dataclass
from typing import Any, Callable, Optional

Expand Down Expand Up @@ -33,12 +35,13 @@
def _process_cc12m_image(
img: PIL.Image.Image,
output_size: int = 256,
skip_low_resolution: bool = True,
) -> Optional[torch.Tensor]:
"""Process CC12M image to the desired size."""

width, height = img.size
# Skip low resolution images
if width < output_size or height < output_size:
if skip_low_resolution and (width < output_size or height < output_size):
return None

if width >= height:
Expand Down Expand Up @@ -106,7 +109,8 @@ def _cc12m_wds_data_processor(
result = {
"image": img,
"clip_tokens": clip_tokens, # type: List[int]
"t5_tokens": t5_tokens, # type: List[int]
"t5_tokens": t5_tokens, # type: List[int],
"txt": sample["txt"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm why adding this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is useful for evaluation and inference to be able to associate the generated image with the prompt that was used.

}
if include_sample_id:
result["id"] = sample["__key__"]
Expand Down Expand Up @@ -174,6 +178,13 @@ class TextToImageDatasetConfig:
loader=lambda path: load_dataset(path, split="train", streaming=True),
data_processor=_cc12m_wds_data_processor,
),
"cc12m-wds-30k": TextToImageDatasetConfig(
path="pixparse/cc12m-wds",
loader=lambda path: load_dataset(path, split="train", streaming=True).take(
30_000
),
data_processor=_cc12m_wds_data_processor,
),
"cc12m-preprocessed": TextToImageDatasetConfig(
path="outputs/preprocessed",
loader=lambda path: load_dataset(
Expand Down Expand Up @@ -285,42 +296,48 @@ def __init__(

# Variables for checkpointing
self._sample_idx = 0
self._all_samples: list[dict[str, Any]] = []
self._epoch = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the purpose of adding this variable and in general what's the purpose of making these changes around data loader?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is around how a dataloader should initially set when we load from a checkpoint.
The issue with the current logic is that it does not behave well properly with non-infinite datasets. Once the end of the dataset is reached, sample_idx is not reset. This means that every subsequent time the dataset is used, all of its samples will be skipped.

To work around this, I introduce 2 thingsL

  1. _restored_checkpoint flag, which ensures we only skip samples if we have just restored from a checkpoint. subsequent epochs should not skip samples.
  2. Non-infinite datasets also need _sample_idx to be reset at the end

self._restored_checkpoint = False

def reset(self):
self._sample_idx = 0

def _get_data_iter(self):
if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
return iter([])

it = iter(self._data)
for _ in range(self._sample_idx):
next(it)
return it
return iter(self._data)

def __iter__(self):
dataset_iterator = self._get_data_iter()
# Initialize the dataset iterator
iterator = self._get_data_iter()

# Skip samples if we're resuming from a checkpoint
if self._restored_checkpoint:
logger.info(f"Restoring dataset state: skipping {self._sample_idx} samples")
for _ in range(self._sample_idx):
next(iterator)
self._restored_checkpoint = False

while True:
try:
sample = next(dataset_iterator)
except StopIteration:
if not self.infinite:
logger.warning(
f"Dataset {self.dataset_name} has run out of data. \
This might cause NCCL timeout if data parallelism is enabled."
)
break
else:
# Reset offset for the next iteration if infinite
self._sample_idx = 0
logger.info(f"Dataset {self.dataset_name} is being re-looped.")
dataset_iterator = self._get_data_iter()
continue
sample = next(iterator)
except (UnicodeDecodeError, SyntaxError, OSError) as e:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my training, I added this line before to capture some data loading error, eg, corrupted image header when PIL.image is reading, or corrupted .tar file header etc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, I must have removed it by accident

# Handle other exception, eg, dataset corruption
logger.warning(
f"Dataset {self.dataset_name} has error while loading batch data. \
Error {type(e).__name__}: {e}. The error could be the result of a streaming glitch."
)
continue
except StopIteration:
# Handle the end of the iterator
self.reset()
if not self.infinite:
logger.warning(f"Dataset {self.dataset_name} has run out of data")
break
# Reset for next epoch if infinite
logger.warning(f"Dataset {self.dataset_name} is being re-looped")
iterator = self._get_data_iter()

# Use the dataset-specific preprocessor
sample_dict = self._data_processor(
Expand All @@ -333,46 +350,85 @@ def __iter__(self):
# skip low quality image or image with color channel = 1
if sample_dict["image"] is None:
logger.warning(
f"Low quality image {sample['__key__']} is skipped in Flux Dataloader."
f"Low quality image {sample['__key__']} is skipped in Flux Dataloader"
)
continue

# Classifier-free guidance: Replace some of the strings with empty strings.
# Distinct random seed is initialized at the beginning of training for each FSDP rank.
dropout_prob = self.job_config.training.classifer_free_guidance_prob
if dropout_prob > 0.0:
if torch.rand(1).item() < dropout_prob:
sample_dict["t5_tokens"] = self._t5_empty_token
sample_dict["clip_tokens"] = self._clip_empty_token
if dropout_prob > 0.0 and random.random() < dropout_prob:
sample_dict["t5_tokens"] = self._t5_empty_token
sample_dict["clip_tokens"] = self._clip_empty_token

self._sample_idx += 1

labels = sample_dict.pop("image")
yield sample_dict, labels

def load_state_dict(self, state_dict):
self._sample_idx = state_dict["sample_idx"]
self._sample_idx = state_dict.get("sample_idx", 0)
self._restored_checkpoint = True # Mark that we've loaded from a checkpoint

def state_dict(self):
return {
"sample_idx": self._sample_idx,
}


def build_flux_dataloader(
def build_flux_train_dataloader(
dp_world_size: int,
dp_rank: int,
job_config: JobConfig,
tokenizer: FluxTokenizer | None,
infinite: bool = True,
) -> ParallelAwareDataloader:
return _build_flux_dataloader(
dataset_name=job_config.training.dataset,
dataset_path=job_config.training.dataset_path,
dp_world_size=dp_world_size,
dp_rank=dp_rank,
job_config=job_config,
tokenizer=tokenizer,
infinite=infinite,
batch_size=job_config.training.batch_size,
)


def build_flux_val_dataloader(
dp_world_size: int,
dp_rank: int,
job_config: JobConfig,
tokenizer: FluxTokenizer | None,
infinite: bool = False,
) -> ParallelAwareDataloader:
print(job_config.eval.dataset_path)

return _build_flux_dataloader(
dataset_name=job_config.eval.dataset,
dataset_path=job_config.eval.dataset_path,
dp_world_size=dp_world_size,
dp_rank=dp_rank,
job_config=job_config,
tokenizer=tokenizer,
infinite=infinite,
batch_size=job_config.eval.batch_size,
)


def _build_flux_dataloader(
dataset_name: str,
dataset_path: str,
dp_world_size: int,
dp_rank: int,
job_config: JobConfig,
# This parameter is not used, keep it for compatibility
tokenizer: FluxTokenizer | None,
infinite: bool = True,
include_sample_id: bool = False,
batch_size: int = 4,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this magic number?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should indeed be a parameter

) -> ParallelAwareDataloader:
"""Build a data loader for HuggingFace datasets."""
dataset_name = job_config.training.dataset
dataset_path = job_config.training.dataset_path
batch_size = job_config.training.batch_size

t5_encoder_name = job_config.encoder.t5_encoder
clip_encoder_name = job_config.encoder.clip_encoder
max_t5_encoding_len = job_config.encoder.max_t5_encoding_len
Expand Down
19 changes: 19 additions & 0 deletions torchtitan/experiments/flux/flux_argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,24 @@ class Eval:
"""Frequency of evaluation/sampling during training"""
save_img_folder: str = "img"
"""Directory to save image generated/sampled from the model"""
dataset: str | None = None
"""Dataset to use for validation."""
dataset_path: str | None = None
"""
Path to the dataset in the file system.
"""
batch_size: int = 16
"""Batch size for validation."""

@dataclass
class Inference:
"""Inference configuration"""
save_path: str = "inference_results"
"""Path to save the inference results"""
prompts_path: str = "prompts.txt"
"""Path to file with newline separated prompts to generate images for"""
batch_size: int = 16
"""Batch size for inference"""


@dataclass
Expand All @@ -54,3 +72,4 @@ class JobConfig:
training: Training = field(default_factory=Training)
encoder: Encoder = field(default_factory=Encoder)
eval: Eval = field(default_factory=Eval)
inference: Inference = field(default_factory=Inference)
Loading