diff --git a/lightning_pose/data/datamodules.py b/lightning_pose/data/datamodules.py index 9392d569..349c29df 100644 --- a/lightning_pose/data/datamodules.py +++ b/lightning_pose/data/datamodules.py @@ -1,22 +1,16 @@ """Data modules split a dataset into train, val, and test modules.""" -import copy import os -from typing import Literal +from typing import Literal, Callable -import imgaug.augmenters as iaa import lightning.pytorch as pl import torch from lightning.pytorch.utilities import CombinedLoader from omegaconf import DictConfig -from torch.utils.data import DataLoader, Subset, random_split +from torch.utils.data import DataLoader, Subset from lightning_pose.data.dali import PrepareDALI from lightning_pose.data.datatypes import SemiSupervisedDataLoaderDict -from lightning_pose.data.utils import ( - compute_num_train_frames, - split_sizes_from_probabilities, -) from lightning_pose.utils.io import check_video_paths # to ignore imports for sphix-autoapidoc @@ -27,170 +21,39 @@ class BaseDataModule(pl.LightningDataModule): - """Splits a labeled dataset into train, val, and test data loaders.""" + """Wraps labeled dataset splits and delegates loader creation to a factory.""" def __init__( self, dataset: torch.utils.data.Dataset, - train_batch_size: int = 16, - val_batch_size: int = 16, - test_batch_size: int = 1, - num_workers: int | None = None, - train_probability: float = 0.8, - val_probability: float | None = None, - test_probability: float | None = None, - train_frames: float | int | None = None, - torch_seed: int = 42, + splits: tuple[Subset, Subset, Subset], + dataloader_factory: Callable[[str], DataLoader] ) -> None: - """Data module splits a dataset into train, val, and test data loaders. + """Data module that uses an injected dataloader factory. Args: - dataset: base dataset to be split into train/val/test - train_batch_size: number of samples of training batches - val_batch_size: number of samples in validation batches - test_batch_size: number of samples in test batches - num_workers: number of threads used for prefetching data - train_probability: fraction of full dataset used for training - val_probability: fraction of full dataset used for validation - test_probability: fraction of full dataset used for testing - train_frames: if integer, select this number of training frames - from the initially selected train frames (defined by - `train_probability`); if float, must be between 0 and 1 - (exclusive) and defines the fraction of the initially selected - train frames - torch_seed: control data splits - + dataset: base dataset corresponding to provided splits + splits: tuple of (train_subset, val_subset, test_subset) + dataloader_factory: function mapping a stage string ("train"|"val"|"test") + to a configured DataLoader for the corresponding split """ super().__init__() self.dataset = dataset - self.train_batch_size = train_batch_size - self.val_batch_size = val_batch_size - self.test_batch_size = test_batch_size - if num_workers is not None: - self.num_workers = num_workers - else: - slurm_cpus = os.getenv("SLURM_CPUS_PER_TASK") - if slurm_cpus: - self.num_workers = int(slurm_cpus) - else: - # Fallback to os.cpu_count() - self.num_workers = os.cpu_count() - self.train_probability = train_probability - self.val_probability = val_probability - self.test_probability = test_probability - self.train_frames = train_frames - self.train_dataset = None # populated by self.setup() - self.val_dataset = None # populated by self.setup() - self.test_dataset = None # populated by self.setup() - self.torch_seed = torch_seed - self._setup() - - def _setup(self) -> None: - - datalen = self.dataset.__len__() - print(f"Number of labeled images in the full dataset (train+val+test): {datalen}") - - # split data based on provided probabilities - data_splits_list = split_sizes_from_probabilities( - datalen, - train_probability=self.train_probability, - val_probability=self.val_probability, - test_probability=self.test_probability, - ) - - if len(self.dataset.imgaug_transform) == 1: - # no augmentations in the pipeline; subsets can share same underlying dataset - self.train_dataset, self.val_dataset, self.test_dataset = random_split( - self.dataset, - data_splits_list, - generator=torch.Generator().manual_seed(self.torch_seed), - ) - else: - # augmentations in the pipeline; we want validation and test datasets that only resize - # we can't simply change the imgaug pipeline in the datasets after they've been split - # because the subsets actually point to the same underlying dataset, so we create - # separate datasets here - train_idxs, val_idxs, test_idxs = random_split( - range(len(self.dataset)), - data_splits_list, - generator=torch.Generator().manual_seed(self.torch_seed), - ) - - self.train_dataset = Subset(copy.deepcopy(self.dataset), indices=list(train_idxs)) - self.val_dataset = Subset(copy.deepcopy(self.dataset), indices=list(val_idxs)) - self.test_dataset = Subset(copy.deepcopy(self.dataset), indices=list(test_idxs)) - - # only use the final resize transform for the validation and test datasets - if self.dataset.imgaug_transform[-1].__str__().find("Resize") == 0: - final_transform = iaa.Sequential([self.dataset.imgaug_transform[-1]]) - else: - # if we're here it's because the dataset is a MultiviewHeatmapDataset that doesn't - # resize by default in the pipeline; we enforce resizing here on val/test batches - height = self.dataset.height - width = self.dataset.width - final_transform = iaa.Sequential([iaa.Resize({"height": height, "width": width})]) - - self.val_dataset.dataset.imgaug_transform = final_transform - if hasattr(self.val_dataset.dataset, "dataset"): - # this will get triggered for multiview datasets - print("val: updating children datasets with resize imgaug pipeline") - for view_name, dset in self.val_dataset.dataset.dataset.items(): - dset.imgaug_transform = final_transform - - self.test_dataset.dataset.imgaug_transform = final_transform - if hasattr(self.test_dataset.dataset, "dataset"): - # this will get triggered for multiview datasets - print("test: updating children datasets with resize imgaug pipeline") - for view_name, dset in self.test_dataset.dataset.dataset.items(): - dset.imgaug_transform = final_transform - - # further subsample training data if desired - if self.train_frames is not None: - n_frames = compute_num_train_frames(len(self.train_dataset), self.train_frames) - - if n_frames < len(self.train_dataset): - # split the data a second time to reflect further subsampling from - # train_frames - self.train_dataset.indices = self.train_dataset.indices[:n_frames] - - print( - f"Dataset splits -- " - f"train: {len(self.train_dataset)}, " - f"val: {len(self.val_dataset)}, " - f"test: {len(self.test_dataset)}" - ) + self.train_dataset, self.val_dataset, self.test_dataset = splits + self._get_dataloader = dataloader_factory def train_dataloader(self) -> torch.utils.data.DataLoader: - return DataLoader( - self.train_dataset, - batch_size=self.train_batch_size, - num_workers=self.num_workers, - persistent_workers=True if self.num_workers > 0 else False, - shuffle=True, - generator=torch.Generator().manual_seed(self.torch_seed), - ) + return self._get_dataloader("train") def val_dataloader(self) -> torch.utils.data.DataLoader: - return DataLoader( - self.val_dataset, - batch_size=self.val_batch_size, - num_workers=self.num_workers, - persistent_workers=True if self.num_workers > 0 else False, - ) + return self._get_dataloader("val") def test_dataloader(self) -> torch.utils.data.DataLoader: - return DataLoader( - self.test_dataset, - batch_size=self.test_batch_size, - num_workers=self.num_workers, - ) + return self._get_dataloader("test") def full_labeled_dataloader(self) -> torch.utils.data.DataLoader: - return DataLoader( - self.dataset, - batch_size=self.val_batch_size, - num_workers=self.num_workers, - ) + # Delegate to factory; expect it to support a 'full' stage for labeled data. + return self._get_dataloader("full") class UnlabeledDataModule(BaseDataModule): @@ -199,18 +62,11 @@ class UnlabeledDataModule(BaseDataModule): def __init__( self, dataset: torch.utils.data.Dataset, + splits: tuple[Subset, Subset, Subset], + dataloader_factory: Callable[[str], DataLoader], video_paths_list: list[str] | str, dali_config: dict | DictConfig, view_names: list[str] | None = None, - train_batch_size: int = 16, - val_batch_size: int = 16, - test_batch_size: int = 1, - num_workers: int | None = None, - train_probability: float = 0.8, - val_probability: float | None = None, - test_probability: float | None = None, - train_frames: float | None = None, - torch_seed: int = 42, imgaug: Literal["default", "dlc", "dlc-top-down"] = "default", ) -> None: """Data module that contains labeled and unlabeled data loaders. @@ -221,34 +77,13 @@ def __init__( view_names: if fitting a non-mirrored multiview model, pass view names in order to correctly organize the video paths dali_config: see `dali` entry of default config file for keys - train_batch_size: number of samples of training batches - val_batch_size: number of samples in validation batches - test_batch_size: number of samples in test batches - num_workers: number of threads used for prefetching data - train_probability: fraction of full dataset used for training - val_probability: fraction of full dataset used for validation - test_probability: fraction of full dataset used for testing - train_frames: if integer, select this number of training frames - from the initially selected train frames (defined by - `train_probability`); if float, must be between 0 and 1 - (exclusive) and defines the fraction of the initially selected - train frames - torch_seed: control data splits - torch_seed: control randomness of labeled data loading imgaug: type of image augmentation to apply to unlabeled frames """ super().__init__( dataset=dataset, - train_batch_size=train_batch_size, - val_batch_size=val_batch_size, - test_batch_size=test_batch_size, - num_workers=num_workers, - train_probability=train_probability, - val_probability=val_probability, - test_probability=test_probability, - train_frames=train_frames, - torch_seed=torch_seed, + splits=splits, + dataloader_factory=dataloader_factory, ) self.video_paths_list = video_paths_list self.filenames = check_video_paths(self.video_paths_list, view_names=view_names) diff --git a/lightning_pose/train.py b/lightning_pose/train.py index 193309f9..8565ab51 100644 --- a/lightning_pose/train.py +++ b/lightning_pose/train.py @@ -9,7 +9,6 @@ import sys from pathlib import Path -import lightning.pytorch as pl import numpy as np import torch from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict @@ -21,17 +20,8 @@ from lightning_pose.utils import pretty_print_cfg, pretty_print_str from lightning_pose.utils.io import ( find_video_files_for_views, - return_absolute_data_paths, -) -from lightning_pose.utils.scripts import ( - calculate_steps_per_epoch, - get_callbacks, - get_data_module, - get_dataset, - get_imgaug_transform, - get_loss_factories, - get_model, ) +from lightning_pose.utils.mega_factory_impl import ModelComponentContainerImpl # to ignore imports for sphinx-autoapidoc __all__ = ["train"] @@ -221,46 +211,7 @@ def _train(cfg: DictConfig) -> Model: ModelConfig(cfg).validate() - # path handling for toy data - data_dir, video_dir = return_absolute_data_paths(data_cfg=cfg.data) - - # ---------------------------------------------------------------------------------- - # Set up data/model objects - # ---------------------------------------------------------------------------------- - - # imgaug transform - imgaug_transform = get_imgaug_transform(cfg=cfg) - - # dataset - dataset = get_dataset(cfg=cfg, data_dir=data_dir, imgaug_transform=imgaug_transform) - - # datamodule; breaks up dataset into train/val/test - data_module = get_data_module(cfg=cfg, dataset=dataset, video_dir=video_dir) - - # build loss factory which orchestrates different losses - loss_factories = get_loss_factories(cfg=cfg, data_module=data_module) - - steps_per_epoch = calculate_steps_per_epoch(data_module) - - # convert milestone_steps to milestones if applicable (before `get_model`). - if ( - "multisteplr" in cfg.training.lr_scheduler_params - and "milestone_steps" in cfg.training.lr_scheduler_params.multisteplr - ): - milestone_steps = cfg.training.lr_scheduler_params.multisteplr.milestone_steps - milestones = [math.ceil(s / steps_per_epoch) for s in milestone_steps] - cfg.training.lr_scheduler_params.multisteplr.milestones = milestones - - # convert patch masking epochs if applicable (before `get_callbacks`) - if "patch_mask" in cfg.training and "init_epoch" in cfg.training.patch_mask: - init_step = math.ceil(cfg.training.patch_mask.init_epoch * steps_per_epoch) - final_step = math.ceil(cfg.training.patch_mask.final_epoch * steps_per_epoch) - with open_dict(cfg): - cfg.training.patch_mask.init_step = init_step - cfg.training.patch_mask.final_step = final_step - - # model - model = get_model(cfg=cfg, data_module=data_module, loss_factories=loss_factories) + container = ModelComponentContainerImpl(cfg) # ---------------------------------------------------------------------------------- # Save configuration in output directory @@ -284,7 +235,7 @@ def _train(cfg: DictConfig) -> Model: for csv_file in csv_files: src_csv_file = Path(csv_file) if not src_csv_file.is_absolute(): - src_csv_file = Path(data_dir) / src_csv_file + src_csv_file = Path(cfg.data.data_dir) / src_csv_file dest_csv_file = Path(hydra_output_directory) / src_csv_file.name shutil.copyfile(src_csv_file, dest_csv_file) @@ -294,7 +245,7 @@ def _train(cfg: DictConfig) -> Model: # ---------------------------------------------------------------------------------- # logger - logger = pl.loggers.TensorBoardLogger("tb_logs", name=cfg.model.model_name) + logger = container.get_logger() # Log hydra config to tensorboard as helpful metadata. for key, value in cfg.items(): logger.experiment.add_text( @@ -302,51 +253,10 @@ def _train(cfg: DictConfig) -> Model: ) # early stopping, learning rate monitoring, model checkpointing, backbone unfreezing - callbacks = get_callbacks( - cfg, - early_stopping=cfg.training.get("early_stopping", False), - lr_monitor=True, - ckpt_every_n_epochs=cfg.training.get("ckpt_every_n_epochs", None), - ) - - # set up trainer - - cfg.training.num_gpus = max(cfg.training.num_gpus, 1) - - # initialize to Trainer defaults. Note max_steps defaults to -1. - min_steps, max_steps, min_epochs, max_epochs = (None, -1, None, None) - if "min_steps" in cfg.training: - min_steps = cfg.training.min_steps - max_steps = cfg.training.max_steps - else: - min_epochs = cfg.training.min_epochs - max_epochs = cfg.training.max_epochs - - # Unlike min_epoch/min_step, both of these are valid to specify. - check_val_every_n_epoch = cfg.training.get("check_val_every_n_epoch", 1) - val_check_interval = cfg.training.get("val_check_interval") - - trainer = pl.Trainer( - accelerator="gpu", - devices=cfg.training.num_gpus, - max_epochs=max_epochs, - min_epochs=min_epochs, - max_steps=max_steps, - min_steps=min_steps, - check_val_every_n_epoch=check_val_every_n_epoch, - val_check_interval=val_check_interval, - log_every_n_steps=cfg.training.log_every_n_steps, - callbacks=callbacks, - logger=logger, - # To understand why we set this, see 'max_size_cycle' in UnlabeledDataModule. - limit_train_batches=cfg.training.get("limit_train_batches") or steps_per_epoch, - accumulate_grad_batches=cfg.training.get("accumulate_grad_batches", 1), - profiler=cfg.training.get("profiler", None), - sync_batchnorm=True, - ) # train model! - trainer.fit(model=model, datamodule=data_module) + trainer = container.get_trainer() + trainer.fit(model=container.get_model(), datamodule=container.get_data_module()) # When devices > 0, lightning creates a process per device. # Kill processes other than the main process, otherwise they all go forward. diff --git a/lightning_pose/utils/mega_factory.py b/lightning_pose/utils/mega_factory.py new file mode 100644 index 00000000..8147721d --- /dev/null +++ b/lightning_pose/utils/mega_factory.py @@ -0,0 +1,53 @@ +from __future__ import annotations +from typing import Protocol, Callable, TypeAlias, Tuple, Union +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + import torch + from omegaconf import DictConfig + import imgaug.augmenters as iaa + import lightning.pytorch as pl + from torchtyping import TensorType + + from torch.utils.data import Subset + + from lightning_pose.losses.factory import LossFactory + from lightning_pose.models import HeatmapTracker + + DataExtractorOutput: TypeAlias = Tuple[ + TensorType["num_examples", Any], + Union[ + TensorType["num_examples", 3, "image_width", "image_height"], + TensorType["num_examples", "frames", 3, "image_width", "image_height"], + None, + ], + ] + + +class ModelComponentContainer(Protocol): + def __init__(self, cfg: DictConfig): ... + + def get_imgaug_transform(self) -> iaa.Sequential: ... + + def get_dataset(self) -> torch.utils.data.Dataset: ... + + def get_split_datasets(self) -> tuple[Subset, Subset, Subset]: ... + + def get_dataloader_factory(self) -> Callable[[str], torch.utils.data.DataLoader]: + """stage -> dataloader""" + ... + + def get_predict_dali_dataloader_factory( + self, + ) -> Callable[[str], torch.utils.data.DataLoader]: + """filename -> predict dataloader""" + ... + + def get_combined_dataloader_factory( + self, + ) -> Callable[[Subset], torch.utils.data.DataLoader]: ... + + def get_data_module(self) -> pl.LightningDataModule: ... + + def get_model(self) -> HeatmapTracker: ... diff --git a/lightning_pose/utils/mega_factory_impl.py b/lightning_pose/utils/mega_factory_impl.py new file mode 100644 index 00000000..e9aa2226 --- /dev/null +++ b/lightning_pose/utils/mega_factory_impl.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +from functools import wraps +from typing import Protocol, Callable, TypeAlias, Tuple, Union +from typing import TYPE_CHECKING + +from lightning_pose.utils import scripts +from lightning_pose.utils.mega_factory import ModelComponentContainer + +if TYPE_CHECKING: + import torch + from omegaconf import DictConfig + import imgaug.augmenters as iaa + import lightning.pytorch as pl + from torchtyping import TensorType + + from torch.utils.data import Subset + + from lightning_pose.losses.factory import LossFactory + from lightning_pose.models import HeatmapTracker + + DataExtractorOutput: TypeAlias = Tuple[ + TensorType["num_examples", Any], + Union[ + TensorType["num_examples", 3, "image_width", "image_height"], + TensorType["num_examples", "frames", 3, "image_width", "image_height"], + None, + ], + ] + + +def cached(func): + """Cache a method's result on the instance itself.""" + + @wraps(func) + def wrapper(self): + cache_attr = "_instance_cache" + if not hasattr(self, cache_attr): + setattr(self, cache_attr, {}) + cache = getattr(self, cache_attr) + key = func.__name__ + if key not in cache: + cache[key] = func(self) + return cache[key] + + return wrapper + + +class ModelComponentContainerImpl(Protocol): + def __init__(self, cfg: DictConfig): + self.cfg = cfg + + # Intentionally not cached since this can be mutated later. + def get_imgaug_transform(self) -> iaa.Sequential: + return scripts.get_imgaug_transform(cfg=self.cfg) + + @cached + def get_dataset(self) -> torch.utils.data.Dataset: + imgaug_transform = self.get_imgaug_transform() + return scripts.get_dataset( + cfg=self.cfg, + data_dir=self.cfg.data.data_dir, + imgaug_transform=imgaug_transform, + ) + + @cached + def get_split_datasets(self) -> tuple[Subset, Subset, Subset]: + dataset = self.get_dataset() + return scripts.get_split_datasets(cfg=self.cfg, dataset=dataset) + + @cached + def get_dataloader_factory(self) -> Callable[[str], torch.utils.data.DataLoader]: + """stage -> dataloader""" + splits = self.get_split_datasets() + return scripts.get_dataloader_factory(cfg=self.cfg, splits=splits) + + @cached + def get_predict_dali_dataloader_factory( + self, + ) -> Callable[[str], torch.utils.data.DataLoader]: + """filename -> predict dataloader""" + ... + + @cached + def get_combined_dataloader_factory( + self, + ) -> Callable[[Subset], torch.utils.data.DataLoader]: ... + + @cached + def get_data_module(self) -> pl.LightningDataModule: + dataset = self.get_dataset() + # Build splits and labeled dataloader factory, then construct data module + splits = scripts.get_split_datasets(cfg=self.cfg, dataset=dataset) + dataloader_factory = scripts.get_dataloader_factory( + cfg=self.cfg, dataset=dataset, splits=splits + ) + return scripts.get_data_module( + cfg=self.cfg, dataset=dataset, dataloader_factory=dataloader_factory + ) + + @cached + def get_loss_factories(self) -> dict[str, LossFactory | None]: + data_module = self.get_data_module() + return scripts.get_loss_factories(cfg=self.cfg, data_module=data_module) + + @cached + def get_model(self) -> HeatmapTracker: + data_module = self.get_data_module() + loss_factories = self.get_loss_factories() + return scripts.get_model( + cfg=self.cfg, loss_factories=loss_factories, data_module=data_module + ) + + # Trainer dependencies + + def get_steps_for_epoch(self): + return scripts.calculate_steps_per_epoch(self.get_data_module()) + + def get_logger(self): + return scripts.get_training_logger(cfg=self.cfg) + + def get_callbacks(self): + return scripts.get_callbacks(cfg=self.cfg) + + def get_trainer(self) -> pl.Trainer: + return scripts.get_trainer( + cfg=self.cfg, + steps_per_epoch=self.get_steps_for_epoch(), + logger=self.get_logger(), + callbacks=self.get_callbacks(), + ) + + +test: ModelComponentContainer = ModelComponentContainerImpl(None) diff --git a/lightning_pose/utils/scripts.py b/lightning_pose/utils/scripts.py index 6bcd8629..42244a38 100644 --- a/lightning_pose/utils/scripts.py +++ b/lightning_pose/utils/scripts.py @@ -6,6 +6,8 @@ from collections import OrderedDict from pathlib import Path +import copy +from typing import TYPE_CHECKING, Callable import imgaug.augmenters as iaa import lightning.pytorch as pl import numpy as np @@ -20,6 +22,7 @@ expand_imgaug_str_to_dict, imgaug_transform, ) + from lightning_pose.data.datamodules import BaseDataModule, UnlabeledDataModule from lightning_pose.data.datasets import ( BaseTrackingDataset, @@ -27,6 +30,10 @@ MultiviewHeatmapDataset, ) from lightning_pose.data.datatypes import ComputeMetricsSingleResult +from lightning_pose.data.utils import ( + compute_num_train_frames, + split_sizes_from_probabilities, +) from lightning_pose.losses.factory import LossFactory from lightning_pose.metrics import ( pca_multiview_reprojection_error, @@ -40,6 +47,7 @@ ) from lightning_pose.utils import io as io_utils from lightning_pose.utils.pca import KeypointPCA +from torch.utils.data import Subset, random_split, DataLoader # to ignore imports for sphix-autoapidoc __all__ = [ @@ -51,6 +59,7 @@ "get_callbacks", "calculate_steps_per_epoch", "compute_metrics", + "get_split_datasets", ] @@ -103,9 +112,13 @@ def get_imgaug_transform(cfg: DictConfig) -> iaa.Sequential: else: params_dict = params.copy() for transform, val in params_dict.items(): - assert getattr(iaa, transform), f"{transform} is not a valid imgaug transform" + assert getattr( + iaa, transform + ), f"{transform} is not a valid imgaug transform" else: - raise TypeError(f"params is of type {type(params)}, must be str, dict, or DictConfig") + raise TypeError( + f"params is of type {type(params)}, must be str, dict, or DictConfig" + ) return imgaug_transform(params_dict) @@ -120,7 +133,9 @@ def get_dataset( if cfg.model.model_type == "regression": if cfg.data.get("view_names", None) and len(cfg.data.view_names) > 1: - raise NotImplementedError("Multi-view support only available for heatmap-based models") + raise NotImplementedError( + "Multi-view support only available for heatmap-based models" + ) else: dataset = BaseTrackingDataset( root_directory=data_dir, @@ -136,9 +151,8 @@ def get_dataset( "No precautions regarding the size of the images were considered here, " "images will be resized accordingly to configs!" ) - if ( - cfg.training.imgaug in ["default", "none"] - or not cfg.data.get("camera_params_file") + if cfg.training.imgaug in ["default", "none"] or not cfg.data.get( + "camera_params_file" ): # we are either # 1. running inference on un-augmented data, and need to make sure to resize @@ -154,9 +168,12 @@ def get_dataset( image_resize_width=cfg.data.image_resize_dims.width, imgaug_transform=imgaug_transform, downsample_factor=cfg.data.get("downsample_factor", 2), - do_context=cfg.model.model_type == "heatmap_mhcrnn", # context only for mhcrnn + do_context=cfg.model.model_type + == "heatmap_mhcrnn", # context only for mhcrnn resize=resize, - uniform_heatmaps=cfg.training.get("uniform_heatmaps_for_nan_keypoints", False), + uniform_heatmaps=cfg.training.get( + "uniform_heatmaps_for_nan_keypoints", False + ), camera_params_path=cfg.data.get("camera_params_file", None), bbox_paths=cfg.data.get("bbox_file", None), ) @@ -168,23 +185,45 @@ def get_dataset( image_resize_width=cfg.data.image_resize_dims.width, imgaug_transform=imgaug_transform, downsample_factor=cfg.data.get("downsample_factor", 2), - do_context=cfg.model.model_type == "heatmap_mhcrnn", # context only for mhcrnn - uniform_heatmaps=cfg.training.get("uniform_heatmaps_for_nan_keypoints", False), + do_context=cfg.model.model_type + == "heatmap_mhcrnn", # context only for mhcrnn + uniform_heatmaps=cfg.training.get( + "uniform_heatmaps_for_nan_keypoints", False + ), ) else: - raise NotImplementedError("%s is an invalid cfg.model.model_type" % cfg.model.model_type) + raise NotImplementedError( + "%s is an invalid cfg.model.model_type" % cfg.model.model_type + ) return dataset +def get_train_val_batches(cfg: DictConfig) -> tuple[int, int]: + """Determine the number of batches to use for training and validation.""" + # Divide config batch_size by num_gpus to maintain the same effective batch + # size in a multi-gpu setting. + train_batch_size = int( + np.ceil(cfg.training.train_batch_size / cfg.training.num_gpus) + ) + val_batch_size = int(np.ceil(cfg.training.val_batch_size / cfg.training.num_gpus)) + + return train_batch_size, val_batch_size + + @typechecked def get_data_module( cfg: DictConfig, dataset: BaseTrackingDataset | HeatmapDataset | MultiviewHeatmapDataset, video_dir: str | None = None, + dataloader_factory: Callable[[str], DataLoader] | None = None, ) -> BaseDataModule | UnlabeledDataModule: - """Create a data module that splits a dataset into train/val/test iterators.""" + """Create a data module using provided dataloader factory (preferred). + + If `dataloader_factory` is None, this function will derive splits and create a + default labeled dataloader factory from cfg for backward compatibility. + """ # Old configs may have num_gpus: 0. We will remove support in a future release. if cfg.training.num_gpus == 0: @@ -194,25 +233,20 @@ def get_data_module( ) cfg.training.num_gpus = max(cfg.training.num_gpus, 1) - # Divide config batch_size by num_gpus to maintain the same effective batch - # size in a multi-gpu setting. - train_batch_size = int( - np.ceil(cfg.training.train_batch_size / cfg.training.num_gpus) - ) - val_batch_size = int(np.ceil(cfg.training.val_batch_size / cfg.training.num_gpus)) - semi_supervised = io_utils.check_if_semi_supervised(cfg.model.losses_to_use) + + # Build splits and default factory if not provided + splits = get_split_datasets(cfg=cfg, dataset=dataset) + if dataloader_factory is None: + dataloader_factory = get_dataloader_factory( + cfg=cfg, dataset=dataset, splits=splits + ) + if not semi_supervised: data_module = BaseDataModule( dataset=dataset, - train_batch_size=train_batch_size, - val_batch_size=val_batch_size, - test_batch_size=cfg.training.test_batch_size, - num_workers=cfg.training.get("num_workers"), - train_probability=cfg.training.train_prob, - val_probability=cfg.training.val_prob, - train_frames=cfg.training.train_frames, - torch_seed=cfg.training.rng_seed_data_pt, + splits=splits, + dataloader_factory=dataloader_factory, ) else: # Divide config batch_size by num_gpus to maintain the same effective batch @@ -248,17 +282,11 @@ def get_data_module( view_names = list(view_names) if view_names is not None else None data_module = UnlabeledDataModule( dataset=dataset, + splits=splits, + dataloader_factory=dataloader_factory, video_paths_list=video_dir, view_names=view_names, - train_batch_size=train_batch_size, - val_batch_size=val_batch_size, - test_batch_size=cfg.training.test_batch_size, - num_workers=cfg.training.get("num_workers"), - train_probability=cfg.training.train_prob, - val_probability=cfg.training.val_prob, - train_frames=cfg.training.train_frames, dali_config=dali_config, - torch_seed=cfg.training.rng_seed_data_pt, imgaug=cfg.training.get("imgaug", "default"), ) return data_module @@ -280,8 +308,12 @@ def get_loss_factories( if cfg.model.model_type.find("heatmap") > -1: loss_name = "heatmap_" + cfg.model.heatmap_loss_type loss_params_dict["supervised"][loss_name] = {"log_weight": 0.0} - if cfg.model.model_type.find("multiview") > -1 and cfg.data.get("camera_params_file"): - log_weight = cfg.losses.get("supervised_pairwise_projections", {}).get("log_weight") + if cfg.model.model_type.find("multiview") > -1 and cfg.data.get( + "camera_params_file" + ): + log_weight = cfg.losses.get("supervised_pairwise_projections", {}).get( + "log_weight" + ) if log_weight is not None: print("adding supervised pairwise projection loss") loss_params_dict["supervised"]["supervised_pairwise_projections"] = { @@ -317,7 +349,9 @@ def get_loss_factories( "original_image_width" ] = width_og # record downsampled image dims - height_ds = int(height_og // (2 ** cfg.data.get("downsample_factor", 2))) + height_ds = int( + height_og // (2 ** cfg.data.get("downsample_factor", 2)) + ) width_ds = int(width_og // (2 ** cfg.data.get("downsample_factor", 2))) loss_params_dict["unsupervised"][loss_name][ "downsampled_image_height" @@ -326,9 +360,9 @@ def get_loss_factories( "downsampled_image_width" ] = width_ds if loss_name[:8] == "unimodal": - loss_params_dict["unsupervised"][loss_name][ - "uniform_heatmaps" - ] = cfg.training.get("uniform_heatmaps_for_nan_keypoints", False) + loss_params_dict["unsupervised"][loss_name]["uniform_heatmaps"] = ( + cfg.training.get("uniform_heatmaps_for_nan_keypoints", False) + ) elif loss_name == "pca_multiview": if cfg.data.get("view_names", None) and len(cfg.data.view_names) > 1: # assume user has provided a set of columns that are present in each view @@ -338,8 +372,10 @@ def get_loss_factories( loss_params_dict["unsupervised"][loss_name][ "mirrored_column_matches" ] = [ - (v * num_keypoints - + np.array(cfg.data.mirrored_column_matches, dtype=int)).tolist() + ( + v * num_keypoints + + np.array(cfg.data.mirrored_column_matches, dtype=int) + ).tolist() for v in range(num_views) ] else: @@ -360,7 +396,7 @@ def get_loss_factories( else: loss_params_dict["unsupervised"][loss_name][ "columns_for_singleview_pca" - ] = cfg.data.get('columns_for_singleview_pca', None) + ] = cfg.data.get("columns_for_singleview_pca", None) # build supervised loss factory, which orchestrates all supervised losses loss_factory_sup = LossFactory( @@ -380,10 +416,31 @@ def get_loss_factories( def get_model( cfg: DictConfig, data_module: BaseDataModule | UnlabeledDataModule | None, - loss_factories: dict[str, LossFactory] | dict[str, None] + loss_factories: dict[str, LossFactory] | dict[str, None], ) -> pl.LightningModule: """Create model: regression or heatmap based, supervised or semi-supervised.""" + ## BEGIN: Hack for model training only (not needed for inference) + steps_per_epoch = calculate_steps_per_epoch(data_module) + + # convert milestone_steps to milestones if applicable (before `get_model`). + if ( + "multisteplr" in cfg.training.lr_scheduler_params + and "milestone_steps" in cfg.training.lr_scheduler_params.multisteplr + ): + milestone_steps = cfg.training.lr_scheduler_params.multisteplr.milestone_steps + milestones = [math.ceil(s / steps_per_epoch) for s in milestone_steps] + cfg.training.lr_scheduler_params.multisteplr.milestones = milestones + + # convert patch masking epochs if applicable (before `get_callbacks`) + if "patch_mask" in cfg.training and "init_epoch" in cfg.training.patch_mask: + init_step = math.ceil(cfg.training.patch_mask.init_epoch * steps_per_epoch) + final_step = math.ceil(cfg.training.patch_mask.final_epoch * steps_per_epoch) + with open_dict(cfg): + cfg.training.patch_mask.init_step = init_step + cfg.training.patch_mask.final_step = final_step + ## END: Hack for model training only (not needed for inference) + optimizer = cfg.training.get("optimizer", "Adam") optimizer_params = _apply_defaults_for_optimizer_params( optimizer, @@ -392,8 +449,7 @@ def get_model( lr_scheduler = cfg.training.get("lr_scheduler", "multisteplr") lr_scheduler_params = _apply_defaults_for_lr_scheduler_params( - lr_scheduler, - cfg.training.get("lr_scheduler_params", {}).get(f"{lr_scheduler}") + lr_scheduler, cfg.training.get("lr_scheduler_params", {}).get(f"{lr_scheduler}") ) semi_supervised = io_utils.check_if_semi_supervised(cfg.model.losses_to_use) @@ -401,12 +457,15 @@ def get_model( image_w = cfg.data.image_resize_dims.width if "vit" in cfg.model.backbone: if image_h != image_w: - raise RuntimeError("ViT model requires resized height and width to be equal") + raise RuntimeError( + "ViT model requires resized height and width to be equal" + ) backbone_pretrained = cfg.model.get("backbone_pretrained", True) if not semi_supervised: if cfg.model.model_type == "regression": from lightning_pose.models import RegressionTracker + model = RegressionTracker( num_keypoints=cfg.data.num_keypoints, loss_factory=loss_factories["supervised"], @@ -425,6 +484,7 @@ def get_model( else: num_targets = None from lightning_pose.models import HeatmapTracker + model = HeatmapTracker( num_keypoints=cfg.data.num_keypoints, num_targets=num_targets, @@ -438,10 +498,13 @@ def get_model( lr_scheduler=lr_scheduler, lr_scheduler_params=lr_scheduler_params, image_size=image_h, # only used by ViT - backbone_checkpoint=cfg.model.get("backbone_checkpoint"), # only used by ViTMAE + backbone_checkpoint=cfg.model.get( + "backbone_checkpoint" + ), # only used by ViTMAE ) elif cfg.model.model_type == "heatmap_mhcrnn": from lightning_pose.models import HeatmapTrackerMHCRNN + model = HeatmapTrackerMHCRNN( num_keypoints=cfg.data.num_keypoints, loss_factory=loss_factories["supervised"], @@ -454,10 +517,13 @@ def get_model( lr_scheduler=lr_scheduler, lr_scheduler_params=lr_scheduler_params, image_size=image_h, # only used by ViT - backbone_checkpoint=cfg.model.get("backbone_checkpoint"), # only used by ViTMAE + backbone_checkpoint=cfg.model.get( + "backbone_checkpoint" + ), # only used by ViTMAE ) elif cfg.model.model_type == "heatmap_multiview_transformer": from lightning_pose.models import HeatmapTrackerMultiviewTransformer + model = HeatmapTrackerMultiviewTransformer( num_keypoints=cfg.data.num_keypoints, num_views=len(cfg.data.view_names), @@ -482,6 +548,7 @@ def get_model( else: if cfg.model.model_type == "regression": from lightning_pose.models import SemiSupervisedRegressionTracker + model = SemiSupervisedRegressionTracker( num_keypoints=cfg.data.num_keypoints, loss_factory=loss_factories["supervised"], @@ -498,6 +565,7 @@ def get_model( elif cfg.model.model_type == "heatmap": from lightning_pose.models import SemiSupervisedHeatmapTracker + model = SemiSupervisedHeatmapTracker( num_keypoints=cfg.data.num_keypoints, loss_factory=loss_factories["supervised"], @@ -511,10 +579,13 @@ def get_model( lr_scheduler=lr_scheduler, lr_scheduler_params=lr_scheduler_params, image_size=image_h, # only used by ViT - backbone_checkpoint=cfg.model.get("backbone_checkpoint"), # only used by ViTMAE + backbone_checkpoint=cfg.model.get( + "backbone_checkpoint" + ), # only used by ViTMAE ) elif cfg.model.model_type == "heatmap_mhcrnn": from lightning_pose.models import SemiSupervisedHeatmapTrackerMHCRNN + model = SemiSupervisedHeatmapTrackerMHCRNN( num_keypoints=cfg.data.num_keypoints, loss_factory=loss_factories["supervised"], @@ -528,10 +599,15 @@ def get_model( lr_scheduler=lr_scheduler, lr_scheduler_params=lr_scheduler_params, image_size=image_h, # only used by ViT - backbone_checkpoint=cfg.model.get("backbone_checkpoint"), # only used by ViTMAE + backbone_checkpoint=cfg.model.get( + "backbone_checkpoint" + ), # only used by ViTMAE ) elif cfg.model.model_type == "heatmap_multiview_transformer": - from lightning_pose.models import SemiSupervisedHeatmapTrackerMultiviewTransformer + from lightning_pose.models import ( + SemiSupervisedHeatmapTrackerMultiviewTransformer, + ) + model = SemiSupervisedHeatmapTrackerMultiviewTransformer( num_keypoints=cfg.data.num_keypoints, num_views=len(cfg.data.view_names), @@ -560,6 +636,7 @@ def get_model( print(f"Loading weights from {ckpt}") if not ckpt.endswith(".ckpt"): import glob + ckpt = glob.glob(os.path.join(ckpt, "**", "*.ckpt"), recursive=True)[0] # Try loading with default settings first, fallback to weights_only=False if needed try: @@ -582,15 +659,20 @@ def get_model( return model +def get_training_logger(cfg): + return pl.loggers.TensorBoardLogger("tb_logs", name=cfg.model.model_name) + + @typechecked def get_callbacks( cfg: DictConfig, - early_stopping=False, checkpointing=True, - lr_monitor=True, - ckpt_every_n_epochs=None, backbone_unfreeze=True, ) -> list: + # Param extraction from train.py. May be overridden for testing. + early_stopping = cfg.training.get("early_stopping", False) + lr_monitor = True + ckpt_every_n_epochs = cfg.training.get("ckpt_every_n_epochs", None) callbacks = [] @@ -635,9 +717,10 @@ def get_callbacks( # we just need this callback for unsupervised losses or multiview models with 3d loss if ( - ((cfg.model.losses_to_use != []) and (cfg.model.losses_to_use is not None)) - or cfg.losses.get("supervised_pairwise_projections", {}).get("log_weight") is not None - ): + (cfg.model.losses_to_use != []) and (cfg.model.losses_to_use is not None) + ) or cfg.losses.get("supervised_pairwise_projections", {}).get( + "log_weight" + ) is not None: anneal_weight_callback = AnnealWeight(**cfg.callbacks.anneal_weight) callbacks.append(anneal_weight_callback) @@ -657,8 +740,23 @@ def get_callbacks( def calculate_steps_per_epoch(data_module: BaseDataModule): - train_dataset_length = len(data_module.train_dataset) - steps_per_epoch = math.ceil(train_dataset_length / data_module.train_batch_size) + """Infer steps per epoch from the training dataloader. + + For semi-supervised (CombinedLoader), we still enforce a minimum of 10 steps + to encourage more unlabeled exposure when labeled data is scarce. + """ + train_loader = data_module.train_dataloader() + try: + steps_per_epoch = len(train_loader) + except TypeError: + # Fallback: compute from dataset length and an inferred batch size where possible + train_dataset_length = len(data_module.train_dataset) + # Try to get batch_size attribute if available + batch_size = getattr(train_loader, "batch_size", None) + if not batch_size or batch_size == 0: + # conservative default to avoid division by zero + batch_size = 1 + steps_per_epoch = math.ceil(train_dataset_length / batch_size) is_unsupervised = isinstance(data_module, UnlabeledDataModule) @@ -725,7 +823,8 @@ def compute_metrics_single( # load predictions pred_df = pd.read_csv(preds_file, header=[0, 1, 2], index_col=0) keypoint_names = io_utils.get_keypoint_names( - cfg, csv_file=str(preds_file), header_rows=[0, 1, 2]) + cfg, csv_file=str(preds_file), header_rows=[0, 1, 2] + ) xyl_mask = pred_df.columns.get_level_values("coords").isin(["x", "y", "likelihood"]) tmp = pred_df.loc[:, xyl_mask].to_numpy().reshape(pred_df.shape[0], -1, 3) @@ -756,14 +855,18 @@ def compute_metrics_single( data_module is not None and cfg.data.get("columns_for_singleview_pca", None) is not None and len(cfg.data.columns_for_singleview_pca) != 0 - and not isinstance(data_module.dataset, MultiviewHeatmapDataset) # mirrored-only for now + and not isinstance( + data_module.dataset, MultiviewHeatmapDataset + ) # mirrored-only for now ): metrics_to_compute += ["pca_singleview"] if ( data_module is not None and cfg.data.get("mirrored_column_matches", None) is not None and len(cfg.data.mirrored_column_matches) != 0 - and not isinstance(data_module.dataset, MultiviewHeatmapDataset) # mirrored-only for now + and not isinstance( + data_module.dataset, MultiviewHeatmapDataset + ) # mirrored-only for now ): metrics_to_compute += ["pca_multiview"] @@ -795,7 +898,9 @@ def compute_metrics_single( # add train/val/test split if set is not None: temporal_norm_df["set"] = set - save_file = preds_file_path.with_name(preds_file_path.stem + "_temporal_norm.csv") + save_file = preds_file_path.with_name( + preds_file_path.stem + "_temporal_norm.csv" + ) temporal_norm_df.to_csv(save_file) result.temporal_norm_df = temporal_norm_df @@ -807,15 +912,22 @@ def compute_metrics_single( data_module=data_module, components_to_keep=cfg.losses.pca_singleview.components_to_keep, empirical_epsilon_percentile=cfg.losses.pca_singleview.get( - "empirical_epsilon_percentile", 1.0), + "empirical_epsilon_percentile", 1.0 + ), columns_for_singleview_pca=cfg.data.columns_for_singleview_pca, - centering_method=cfg.losses.pca_singleview.get("centering_method", None), + centering_method=cfg.losses.pca_singleview.get( + "centering_method", None + ), ) # re-fit pca on the labeled data to get params pca() # compute reprojection error - pcasv_error_per_keypoint = pca_singleview_reprojection_error(keypoints_pred, pca) - pcasv_df = pd.DataFrame(pcasv_error_per_keypoint, index=index, columns=keypoint_names) + pcasv_error_per_keypoint = pca_singleview_reprojection_error( + keypoints_pred, pca + ) + pcasv_df = pd.DataFrame( + pcasv_error_per_keypoint, index=index, columns=keypoint_names + ) # add train/val/test split if set is not None: pcasv_df["set"] = set @@ -839,19 +951,231 @@ def compute_metrics_single( data_module=data_module, components_to_keep=cfg.losses.pca_singleview.components_to_keep, empirical_epsilon_percentile=cfg.losses.pca_singleview.get( - "empirical_epsilon_percentile", 1.0), + "empirical_epsilon_percentile", 1.0 + ), mirrored_column_matches=cfg.data.mirrored_column_matches, ) # re-fit pca on the labeled data to get params pca() # compute reprojection error pcamv_error_per_keypoint = pca_multiview_reprojection_error(keypoints_pred, pca) - pcamv_df = pd.DataFrame(pcamv_error_per_keypoint, index=index, columns=keypoint_names) + pcamv_df = pd.DataFrame( + pcamv_error_per_keypoint, index=index, columns=keypoint_names + ) # add train/val/test split if set is not None: pcamv_df["set"] = set - save_file = preds_file_path.with_name(preds_file_path.stem + "_pca_multiview_error.csv") + save_file = preds_file_path.with_name( + preds_file_path.stem + "_pca_multiview_error.csv" + ) pcamv_df.to_csv(save_file) result.pca_mv_df = pcamv_df return result + + +@typechecked +def get_split_datasets( + cfg: DictConfig, + dataset: torch.utils.data.Dataset, +) -> tuple[Subset, Subset, Subset]: + """Split a dataset into train/val/test subsets with augmentation-aware handling. + + This mirrors the logic previously implemented in BaseDataModule._setup. + + Args: + cfg: Full config; split-related parameters are read from `cfg.training`: + - `train_prob`, optional `val_prob`, optional `test_prob` + - `train_frames` (int or float) + - `rng_seed_data_pt` (int) + dataset: The full dataset to split. + + Returns: + Tuple of (train_subset, val_subset, test_subset). + """ + datalen = len(dataset) + print(f"Number of labeled images in the full dataset (train+val+test): {datalen}") + + # derive split parameters from cfg + train_probability = cfg.training.get("train_prob", 0.8) + val_probability = cfg.training.get("val_prob", None) + test_probability = cfg.training.get("test_prob", None) + train_frames = cfg.training.get("train_frames", None) + torch_seed = cfg.training.get("rng_seed_data_pt", 42) + + # split data based on provided probabilities + data_splits_list = split_sizes_from_probabilities( + datalen, + train_probability=train_probability, + val_probability=val_probability, + test_probability=test_probability, + ) + + if ( + getattr(dataset, "imgaug_transform", None) is not None + and len(dataset.imgaug_transform) == 1 + ): + # no augmentations in the pipeline; subsets can share same underlying dataset + train_dataset, val_dataset, test_dataset = random_split( + dataset, + data_splits_list, + generator=torch.Generator().manual_seed(torch_seed), + ) + else: + # augmentations in the pipeline; we want validation and test datasets that only resize + # we can't simply change the imgaug pipeline in the datasets after they've been split + # because the subsets actually point to the same underlying dataset, so we create + # separate datasets here + train_idxs, val_idxs, test_idxs = random_split( + range(len(dataset)), + data_splits_list, + generator=torch.Generator().manual_seed(torch_seed), + ) + + train_dataset = Subset(copy.deepcopy(dataset), indices=list(train_idxs)) + val_dataset = Subset(copy.deepcopy(dataset), indices=list(val_idxs)) + test_dataset = Subset(copy.deepcopy(dataset), indices=list(test_idxs)) + + # only use the final resize transform for the validation and test datasets + # try to pull the final transform; if unavailable (e.g., multiview that doesn't resize + # by default), enforce resizing to dataset.height/width + if ( + getattr(dataset, "imgaug_transform", None) is not None + and len(dataset.imgaug_transform) > 0 + and dataset.imgaug_transform[-1].__str__().find("Resize") == 0 + ): + final_transform = iaa.Sequential([dataset.imgaug_transform[-1]]) + else: + height = getattr(dataset, "height", None) + width = getattr(dataset, "width", None) + if height is None or width is None: + raise AttributeError( + "Dataset must have 'height' and 'width' attributes when no final Resize transform is present." + ) + final_transform = iaa.Sequential( + [iaa.Resize({"height": height, "width": width})] + ) + + val_dataset.dataset.imgaug_transform = final_transform + if hasattr(val_dataset.dataset, "dataset"): + # this will get triggered for multiview datasets + print("val: updating children datasets with resize imgaug pipeline") + for _, dset in val_dataset.dataset.dataset.items(): + dset.imgaug_transform = final_transform + + test_dataset.dataset.imgaug_transform = final_transform + if hasattr(test_dataset.dataset, "dataset"): + # this will get triggered for multiview datasets + print("test: updating children datasets with resize imgaug pipeline") + for _, dset in test_dataset.dataset.dataset.items(): + dset.imgaug_transform = final_transform + + # further subsample training data if desired + if train_frames is not None: + n_frames = compute_num_train_frames(len(train_dataset), train_frames) + if n_frames < len(train_dataset): + # reflect further subsampling from train_frames + train_dataset.indices = train_dataset.indices[:n_frames] + + print( + f"Dataset splits -- " + f"train: {len(train_dataset)}, " + f"val: {len(val_dataset)}, " + f"test: {len(test_dataset)}" + ) + + return train_dataset, val_dataset, test_dataset + + +def get_dataloader_factory( + cfg: DictConfig, + dataset: torch.utils.data.Dataset, + splits: tuple[Subset, Subset, Subset], +) -> Callable[[str], DataLoader]: + """Returns stage -> dataloader for labeled data (train/val/test/full).""" + train_batch_size, val_batch_size = get_train_val_batches(cfg) + train_dataset, val_dataset, test_dataset = splits + + def get_dataloader(stage: str) -> DataLoader: + num_workers = cfg.training.get("num_workers") + if num_workers is None: + slurm_cpus = os.getenv("SLURM_CPUS_PER_TASK") + if slurm_cpus: + num_workers = int(slurm_cpus) + else: + # Fallback to os.cpu_count() + num_workers = os.cpu_count() + if stage == "train": + return DataLoader( + train_dataset, + batch_size=train_batch_size, + num_workers=num_workers, + persistent_workers=True if num_workers > 0 else False, + shuffle=True, + generator=torch.Generator().manual_seed(cfg.training.rng_seed_data_pt), + ) + if stage == "val": + return DataLoader( + val_dataset, + batch_size=val_batch_size, + num_workers=num_workers, + persistent_workers=True if num_workers > 0 else False, + ) + if stage == "test": + return DataLoader( + test_dataset, + batch_size=cfg.training.test_batch_size, + num_workers=num_workers, + persistent_workers=True if num_workers > 0 else False, + ) + if stage == "full": + return DataLoader( + dataset, + batch_size=val_batch_size, + num_workers=num_workers, + persistent_workers=True if num_workers > 0 else False, + ) + raise NotImplementedError(f"Unknown stage: {stage}") + + return get_dataloader + + +def get_training_trainer( + cfg: DictConfig, logger, callbacks, steps_per_epoch +) -> pl.Trainer: + """Get trainer for training.""" + # set up trainer + + cfg.training.num_gpus = max(cfg.training.num_gpus, 1) + + # initialize to Trainer defaults. Note max_steps defaults to -1. + min_steps, max_steps, min_epochs, max_epochs = (None, -1, None, None) + if "min_steps" in cfg.training: + min_steps = cfg.training.min_steps + max_steps = cfg.training.max_steps + else: + min_epochs = cfg.training.min_epochs + max_epochs = cfg.training.max_epochs + + # Unlike min_epoch/min_step, both of these are valid to specify. + check_val_every_n_epoch = cfg.training.get("check_val_every_n_epoch", 1) + val_check_interval = cfg.training.get("val_check_interval") + + return pl.Trainer( + accelerator="gpu", + devices=cfg.training.num_gpus, + max_epochs=max_epochs, + min_epochs=min_epochs, + max_steps=max_steps, + min_steps=min_steps, + check_val_every_n_epoch=check_val_every_n_epoch, + val_check_interval=val_check_interval, + log_every_n_steps=cfg.training.log_every_n_steps, + callbacks=callbacks, + logger=logger, + # To understand why we set this, see 'max_size_cycle' in UnlabeledDataModule. + limit_train_batches=cfg.training.get("limit_train_batches") or steps_per_epoch, + accumulate_grad_batches=cfg.training.get("accumulate_grad_batches", 1), + profiler=cfg.training.get("profiler", None), + sync_batchnorm=True, + )