-
Notifications
You must be signed in to change notification settings - Fork 384
Add validation and batched inference to flux #1205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1a0e264
8d2da80
2318c69
73feafd
37035db
3876386
f94bf0c
0df6366
d1fa2d1
0459518
bc2f123
a04df23
c232b63
81e6f46
25c0af1
caaf8f5
d389c1e
d06d1be
a7b2630
80e1768
7ed24ae
47567a4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,7 @@ | |
|
||
from torchtitan.components.lr_scheduler import build_lr_schedulers | ||
from torchtitan.components.optimizer import build_optimizers | ||
from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_dataloader | ||
from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_train_dataloader, build_flux_val_dataloader | ||
from torchtitan.experiments.flux.loss import build_mse_loss | ||
from torchtitan.experiments.flux.model.autoencoder import AutoEncoderParams | ||
from torchtitan.experiments.flux.parallelize_flux import parallelize_flux | ||
|
@@ -104,17 +104,20 @@ | |
} | ||
|
||
|
||
register_train_spec( | ||
TrainSpec( | ||
train_spec = TrainSpec( | ||
name="flux", | ||
cls=FluxModel, | ||
config=flux_configs, | ||
parallelize_fn=parallelize_flux, | ||
pipelining_fn=None, | ||
build_optimizers_fn=build_optimizers, | ||
build_lr_schedulers_fn=build_lr_schedulers, | ||
build_dataloader_fn=build_flux_dataloader, | ||
build_dataloader_fn=build_flux_train_dataloader, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this aligns with my proposal to do validation in torchtitan (not just for flux but also for other models). #1210 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Absolutely. I wanted to enable this functionality for flux asap, so this is hacky. Since it will involve changes to some central components in torchtitan, I didnt want to attempt a full implementation just yet, and Im not sure I'd have the bandwidth for this, especially if its work that someone is already doing / plans on doing. I'm happy to remove the validation dataset bit and wait on a proper implementation being added to main. until then the validation metrics I added in this pr could instead target a subset of the training set, for example |
||
build_tokenizer_fn=None, | ||
build_loss_fn=build_mse_loss, | ||
) | ||
# monkey patch a build_val_dataloader_fn to the train_spec | ||
train_spec.build_val_dataloader_fn = build_flux_val_dataloader | ||
register_train_spec( | ||
train_spec | ||
) |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we take first 30_000 samples as validation dataset, will it overlap with the training dataset? A alternative ways to specify the If we are loading data locally, we could keep a _info.json, (https://huggingface.co/datasets/pixparse/cc12m-wds/blob/main/_info.json) to specify train / validation split There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes it will. This was just a temporary solution to functionally verify the validation loop. I wanted to ask if you had some insights on how we should include the coco2014 dataset, given that its not easily available on hf hub. Would we add download instructions to the readme and load it locally? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you want to using coco dataset because the stable diffusion paper? I think we should keep it simplify and just cut some part from the cc12m dataset to work as validation group. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,8 @@ | |
|
||
import json | ||
import math | ||
import os | ||
import random | ||
from dataclasses import dataclass | ||
from typing import Any, Callable, Optional | ||
|
||
|
@@ -33,12 +35,13 @@ | |
def _process_cc12m_image( | ||
img: PIL.Image.Image, | ||
output_size: int = 256, | ||
skip_low_resolution: bool = True, | ||
) -> Optional[torch.Tensor]: | ||
"""Process CC12M image to the desired size.""" | ||
|
||
width, height = img.size | ||
# Skip low resolution images | ||
if width < output_size or height < output_size: | ||
if skip_low_resolution and (width < output_size or height < output_size): | ||
CarlosGomes98 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return None | ||
|
||
if width >= height: | ||
|
@@ -106,7 +109,8 @@ def _cc12m_wds_data_processor( | |
result = { | ||
"image": img, | ||
"clip_tokens": clip_tokens, # type: List[int] | ||
"t5_tokens": t5_tokens, # type: List[int] | ||
"t5_tokens": t5_tokens, # type: List[int], | ||
"txt": sample["txt"], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm why adding this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it is useful for evaluation and inference to be able to associate the generated image with the prompt that was used. |
||
} | ||
if include_sample_id: | ||
result["id"] = sample["__key__"] | ||
|
@@ -174,6 +178,13 @@ class TextToImageDatasetConfig: | |
loader=lambda path: load_dataset(path, split="train", streaming=True), | ||
data_processor=_cc12m_wds_data_processor, | ||
), | ||
"cc12m-wds-30k": TextToImageDatasetConfig( | ||
path="pixparse/cc12m-wds", | ||
loader=lambda path: load_dataset(path, split="train", streaming=True).take( | ||
30_000 | ||
), | ||
data_processor=_cc12m_wds_data_processor, | ||
), | ||
"cc12m-preprocessed": TextToImageDatasetConfig( | ||
path="outputs/preprocessed", | ||
loader=lambda path: load_dataset( | ||
|
@@ -285,42 +296,48 @@ def __init__( | |
|
||
# Variables for checkpointing | ||
self._sample_idx = 0 | ||
self._all_samples: list[dict[str, Any]] = [] | ||
self._epoch = 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the purpose of adding this variable and in general what's the purpose of making these changes around data loader? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The problem is around how a dataloader should initially set when we load from a checkpoint. To work around this, I introduce 2 thingsL
|
||
self._restored_checkpoint = False | ||
|
||
def reset(self): | ||
self._sample_idx = 0 | ||
|
||
def _get_data_iter(self): | ||
if isinstance(self._data, Dataset) and self._sample_idx == len(self._data): | ||
return iter([]) | ||
|
||
it = iter(self._data) | ||
for _ in range(self._sample_idx): | ||
next(it) | ||
return it | ||
return iter(self._data) | ||
|
||
def __iter__(self): | ||
dataset_iterator = self._get_data_iter() | ||
# Initialize the dataset iterator | ||
iterator = self._get_data_iter() | ||
|
||
# Skip samples if we're resuming from a checkpoint | ||
if self._restored_checkpoint: | ||
logger.info(f"Restoring dataset state: skipping {self._sample_idx} samples") | ||
for _ in range(self._sample_idx): | ||
next(iterator) | ||
self._restored_checkpoint = False | ||
|
||
while True: | ||
try: | ||
sample = next(dataset_iterator) | ||
except StopIteration: | ||
if not self.infinite: | ||
logger.warning( | ||
f"Dataset {self.dataset_name} has run out of data. \ | ||
This might cause NCCL timeout if data parallelism is enabled." | ||
) | ||
break | ||
else: | ||
# Reset offset for the next iteration if infinite | ||
self._sample_idx = 0 | ||
logger.info(f"Dataset {self.dataset_name} is being re-looped.") | ||
dataset_iterator = self._get_data_iter() | ||
continue | ||
sample = next(iterator) | ||
except (UnicodeDecodeError, SyntaxError, OSError) as e: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In my training, I added this line before to capture some data loading error, eg, corrupted image header when PIL.image is reading, or corrupted .tar file header etc There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense, I must have removed it by accident |
||
# Handle other exception, eg, dataset corruption | ||
logger.warning( | ||
f"Dataset {self.dataset_name} has error while loading batch data. \ | ||
Error {type(e).__name__}: {e}. The error could be the result of a streaming glitch." | ||
) | ||
continue | ||
except StopIteration: | ||
# Handle the end of the iterator | ||
self.reset() | ||
if not self.infinite: | ||
logger.warning(f"Dataset {self.dataset_name} has run out of data") | ||
break | ||
# Reset for next epoch if infinite | ||
logger.warning(f"Dataset {self.dataset_name} is being re-looped") | ||
iterator = self._get_data_iter() | ||
|
||
# Use the dataset-specific preprocessor | ||
sample_dict = self._data_processor( | ||
|
@@ -333,46 +350,85 @@ def __iter__(self): | |
# skip low quality image or image with color channel = 1 | ||
if sample_dict["image"] is None: | ||
logger.warning( | ||
f"Low quality image {sample['__key__']} is skipped in Flux Dataloader." | ||
f"Low quality image {sample['__key__']} is skipped in Flux Dataloader" | ||
) | ||
continue | ||
|
||
# Classifier-free guidance: Replace some of the strings with empty strings. | ||
# Distinct random seed is initialized at the beginning of training for each FSDP rank. | ||
dropout_prob = self.job_config.training.classifer_free_guidance_prob | ||
if dropout_prob > 0.0: | ||
if torch.rand(1).item() < dropout_prob: | ||
sample_dict["t5_tokens"] = self._t5_empty_token | ||
sample_dict["clip_tokens"] = self._clip_empty_token | ||
if dropout_prob > 0.0 and random.random() < dropout_prob: | ||
sample_dict["t5_tokens"] = self._t5_empty_token | ||
sample_dict["clip_tokens"] = self._clip_empty_token | ||
|
||
self._sample_idx += 1 | ||
|
||
labels = sample_dict.pop("image") | ||
yield sample_dict, labels | ||
|
||
def load_state_dict(self, state_dict): | ||
self._sample_idx = state_dict["sample_idx"] | ||
self._sample_idx = state_dict.get("sample_idx", 0) | ||
self._restored_checkpoint = True # Mark that we've loaded from a checkpoint | ||
|
||
def state_dict(self): | ||
return { | ||
"sample_idx": self._sample_idx, | ||
} | ||
|
||
|
||
def build_flux_dataloader( | ||
def build_flux_train_dataloader( | ||
dp_world_size: int, | ||
dp_rank: int, | ||
job_config: JobConfig, | ||
tokenizer: FluxTokenizer | None, | ||
infinite: bool = True, | ||
) -> ParallelAwareDataloader: | ||
return _build_flux_dataloader( | ||
dataset_name=job_config.training.dataset, | ||
dataset_path=job_config.training.dataset_path, | ||
dp_world_size=dp_world_size, | ||
dp_rank=dp_rank, | ||
job_config=job_config, | ||
tokenizer=tokenizer, | ||
infinite=infinite, | ||
batch_size=job_config.training.batch_size, | ||
) | ||
|
||
|
||
def build_flux_val_dataloader( | ||
dp_world_size: int, | ||
dp_rank: int, | ||
job_config: JobConfig, | ||
tokenizer: FluxTokenizer | None, | ||
infinite: bool = False, | ||
) -> ParallelAwareDataloader: | ||
print(job_config.eval.dataset_path) | ||
|
||
return _build_flux_dataloader( | ||
dataset_name=job_config.eval.dataset, | ||
dataset_path=job_config.eval.dataset_path, | ||
dp_world_size=dp_world_size, | ||
dp_rank=dp_rank, | ||
job_config=job_config, | ||
tokenizer=tokenizer, | ||
infinite=infinite, | ||
batch_size=job_config.eval.batch_size, | ||
) | ||
|
||
|
||
def _build_flux_dataloader( | ||
dataset_name: str, | ||
dataset_path: str, | ||
dp_world_size: int, | ||
dp_rank: int, | ||
job_config: JobConfig, | ||
# This parameter is not used, keep it for compatibility | ||
tokenizer: FluxTokenizer | None, | ||
infinite: bool = True, | ||
include_sample_id: bool = False, | ||
batch_size: int = 4, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why this magic number? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should indeed be a parameter |
||
) -> ParallelAwareDataloader: | ||
"""Build a data loader for HuggingFace datasets.""" | ||
dataset_name = job_config.training.dataset | ||
dataset_path = job_config.training.dataset_path | ||
batch_size = job_config.training.batch_size | ||
|
||
t5_encoder_name = job_config.encoder.t5_encoder | ||
clip_encoder_name = job_config.encoder.clip_encoder | ||
max_t5_encoding_len = job_config.encoder.max_t5_encoding_len | ||
|
Uh oh!
There was an error while loading. Please reload this page.