diff --git a/test/core/reid/dataset/test_base.py b/test/core/reid/dataset/test_base.py index f58327a7..6cfee952 100644 --- a/test/core/reid/dataset/test_base.py +++ b/test/core/reid/dataset/test_base.py @@ -1,159 +1,55 @@ -from contextlib import ExitStack as DoesNotRaise +import os +import shutil import pytest -from trackers.core.reid.dataset.base import TripletsDataset +from trackers.core.reid.dataset.base import IdentityDataset +from trackers.core.reid.dataset.market_1501 import parse_market1501_dataset +from trackers.core.reid.model import ReIDModel +from trackers.utils.data_utils import unzip_file +from trackers.utils.downloader import download_file +DATASET_URL = "https://storage.googleapis.com/com-roboflow-marketing/trackers/datasets/market_1501.zip" -@pytest.mark.parametrize( - "tracker_id_to_images, exception", - [ - ( - {"0111": []}, - pytest.raises(ValueError), - ), # Single tracker with no images - should raise ValueError - ( - {"0111": ["0111_00000000.jpg"]}, - pytest.raises(ValueError), - ), # Single tracker with one image - should raise ValueError - ( - {"0111": ["0111_00000000.jpg", "0111_00000001.jpg"]}, - pytest.raises(ValueError), - ), # Single tracker with multiple images - should raise ValueError - ( - { - "0111": ["0111_00000000.jpg", "0111_00000001.jpg"], - "0112": ["0112_00000000.jpg"], - }, - pytest.raises(ValueError), - ), # Two trackers but one has only one image - should raise ValueError - ( - { - "0111": ["0111_00000000.jpg", "0111_00000001.jpg"], - "0112": ["0112_00000000.jpg", "0112_00000001.jpg"], - }, - DoesNotRaise(), - ), # Two trackers with multiple images - should not raise - ( - { - "0111": ["0111_00000000.jpg", "0111_00000001.jpg"], - "0112": ["0112_00000000.jpg", "0112_00000001.jpg"], - "0113": ["0113_00000000.jpg"], - }, - DoesNotRaise(), - ), # Three trackers, one with fewer images - should validate dataset length - ], -) -def test_triplet_dataset_initialization(tracker_id_to_images, exception): - with exception: - _ = TripletsDataset(tracker_id_to_images) + +@pytest.fixture +def market_1501_dataset(): + os.makedirs("test_data", exist_ok=True) + dataset_path = os.path.join("test_data", "Market-1501-v15.09.15") + zip_path = os.path.join("test_data", "market_1501.zip") + if not os.path.exists(dataset_path): + if not os.path.exists(zip_path): + download_file(DATASET_URL) + shutil.move("market_1501.zip", str(zip_path)) + unzip_file(str(zip_path), "test_data") + yield dataset_path @pytest.mark.parametrize( - "tracker_id_to_images, split_ratio, expected_train_size, expected_val_size, exception", # noqa: E501 - [ - ( - { - "0111": ["0111_00000000.jpg", "0111_00000001.jpg"], - "0112": ["0112_00000000.jpg", "0112_00000001.jpg"], - }, - 0.5, - 1, - 1, - pytest.raises(ValueError), - ), # Split results in only 1 tracker in test set - should raise ValueError - ( - { - "0111": ["0111_00000000.jpg", "0111_00000001.jpg"], - "0112": ["0112_00000000.jpg", "0112_00000001.jpg"], - "0113": ["0113_00000000.jpg", "0113_00000001.jpg"], - "0114": ["0114_00000000.jpg", "0114_00000001.jpg"], - "0115": ["0115_00000000.jpg", "0115_00000001.jpg"], - }, - 0.2, - 1, - 4, - pytest.raises(ValueError), - ), # Split results in only 1 tracker in test set - should raise ValueError - ( - { - "0111": ["0111_00000000.jpg", "0111_00000001.jpg"], - "0112": ["0112_00000000.jpg", "0112_00000001.jpg"], - "0113": ["0113_00000000.jpg", "0113_00000001.jpg"], - "0114": ["0114_00000000.jpg", "0114_00000001.jpg"], - "0115": ["0115_00000000.jpg", "0115_00000001.jpg"], - }, - 0.8, - 4, - 1, - pytest.raises(ValueError), - ), # Split results in only 1 tracker in val set - should raise ValueError - ( - { - "0111": ["0111_00000000.jpg", "0111_00000001.jpg"], - "0112": ["0112_00000000.jpg", "0112_00000001.jpg"], - "0113": ["0113_00000000.jpg", "0113_00000001.jpg"], - "0114": ["0114_00000000.jpg", "0114_00000001.jpg"], - "0115": ["0115_00000000.jpg", "0115_00000001.jpg"], - }, - 0.6, - 3, - 2, - DoesNotRaise(), - ), # Valid split with multiple trackers in both sets - ( - { - "0111": ["0111_00000000.jpg", "0111_00000001.jpg"], - "0112": ["0112_00000000.jpg", "0112_00000001.jpg"], - "0113": ["0113_00000000.jpg", "0113_00000001.jpg"], - "0114": ["0114_00000000.jpg", "0114_00000001.jpg"], - }, - 0.5, - 2, - 2, - DoesNotRaise(), - ), # 50% train, 50% validation - valid - ], + "dataset_split", ["bounding_box_train, bounding_box_test", "query"] ) -def test_triplet_dataset_split( - tracker_id_to_images, split_ratio, expected_train_size, expected_val_size, exception -): - with exception: - dataset = TripletsDataset(tracker_id_to_images) - train_dataset, val_dataset = dataset.split(split_ratio=split_ratio) - - assert len(train_dataset) == expected_train_size, ( - f"Expected train dataset size {expected_train_size}, " - f"got {len(train_dataset)}" - ) - assert len(val_dataset) == expected_val_size, ( - f"Expected validation dataset size {expected_val_size}, " - f"got {len(val_dataset)}" - ) +def test_identity_dataset(market_1501_dataset, dataset_split): + dataset_path = os.path.join(market_1501_dataset, dataset_split) + dataset = IdentityDataset(parse_market1501_dataset(data_dir=dataset_path)) + if dataset_split == "bounding_box_train": + assert len(dataset) == 12936 + assert dataset.get_num_identities() == 751 + elif dataset_split == "bounding_box_test": + assert len(dataset) == 15913 + assert dataset.get_num_identities() == 751 + elif dataset_split == "query": + assert len(dataset) == 3368 + assert dataset.get_num_identities() == 750 @pytest.mark.parametrize( - "tracker_id_to_images, tracker_id, exception", - [ - ( - { - "0111": ["0111_00000000.jpg", "0111_00000001.jpg"], - "0112": ["0112_00000000.jpg", "0112_00000001.jpg"], - "0113": ["0113_00000000.jpg", "0113_00000001.jpg"], - }, - "0111", - DoesNotRaise(), - ), - ], + "dataset_split", ["bounding_box_train, bounding_box_test", "query"] ) -def test_get_triplet_image_paths(tracker_id_to_images, tracker_id, exception) -> None: - with exception: - dataset = TripletsDataset(tracker_id_to_images) - anchor_path, positive_path, negative_path = dataset._get_triplet_image_paths( - tracker_id - ) - - assert anchor_path in tracker_id_to_images[tracker_id] - assert positive_path in tracker_id_to_images[tracker_id] - assert negative_path not in tracker_id_to_images[tracker_id] - assert anchor_path != positive_path +def test_reid_model_classification_head(market_1501_dataset, dataset_split): + dataset_path = os.path.join(market_1501_dataset, dataset_split) + dataset = IdentityDataset(parse_market1501_dataset(data_dir=dataset_path)) + model = ReIDModel.from_timm("resnet50") + model.add_classification_head( + num_classes=dataset.get_num_identities(), freeze_backbone=True + ) + assert model.backbone.fc.out_features == dataset.get_num_identities() diff --git a/test/core/reid/dataset/test_market_1501.py b/test/core/reid/dataset/test_market_1501.py deleted file mode 100644 index 2f435c34..00000000 --- a/test/core/reid/dataset/test_market_1501.py +++ /dev/null @@ -1,52 +0,0 @@ -from unittest.mock import patch - -import pytest - -from trackers.core.reid.dataset.market_1501 import parse_market1501_dataset - - -@pytest.mark.parametrize( - "mock_glob_output, expected_result", - [ - ( - # Empty dataset - [], - {}, - ), - ( - # Single image for one person - ["0111_00000000.jpg"], - {"0111": ["0111_00000000.jpg"]}, - ), - ( - # Multiple images for one person - ["0111_00000000.jpg", "0111_00000001.jpg"], - {"0111": ["0111_00000000.jpg", "0111_00000001.jpg"]}, - ), - ( - # Multiple people with multiple images - [ - "0111_00000000.jpg", - "0111_00000001.jpg", - "0112_00000000.jpg", - "0112_00000001.jpg", - ], - { - "0111": ["0111_00000000.jpg", "0111_00000001.jpg"], - "0112": ["0112_00000000.jpg", "0112_00000001.jpg"], - }, - ), - ( - # Multiple people with varying number of images - ["0111_00000000.jpg", "0111_00000001.jpg", "0112_00000000.jpg"], - { - "0111": ["0111_00000000.jpg", "0111_00000001.jpg"], - "0112": ["0112_00000000.jpg"], - }, - ), - ], -) -def test_parse_market1501_dataset(mock_glob_output, expected_result): - with patch("glob.glob", return_value=mock_glob_output): - result = parse_market1501_dataset("dummy_path") - assert result == expected_result diff --git a/trackers/core/reid/__init__.py b/trackers/core/reid/__init__.py index ac00fe1c..e4c6eb12 100644 --- a/trackers/core/reid/__init__.py +++ b/trackers/core/reid/__init__.py @@ -3,11 +3,11 @@ logger = get_logger(__name__) try: - from trackers.core.reid.dataset.base import TripletsDataset + from trackers.core.reid.dataset.base import IdentityDataset from trackers.core.reid.dataset.market_1501 import get_market1501_dataset from trackers.core.reid.model import ReIDModel - __all__ = ["ReIDModel", "TripletsDataset", "get_market1501_dataset"] + __all__ = ["IdentityDataset", "ReIDModel", "get_market1501_dataset"] except ImportError: logger.warning( "ReIDModel dependencies not installed. ReIDModel will not be available. " diff --git a/trackers/core/reid/dataset/base.py b/trackers/core/reid/dataset/base.py index fc4efd43..22d192a2 100644 --- a/trackers/core/reid/dataset/base.py +++ b/trackers/core/reid/dataset/base.py @@ -1,8 +1,6 @@ from __future__ import annotations -import random -from pathlib import Path -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch from PIL import Image @@ -10,162 +8,49 @@ from torch.utils.data import Dataset from torchvision.transforms import Compose, ToTensor -from trackers.core.reid.dataset.utils import validate_tracker_id_to_images - - -class TripletsDataset(Dataset): - """A dataset that provides triplets of images for training ReID models. - - This dataset is designed for training models with triplet loss, where each sample - consists of an anchor image, a positive image (same identity as anchor), - and a negative image (different identity from anchor). - - Args: - tracker_id_to_images (dict[str, list[str]]): Dictionary mapping tracker IDs - to lists of image paths - transforms (Optional[Compose]): Optional image transformations to apply - - Attributes: - tracker_id_to_images (dict[str, list[str]]): Dictionary mapping tracker IDs - to lists of image paths - transforms (Optional[Compose]): Optional image transformations to apply - tracker_ids (list[str]): List of all unique tracker IDs in the dataset - """ +class IdentityDataset(Dataset): def __init__( self, - tracker_id_to_images: dict[str, list[str]], - transforms: Optional[Compose] = None, + data: list[Tuple[str, int, int]], + transforms: Optional[Union[Callable, Compose]] = None, ): - self.tracker_id_to_images = validate_tracker_id_to_images(tracker_id_to_images) - self.transforms = transforms or ToTensor() - self.tracker_ids = list(self.tracker_id_to_images.keys()) - - @classmethod - def from_image_directories( - cls, - root_directory: str, - transforms: Optional[Compose] = None, - image_extensions: Tuple[str, ...] = (".jpg", ".jpeg", ".png"), - ) -> TripletsDataset: - """ - Create TripletsDataset from a directory structured by tracker IDs. - - Args: - root_directory (str): Root directory with tracker folders. - transforms (Optional[Compose]): Optional image transformations. - image_extensions (Tuple[str, ...]): Valid image extensions to load. - - Returns: - TripletsDataset: An initialized dataset. - """ - root_path = Path(root_directory) - tracker_id_to_images = {} + self.data = data + self.transforms = transforms or Compose([ToTensor()]) - for tracker_path in sorted(root_path.iterdir()): - if not tracker_path.is_dir(): - continue + def __len__(self): + return len(self.data) - image_paths = sorted( - [ - str(image_path) - for image_path in tracker_path.glob("*") - if image_path.suffix.lower() in image_extensions - and image_path.is_file() - ] - ) - - if image_paths: - tracker_id_to_images[tracker_path.name] = image_paths - - return cls( - tracker_id_to_images=tracker_id_to_images, - transforms=transforms, - ) - - def __len__(self) -> int: - """ - Return the number of unique tracker IDs (identities) in the dataset. - - Returns: - int: The total number of unique identities (tracker IDs) available for - sampling triplets. - """ - return len(self.tracker_ids) - - def _load_and_transform_image(self, image_path: str) -> torch.Tensor: + def __getitem__(self, index: int) -> dict[str, Union[torch.Tensor, str, int]]: + image_path, identity, camera_id = self.data[index] image = Image.open(image_path).convert("RGB") - if self.transforms: - image = self.transforms(image) - return image - - def _get_triplet_image_paths(self, tracker_id: str) -> Tuple[str, str, str]: - tracker_id_image_paths = self.tracker_id_to_images[tracker_id] - - anchor_image_path, positive_image_path = random.sample( # nosec B311 - tracker_id_image_paths, 2 - ) - - negative_candidates = [tid for tid in self.tracker_ids if tid != tracker_id] - negative_tracker_id = random.choice(negative_candidates) # nosec B311 - - negative_image_path = random.choice( # nosec B311 - self.tracker_id_to_images[negative_tracker_id] - ) - - return anchor_image_path, positive_image_path, negative_image_path - - def __getitem__( - self, index: int - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Retrieve a random triplet (anchor, positive, negative) of images for a given - identity. - - For the tracker ID at the given index, samples two different images as the - anchor and positive (same identity), and one image from a different tracker ID - as the negative (different identity). - - Args: - index (int): Index of the tracker ID (identity) to sample the triplet from. - - Returns: - Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - A tuple containing the anchor, positive, and negative image tensors. - """ - tracker_id = self.tracker_ids[index] - - anchor_image_path, positive_image_path, negative_image_path = ( - self._get_triplet_image_paths(tracker_id) - ) - - anchor_image = self._load_and_transform_image(anchor_image_path) - positive_image = self._load_and_transform_image(positive_image_path) - negative_image = self._load_and_transform_image(negative_image_path) + image = self.transforms(image) + return { + "image": image, + "image_path": image_path, + "identity": identity, + "camera_id": camera_id, + } - return anchor_image, positive_image, negative_image + def get_num_identities(self) -> int: + identities = set() + for items in self.data: + identities.add(items[1]) + return len(identities) def split( self, split_ratio: float = 0.8, random_state: Optional[Union[int, float, str, bytes, bytearray]] = None, shuffle: bool = True, - ) -> Tuple[TripletsDataset, TripletsDataset]: - train_tracker_id_to_images, validation_tracker_id_to_images = train_test_split( - list(self.tracker_id_to_images.keys()), + ) -> Tuple[IdentityDataset, IdentityDataset]: + train_data, validation_data = train_test_split( + data=self.data, train_ratio=split_ratio, random_state=random_state, shuffle=shuffle, ) - train_tracker_id_to_images = { - tracker_id: self.tracker_id_to_images[tracker_id] - for tracker_id in train_tracker_id_to_images - } - validation_tracker_id_to_images = { - tracker_id: self.tracker_id_to_images[tracker_id] - for tracker_id in validation_tracker_id_to_images - } return ( - TripletsDataset(train_tracker_id_to_images, self.transforms), - TripletsDataset(validation_tracker_id_to_images, self.transforms), + IdentityDataset(train_data, transforms=self.transforms), + IdentityDataset(validation_data, transforms=self.transforms), ) diff --git a/trackers/core/reid/dataset/market_1501.py b/trackers/core/reid/dataset/market_1501.py index a1b15b73..4afa6b8f 100644 --- a/trackers/core/reid/dataset/market_1501.py +++ b/trackers/core/reid/dataset/market_1501.py @@ -1,61 +1,125 @@ import glob import os -from collections import defaultdict -from typing import Dict, List, Optional, Tuple, Union +import re +from typing import Callable, Optional, Tuple, Union from torchvision.transforms import Compose -from trackers.core.reid.dataset.base import TripletsDataset +from trackers.core.reid.dataset.base import IdentityDataset -def parse_market1501_dataset(data_dir: str) -> Dict[str, List[str]]: +def parse_market1501_dataset( + data_dir: str, relabel: bool = False +) -> list[Tuple[str, int, int]]: """Parse the [Market1501 dataset](https://paperswithcode.com/dataset/market-1501) - to create a dictionary mapping tracker IDs to lists of image paths. + to create a list of tuples, each containing an image path, an identity label, + and a camera label. Args: data_dir (str): The path to the Market1501 dataset. + relabel (bool): Whether to relabel the identities to a compact range of starting + from 0. Returns: - Dict[str, List[str]]: A dictionary mapping tracker IDs to lists of image paths. + list[Tuple[str, int, int]]: A list of tuples, each containing an image path, + an identity label, and a camera label. """ image_files = glob.glob(os.path.join(data_dir, "*.jpg")) - tracker_id_to_images = defaultdict(list) - for image_file in image_files: - tracker_id = os.path.basename(image_file).split("_")[0] - tracker_id_to_images[tracker_id].append(image_file) - return dict(tracker_id_to_images) + file_pattern = re.compile(r"([-\d]+)_c(\d)") + id_container = set() + data = [] + for img_path in image_files: + match = file_pattern.search(img_path) + if match is None: + continue + identity, camera_id = map(int, match.groups()) + if identity != -1: + id_container.add(identity) + data.append((img_path, identity, camera_id - 1)) + + if relabel: + id_to_label = {identity: label for label, identity in enumerate(id_container)} + data = [ + (img_path, id_to_label[identity], camera_id) + for img_path, identity, camera_id in data + ] + + return data def get_market1501_dataset( data_dir: str, + relabel: bool = False, split_ratio: Optional[float] = None, random_state: Optional[Union[int, float, str, bytes, bytearray]] = None, shuffle: bool = True, - transforms: Optional[Compose] = None, -) -> Union[TripletsDataset, Tuple[TripletsDataset, TripletsDataset]]: + transforms: Optional[Union[Callable, Compose]] = None, +) -> Union[ + IdentityDataset, + Tuple[IdentityDataset, IdentityDataset], + Tuple[IdentityDataset, IdentityDataset, IdentityDataset], + Tuple[IdentityDataset, IdentityDataset, IdentityDataset, IdentityDataset], +]: """Get the [Market1501 dataset](https://paperswithcode.com/dataset/market-1501). Args: data_dir (str): The path to the bounding box train/test directory of the [Market1501 dataset](https://paperswithcode.com/dataset/market-1501). + relabel (bool): Whether to relabel the identities to a compact range of starting + from 0. split_ratio (Optional[float]): The ratio of the dataset to split into training and validation sets. If `None`, the dataset is returned as a single - `TripletsDataset` object, otherwise the dataset is split into a tuple of - training and validation `TripletsDataset` objects. + `IdentityDataset` object, otherwise the dataset is split into a tuple of + training and validation `IdentityDataset` objects. random_state (Optional[Union[int, float, str, bytes, bytearray]]): The random state to use for the split. shuffle (bool): Whether to shuffle the dataset. - transforms (Optional[Compose]): The transforms to apply to the dataset. + transforms (Optional[Union[Callable, Compose]]): The transforms to apply to + the dataset. Returns: - Tuple[TripletsDataset, TripletsDataset]: A tuple of training and validation - `TripletsDataset` objects. - """ - tracker_id_to_images = parse_market1501_dataset(data_dir) - dataset = TripletsDataset(tracker_id_to_images, transforms) - if split_ratio is not None: - train_dataset, validation_dataset = dataset.split( - split_ratio=split_ratio, random_state=random_state, shuffle=shuffle + Union[ + IdentityDataset, + Tuple[IdentityDataset, IdentityDataset], + Tuple[IdentityDataset, IdentityDataset, IdentityDataset], + Tuple[IdentityDataset, IdentityDataset, IdentityDataset, IdentityDataset], + ]: The return type depends on the directory structure and split_ratio: + - If standard Market1501 structure with split_ratio: (train, validation, test, query) + - If standard Market1501 structure without split_ratio: (train, test, query) + - If custom directory with split_ratio: (train, validation) + - If custom directory without split_ratio: single IdentityDataset + """ # noqa: E501 + dirs = os.listdir(data_dir) + if "bounding_box_train" in dirs and "bounding_box_test" in dirs and "query" in dirs: + train_dataset = IdentityDataset( + parse_market1501_dataset( + os.path.join(data_dir, "bounding_box_train"), relabel=relabel + ), + transforms=transforms, + ) + test_dataset = IdentityDataset( + parse_market1501_dataset( + os.path.join(data_dir, "bounding_box_test"), relabel=relabel + ), + transforms=transforms, + ) + query_dataset = IdentityDataset( + parse_market1501_dataset(os.path.join(data_dir, "query"), relabel=relabel), + transforms=transforms, + ) + if split_ratio is not None: + train_dataset, validation_dataset = train_dataset.split( + split_ratio=split_ratio, random_state=random_state, shuffle=shuffle + ) + return train_dataset, validation_dataset, test_dataset, query_dataset + return train_dataset, test_dataset, query_dataset + else: + dataset = IdentityDataset( + parse_market1501_dataset(data_dir, relabel=relabel), transforms=transforms ) - return train_dataset, validation_dataset - return dataset + if split_ratio is not None: + train_dataset, validation_dataset = dataset.split( + split_ratio=split_ratio, random_state=random_state, shuffle=shuffle + ) + return train_dataset, validation_dataset + return dataset diff --git a/trackers/core/reid/dataset/utils.py b/trackers/core/reid/dataset/utils.py deleted file mode 100644 index 41436a82..00000000 --- a/trackers/core/reid/dataset/utils.py +++ /dev/null @@ -1,35 +0,0 @@ -from trackers.log import get_logger - -logger = get_logger(__name__) - - -def validate_tracker_id_to_images( - tracker_id_to_images: dict[str, list[str]], -) -> dict[str, list[str]]: - """Validates a dictionary that maps tracker IDs to lists of image paths for the - `TripletsDataset` for training ReID models using triplet loss. - - Args: - tracker_id_to_images (dict[str, list[str]]): The tracker ID to images - dictionary. - - Returns: - dict[str, list[str]]: The validated tracker ID to images dictionary. - """ - valid_tracker_ids = {} - for tracker_id, image_paths in tracker_id_to_images.items(): - if len(image_paths) < 2: - logger.warning( - f"Tracker ID '{tracker_id}' has less than 2 images. " - f"Skipping this tracker ID." - ) - else: - valid_tracker_ids[tracker_id] = image_paths - - if len(valid_tracker_ids) < 2: - raise ValueError( - "Tracker ID to images dictionary must contain at least 2 items " - "to select negative samples." - ) - - return valid_tracker_ids diff --git a/trackers/core/reid/model.py b/trackers/core/reid/model.py index bbd7aacf..ca73f716 100644 --- a/trackers/core/reid/model.py +++ b/trackers/core/reid/model.py @@ -1,6 +1,5 @@ from __future__ import annotations -import json import os from typing import Any, Callable, Optional, Union @@ -10,151 +9,106 @@ import timm import torch import torch.nn as nn -import torch.optim as optim -from safetensors.torch import save_file from timm.data import resolve_data_config from timm.data.transforms_factory import create_transform from torch.utils.data import DataLoader from torchvision.transforms import Compose, ToPILImage -from tqdm.auto import tqdm -from trackers.core.reid.callbacks import BaseCallback -from trackers.core.reid.metrics import ( - TripletAccuracyMetric, - TripletMetric, -) +from trackers.core.reid.dataset.base import IdentityDataset +from trackers.core.reid.trainer.cross_entropy_trainer import CrossEntropyTrainer from trackers.log import get_logger from trackers.utils.torch_utils import load_safetensors_checkpoint, parse_device_spec logger = get_logger(__name__) -def _initialize_reid_model_from_timm( - cls, - model_name_or_checkpoint_path: str, - device: Optional[str] = "auto", - get_pooled_features: bool = True, - **kwargs, -): - if model_name_or_checkpoint_path not in timm.list_models( - filter=model_name_or_checkpoint_path, pretrained=True - ): - probable_model_name_list = timm.list_models( - f"*{model_name_or_checkpoint_path}*", pretrained=True - ) - if len(probable_model_name_list) == 0: - raise ValueError( - f"Model {model_name_or_checkpoint_path} not found in timm. " - + "Please check the model name and try again." - ) - logger.warning( - f"Model {model_name_or_checkpoint_path} not found in timm. " - + f"Using {probable_model_name_list[0]} instead." - ) - model_name_or_checkpoint_path = probable_model_name_list[0] - if not get_pooled_features: - kwargs["global_pool"] = "" - model = timm.create_model( - model_name_or_checkpoint_path, pretrained=True, num_classes=0, **kwargs - ) - config = resolve_data_config(model.pretrained_cfg) - transforms = create_transform(**config) - model_metadata = { - "model_name_or_checkpoint_path": model_name_or_checkpoint_path, - "get_pooled_features": get_pooled_features, - "kwargs": kwargs, - } - return cls(model, device, transforms, model_metadata) - +class FeatureExtractorModel(nn.Module): + def __init__(self, backbone: nn.Module): + super().__init__() + self.backbone = backbone + self.fc: Optional[nn.Linear] = None -def _initialize_reid_model_from_checkpoint(cls, checkpoint_path: str): - state_dict, config = load_safetensors_checkpoint(checkpoint_path) - reid_model_instance = _initialize_reid_model_from_timm( - cls, **config["model_metadata"] - ) - if config["projection_dimension"]: - reid_model_instance._add_projection_layer( - projection_dimension=config["projection_dimension"] - ) - for k, v in state_dict.items(): - state_dict[k].to(reid_model_instance.device) - reid_model_instance.backbone_model.load_state_dict(state_dict) - return reid_model_instance + def add_classification_head(self, num_classes: int) -> None: + self.fc = nn.Linear(self.backbone.num_features, num_classes) + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + features = self.backbone.forward_features(x) + pooled_features = self.backbone.global_pool(features) + return pooled_features -class ReIDModel: - """ - A ReID model that is used to extract features from detection crops for trackers - that utilize appearance features. + def forward(self, x: torch.Tensor) -> torch.Tensor: + pooled_features = self.forward_features(x) + output = self.fc(pooled_features) if self.fc is not None else pooled_features + return output - Args: - backbone_model (nn.Module): The torch model to use as the backbone. - device (Optional[str]): The device to run the model on. - transforms (Optional[Union[Callable, list[Callable]]]): The transforms to - apply to the input images. - model_metadata (dict[str, Any]): Metadata about the model architecture. - """ +class ReIDModel: def __init__( self, - backbone_model: nn.Module, + backbone: nn.Module, device: Optional[str] = "auto", - transforms: Optional[Union[Callable, list[Callable]]] = None, + transforms: Optional[Union[Callable, list[Callable], Compose]] = None, model_metadata: dict[str, Any] = {}, ): - self.backbone_model = backbone_model self.device = parse_device_spec(device or "auto") - self.backbone_model.to(self.device) - self.backbone_model.eval() - self.train_transforms = ( - (Compose(*transforms) if isinstance(transforms, list) else transforms) - if transforms is not None - else None - ) - self.inference_transforms = Compose( - [ToPILImage(), *transforms] - if isinstance(transforms, list) - else [ToPILImage(), transforms] - ) + self.feature_extractor = ( + FeatureExtractorModel(backbone) + if not isinstance(backbone, FeatureExtractorModel) + else backbone + ).to(self.device) + self._initialize_transforms(transforms) self.model_metadata = model_metadata + def _initialize_transforms( + self, transforms: Optional[Union[Callable, list[Callable], Compose]] + ) -> None: + if isinstance(transforms, list): + self.train_transforms = Compose(transforms) + self.inference_transforms = Compose([ToPILImage(), *transforms]) + else: + self.train_transforms = Compose([transforms]) + self.inference_transforms = Compose([ToPILImage(), transforms]) + @classmethod def from_timm( cls, model_name_or_checkpoint_path: str, device: Optional[str] = "auto", - get_pooled_features: bool = True, **kwargs, ) -> ReIDModel: - """ - Create a `ReIDModel` with a [timm](https://huggingface.co/docs/timm) - model as the backbone. - - Args: - model_name_or_checkpoint_path (str): Name of the timm model to use or - path to a safetensors checkpoint. If the exact model name is not - found, the closest match from `timm.list_models` will be used. - device (str): Device to run the model on. - get_pooled_features (bool): Whether to get the pooled features from the - model or not. - **kwargs: Additional keyword arguments to pass to - [`timm.create_model`](https://huggingface.co/docs/timm/en/reference/models#timm.create_model). - - Returns: - ReIDModel: A new instance of `ReIDModel`. - """ - if os.path.exists(model_name_or_checkpoint_path): - return _initialize_reid_model_from_checkpoint( - cls, model_name_or_checkpoint_path + if not os.path.exists(model_name_or_checkpoint_path): + model = timm.create_model( + model_name_or_checkpoint_path, pretrained=True, num_classes=0, **kwargs ) + config = resolve_data_config(model.pretrained_cfg) + transforms = create_transform(**config) + return cls(model, device, transforms) else: - return _initialize_reid_model_from_timm( - cls, - model_name_or_checkpoint_path, - device, - get_pooled_features, - **kwargs, + state_dict, config = load_safetensors_checkpoint( + model_name_or_checkpoint_path + ) + model = timm.create_model( + model_name_or_checkpoint_path, pretrained=True, num_classes=0, **kwargs ) + config = resolve_data_config(model.pretrained_cfg) + transforms = create_transform(**config) + reid_model_instance = cls(model, device, transforms) + if config["num_classes"]: + reid_model_instance.add_classification_head( + num_classes=config["num_classes"] + ) + for k, _ in state_dict.items(): + state_dict[k].to(reid_model_instance.device) + reid_model_instance.feature_extractor.load_state_dict(state_dict) + return reid_model_instance + + def add_classification_head( + self, num_classes: int, freeze_backbone: bool = False + ) -> None: + if freeze_backbone: + for param in self.feature_extractor.backbone.parameters(): + param.requires_grad = False + self.feature_extractor.add_classification_head(num_classes) def extract_features( self, detections: sv.Detections, frame: Union[np.ndarray, PIL.Image.Image] @@ -175,375 +129,57 @@ def extract_features( if isinstance(frame, PIL.Image.Image): frame = np.array(frame) - features = [] + features_list = [] with torch.inference_mode(): for box in detections.xyxy: crop = sv.crop_image(image=frame, xyxy=[*box.astype(int)]) - tensor = self.inference_transforms(crop).unsqueeze(0).to(self.device) - feature = ( - torch.squeeze(self.backbone_model(tensor)).cpu().numpy().flatten() + crop_tensor = ( + self.inference_transforms(crop).unsqueeze(0).to(self.device) ) - features.append(feature) - - return np.array(features) - - def _add_projection_layer( - self, projection_dimension: Optional[int] = None, freeze_backbone: bool = False - ): - """ - Perform model surgery to add a projection layer to the model and freeze the - backbone if specified. The backbone is only frozen if `projection_dimension` - is specified. - - Args: - projection_dimension (Optional[int]): The dimension of the projection layer. - freeze_backbone (bool): Whether to freeze the backbone of the model during - training. - """ - if projection_dimension is not None: - # Freeze backbone only if specified and projection_dimension is mentioned - if freeze_backbone: - for param in self.backbone_model.parameters(): - param.requires_grad = False - - # Add projection layer if projection_dimension is specified - self.backbone_model = nn.Sequential( - self.backbone_model, - nn.Linear(self.backbone_model.num_features, projection_dimension), - ) - self.backbone_model.to(self.device) - - def _train_step( - self, - anchor_image: torch.Tensor, - positive_image: torch.Tensor, - negative_image: torch.Tensor, - metrics_list: list[TripletMetric], - ) -> dict[str, float]: - """ - Perform a single training step. - - Args: - anchor_image (torch.Tensor): The anchor image. - positive_image (torch.Tensor): The positive image. - negative_image (torch.Tensor): The negative image. - metrics_list (list[Metric]): The list of metrics to update. - """ - self.optimizer.zero_grad() - anchor_image_features = self.backbone_model(anchor_image) - positive_image_features = self.backbone_model(positive_image) - negative_image_features = self.backbone_model(negative_image) - - loss = self.criterion( - anchor_image_features, - positive_image_features, - negative_image_features, - ) - loss.backward() - self.optimizer.step() - - # Update metrics - for metric in metrics_list: - metric.update( - anchor_image_features.detach(), - positive_image_features.detach(), - negative_image_features.detach(), - ) + pooled_features = self.feature_extractor.forward_features(crop_tensor) + pooled_features = torch.squeeze(pooled_features).cpu().numpy().flatten() + features_list.append(pooled_features) - train_logs = {"train/loss": loss.item()} - for metric in metrics_list: - train_logs[f"train/{metric!s}"] = metric.compute() - - return train_logs - - def _validation_step( - self, - anchor_image: torch.Tensor, - positive_image: torch.Tensor, - negative_image: torch.Tensor, - metrics_list: list[TripletMetric], - ) -> dict[str, float]: - """ - Perform a single validation step. - - Args: - anchor_image (torch.Tensor): The anchor image. - positive_image (torch.Tensor): The positive image. - negative_image (torch.Tensor): The negative image. - metrics_list (list[Metric]): The list of metrics to update. - """ - with torch.inference_mode(): - anchor_image_features = self.backbone_model(anchor_image) - positive_image_features = self.backbone_model(positive_image) - negative_image_features = self.backbone_model(negative_image) - - loss = self.criterion( - anchor_image_features, - positive_image_features, - negative_image_features, - ) - - # Update metrics - for metric in metrics_list: - metric.update( - anchor_image_features.detach(), - positive_image_features.detach(), - negative_image_features.detach(), - ) - - validation_logs = {"validation/loss": loss.item()} - for metric in metrics_list: - validation_logs[f"validation/{metric!s}"] = metric.compute() - - return validation_logs + return np.array(features_list) def train( self, train_loader: DataLoader, epochs: int, + num_classes: int, validation_loader: Optional[DataLoader] = None, - projection_dimension: Optional[int] = None, freeze_backbone: bool = False, - learning_rate: float = 5e-5, - weight_decay: float = 0.0, - triplet_margin: float = 1.0, + label_smoothing: float = 1e-2, + learning_rate: float = 1e-3, + weight_decay: float = 1e-2, random_state: Optional[Union[int, float, str, bytes, bytearray]] = None, checkpoint_interval: Optional[int] = None, log_dir: str = "logs", log_to_matplotlib: bool = False, log_to_tensorboard: bool = False, log_to_wandb: bool = False, - ) -> None: - """ - Train/fine-tune the ReID model. - - Args: - train_loader (DataLoader): The training data loader. - epochs (int): The number of epochs to train the model. - validation_loader (Optional[DataLoader]): The validation data loader. - projection_dimension (Optional[int]): The dimension of the projection layer. - freeze_backbone (bool): Whether to freeze the backbone of the model. The - backbone is only frozen if `projection_dimension` is specified. - learning_rate (float): The learning rate to use for the optimizer. - weight_decay (float): The weight decay to use for the optimizer. - triplet_margin (float): The margin to use for the triplet loss. - random_state (Optional[Union[int, float, str, bytes, bytearray]]): The - random state to use for the training. - checkpoint_interval (Optional[int]): The interval to save checkpoints. - log_dir (str): The directory to save logs. - log_to_matplotlib (bool): Whether to log to matplotlib. - log_to_tensorboard (bool): Whether to log to tensorboard. - log_to_wandb (bool): Whether to log to wandb. If `checkpoint_interval` is - specified, the model will be logged to wandb as well. - Project and entity name should be set using the environment variables - `WANDB_PROJECT` and `WANDB_ENTITY`. For more details, refer to - [wandb environment variables](https://docs.wandb.ai/guides/track/environment-variables). - """ - os.makedirs(log_dir, exist_ok=True) - os.makedirs(os.path.join(log_dir, "checkpoints"), exist_ok=True) - os.makedirs(os.path.join(log_dir, "tensorboard_logs"), exist_ok=True) - - if random_state is not None: - torch.manual_seed(random_state) - - self._add_projection_layer(projection_dimension, freeze_backbone) - - # Initialize optimizer, criterion and metrics - self.optimizer = optim.Adam( - self.backbone_model.parameters(), - lr=learning_rate, - weight_decay=weight_decay, - ) - self.criterion = nn.TripletMarginLoss(margin=triplet_margin) - metrics_list: list[TripletMetric] = [TripletAccuracyMetric()] - - config = { - "epochs": epochs, - "learning_rate": learning_rate, - "weight_decay": weight_decay, - "random_state": random_state, - "projection_dimension": projection_dimension, - "freeze_backbone": freeze_backbone, - "triplet_margin": triplet_margin, - "model_metadata": self.model_metadata, - } - - # Initialize callbacks - callbacks: list[BaseCallback] = [] - if log_to_matplotlib: - try: - from trackers.core.reid.callbacks import MatplotlibCallback - - callbacks.append(MatplotlibCallback(log_dir=log_dir)) - except (ImportError, AttributeError) as e: - logger.error( - "Metric logging dependencies are not installed. " - "Please install it using `pip install trackers[metrics]`.", - ) - raise e - if log_to_tensorboard: - try: - from trackers.core.reid.callbacks import TensorboardCallback - - callbacks.append( - TensorboardCallback( - log_dir=os.path.join(log_dir, "tensorboard_logs") - ) - ) - except (ImportError, AttributeError) as e: - logger.error( - "Metric logging dependencies are not installed. " - "Please install it using `pip install trackers[metrics]`." - ) - raise e - - if log_to_wandb: - try: - from trackers.core.reid.callbacks import WandbCallback - - callbacks.append(WandbCallback(config=config)) - except (ImportError, AttributeError) as e: - logger.error( - "Metric logging dependencies are not installed. " - "Please install it using `pip install trackers[metrics]`." - ) - raise e - - # Training loop over epochs - for epoch in tqdm(range(epochs), desc="Training"): - # Reset metrics at the start of each epoch - for metric in metrics_list: - metric.reset() - - # Training loop over batches - accumulated_train_logs: dict[str, Union[float, int]] = {} - for idx, data in tqdm( - enumerate(train_loader), - total=len(train_loader), - desc=f"Training Epoch {epoch + 1}/{epochs}", - leave=False, - ): - anchor_image, positive_image, negative_image = data - if self.train_transforms is not None: - anchor_image = self.train_transforms(anchor_image) - positive_image = self.train_transforms(positive_image) - negative_image = self.train_transforms(negative_image) - - anchor_image = anchor_image.to(self.device) - positive_image = positive_image.to(self.device) - negative_image = negative_image.to(self.device) - - if callbacks: - for callback in callbacks: - callback.on_train_batch_start( - {}, epoch * len(train_loader) + idx - ) - - train_logs = self._train_step( - anchor_image, positive_image, negative_image, metrics_list - ) - - for key, value in train_logs.items(): - accumulated_train_logs[key] = ( - accumulated_train_logs.get(key, 0) + value - ) - - if callbacks: - for callback in callbacks: - for key, value in train_logs.items(): - callback.on_train_batch_end( - {f"batch/{key}": value}, epoch * len(train_loader) + idx - ) - - for key, value in accumulated_train_logs.items(): - accumulated_train_logs[key] = value / len(train_loader) - - # Compute and add training metrics to logs - for metric in metrics_list: - accumulated_train_logs[f"train/{metric!s}"] = metric.compute() - # Metrics are reset at the start of the next epoch or before validation - - if callbacks: - for callback in callbacks: - callback.on_train_epoch_end(accumulated_train_logs, epoch) - - # Validation loop over batches - accumulated_validation_logs: dict[str, Union[float, int]] = {} + ): + if isinstance(train_loader.dataset, IdentityDataset): if validation_loader is not None: - # Reset metrics for validation - for metric in metrics_list: - metric.reset() - for idx, data in tqdm( - enumerate(validation_loader), - total=len(validation_loader), - desc=f"Validation Epoch {epoch + 1}/{epochs}", - leave=False, - ): - if callbacks: - for callback in callbacks: - callback.on_validation_batch_start( - {}, epoch * len(train_loader) + idx - ) - - anchor_image, positive_image, negative_image = data - if self.train_transforms is not None: - anchor_image = self.train_transforms(anchor_image) - positive_image = self.train_transforms(positive_image) - negative_image = self.train_transforms(negative_image) - - anchor_image = anchor_image.to(self.device) - positive_image = positive_image.to(self.device) - negative_image = negative_image.to(self.device) - - validation_logs = self._validation_step( - anchor_image, positive_image, negative_image, metrics_list - ) - - for key, value in validation_logs.items(): - accumulated_validation_logs[key] = ( - accumulated_validation_logs.get(key, 0) + value - ) - - if callbacks: - for callback in callbacks: - for key, value in validation_logs.items(): - callback.on_validation_batch_end( - {f"batch/{key}": value}, - epoch * len(train_loader) + idx, - ) - - for key, value in accumulated_validation_logs.items(): - accumulated_validation_logs[key] = value / len(validation_loader) - - # Compute and add validation metrics to logs - for metric in metrics_list: - accumulated_validation_logs[f"validation/{metric!s}"] = ( - metric.compute() - ) - # Metrics will be reset at the start of the next training epoch loop - - if callbacks: - for callback in callbacks: - callback.on_validation_epoch_end(accumulated_validation_logs, epoch) - - # Save checkpoint - if ( - checkpoint_interval is not None - and (epoch + 1) % checkpoint_interval == 0 - ): - state_dict = self.backbone_model.state_dict() - checkpoint_path = os.path.join( - log_dir, "checkpoints", f"reid_model_{epoch + 1}.safetensors" - ) - save_file( - state_dict, - checkpoint_path, - metadata={"config": json.dumps(config), "format": "pt"}, - ) - if callbacks: - for callback in callbacks: - callback.on_checkpoint_save(checkpoint_path, epoch + 1) - - if callbacks: - for callback in callbacks: - callback.on_end() + assert isinstance(validation_loader.dataset, IdentityDataset) + self.add_classification_head(num_classes, freeze_backbone=freeze_backbone) + trainer = CrossEntropyTrainer( + model=self.feature_extractor, + device=self.device, + transforms=self.train_transforms, + epochs=epochs, + label_smoothing=label_smoothing, + learning_rate=learning_rate, + weight_decay=weight_decay, + random_state=random_state, + log_dir=log_dir, + log_to_matplotlib=log_to_matplotlib, + log_to_tensorboard=log_to_tensorboard, + log_to_wandb=log_to_wandb, + ) + trainer.train( + train_loader=train_loader, + validation_loader=validation_loader, + checkpoint_interval=checkpoint_interval, + ) + self.feature_extractor = trainer.model diff --git a/trackers/core/reid/trainer/__init__.py b/trackers/core/reid/trainer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/trackers/core/reid/trainer/base.py b/trackers/core/reid/trainer/base.py new file mode 100644 index 00000000..6e42fdcd --- /dev/null +++ b/trackers/core/reid/trainer/base.py @@ -0,0 +1,189 @@ +import json +import os +from abc import ABC, abstractmethod +from typing import Any, Optional, Union + +import torch +import torch.nn as nn +import torch.optim as optim +from safetensors.torch import save_file +from torch.utils.data import DataLoader +from torchvision.transforms import Compose +from tqdm.auto import tqdm + +from trackers.core.reid.trainer.callbacks import BaseCallback +from trackers.log import get_logger + +logger = get_logger(__name__) + + +class BaseTrainer(ABC): + def __init__( + self, + model: nn.Module, + device: torch.device, + transforms: Compose, + epochs: int, + learning_rate: float = 1e-3, + weight_decay: float = 1e-2, + random_state: Optional[Union[int, float, str, bytes, bytearray]] = None, + log_dir: str = "logs", + log_to_matplotlib: bool = False, + log_to_tensorboard: bool = False, + log_to_wandb: bool = False, + config: Optional[dict[str, Any]] = {}, + ): + self.device = device + self.model = model.to(device) + self.transforms = transforms + self.epochs = epochs + self.learning_rate = learning_rate + self.weight_decay = weight_decay + self.random_state = random_state + self.log_dir = log_dir + self.config = config + + self.optimizer = optim.Adam( + self.model.parameters(), lr=learning_rate, weight_decay=weight_decay + ) + + os.makedirs(self.log_dir, exist_ok=True) + os.makedirs(os.path.join(self.log_dir, "checkpoints"), exist_ok=True) + os.makedirs(os.path.join(self.log_dir, "tensorboard_logs"), exist_ok=True) + + if random_state is not None: + torch.manual_seed(random_state) + + self.initialize_callbacks(log_to_matplotlib, log_to_tensorboard, log_to_wandb) + + def initialize_callbacks( + self, + log_to_matplotlib: bool = False, + log_to_tensorboard: bool = False, + log_to_wandb: bool = False, + ): + self.callbacks: list[BaseCallback] = [] + if log_to_matplotlib: + try: + from trackers.core.reid.trainer.callbacks import MatplotlibCallback + + self.callbacks.append(MatplotlibCallback(log_dir=self.log_dir)) + except (ImportError, AttributeError) as e: + logger.error( + "Metric logging dependencies are not installed. " + "Please install it using `pip install trackers[metrics]`.", + ) + raise e + if log_to_tensorboard: + try: + from trackers.core.reid.trainer.callbacks import TensorboardCallback + + self.callbacks.append( + TensorboardCallback( + log_dir=os.path.join(self.log_dir, "tensorboard_logs") + ) + ) + except (ImportError, AttributeError) as e: + logger.error( + "Metric logging dependencies are not installed. " + "Please install it using `pip install trackers[metrics]`." + ) + raise e + if log_to_wandb: + try: + from trackers.core.reid.trainer.callbacks import WandbCallback + + self.callbacks.append(WandbCallback(config=self.config)) + except (ImportError, AttributeError) as e: + logger.error( + "Metric logging dependencies are not installed. " + "Please install it using `pip install trackers[metrics]`." + ) + raise e + + @abstractmethod + def train_step(self, data: dict[str, torch.Tensor]): + raise NotImplementedError( + "Subclasses of `BaseTrainer` must implement `train_step`" + ) + + @abstractmethod + def validation_step(self, data: dict[str, torch.Tensor]): + raise NotImplementedError( + "Subclasses of `BaseTrainer` must implement `validation_step`" + ) + + def execute_train_batch_loop(self, train_loader: DataLoader, epoch: int): + accumulated_train_logs: dict[str, Union[float, int]] = {} + for idx, data in tqdm( + enumerate(train_loader), + total=len(train_loader), + desc=f"Training Epoch {epoch + 1}/{self.epochs}", + leave=False, + ): + train_logs = self.train_step(data) + for key, value in train_logs.items(): + accumulated_train_logs[key] = accumulated_train_logs.get(key, 0) + value + for callback in self.callbacks: + for key, value in train_logs.items(): + callback.on_train_batch_end( + {f"batch/{key}": value}, epoch * len(train_loader) + idx + ) + for key, value in accumulated_train_logs.items(): + accumulated_train_logs[key] = value / len(train_loader) + + for callback in self.callbacks: + callback.on_train_epoch_end(accumulated_train_logs, epoch) + + def execute_validation_batch_loop(self, validation_loader: DataLoader, epoch: int): + accumulated_validation_logs: dict[str, Union[float, int]] = {} + for idx, data in tqdm( + enumerate(validation_loader), + total=len(validation_loader), + desc=f"Validation Epoch {epoch + 1}/{self.epochs}", + leave=False, + ): + validation_logs = self.validation_step(data) + for key, value in validation_logs.items(): + accumulated_validation_logs[key] = ( + accumulated_validation_logs.get(key, 0) + value + ) + for callback in self.callbacks: + for key, value in validation_logs.items(): + callback.on_validation_batch_end( + {f"batch/{key}": value}, epoch * len(validation_loader) + idx + ) + for key, value in accumulated_validation_logs.items(): + accumulated_validation_logs[key] = value / len(validation_loader) + + for callback in self.callbacks: + callback.on_validation_epoch_end(accumulated_validation_logs, epoch) + + def save_checkpoint(self, epoch: int, checkpoint_interval: Optional[int] = None): + if checkpoint_interval is not None and (epoch + 1) % checkpoint_interval == 0: + state_dict = self.model.state_dict() + checkpoint_path = os.path.join( + self.log_dir, "checkpoints", f"reid_model_{epoch + 1}.safetensors" + ) + save_file( + state_dict, + checkpoint_path, + metadata={"config": json.dumps(self.config), "format": "pt"}, + ) + for callback in self.callbacks: + callback.on_checkpoint_save(checkpoint_path, epoch + 1) + + def train( + self, + train_loader: DataLoader, + validation_loader: Optional[DataLoader] = None, + checkpoint_interval: Optional[int] = None, + ): + self.model.train() + for epoch in tqdm(range(self.epochs), desc="Training"): + self.execute_train_batch_loop(train_loader, epoch) + if validation_loader is not None: + self.execute_validation_batch_loop(validation_loader, epoch) + self.save_checkpoint(epoch, checkpoint_interval) + for callback in self.callbacks: + callback.on_end() diff --git a/trackers/core/reid/callbacks.py b/trackers/core/reid/trainer/callbacks.py similarity index 97% rename from trackers/core/reid/callbacks.py rename to trackers/core/reid/trainer/callbacks.py index c879382f..29d37802 100644 --- a/trackers/core/reid/callbacks.py +++ b/trackers/core/reid/trainer/callbacks.py @@ -5,18 +5,12 @@ class BaseCallback: - def on_train_batch_start(self, logs: dict, idx: int): - pass - def on_train_batch_end(self, logs: dict, idx: int): pass def on_train_epoch_end(self, logs: dict, epoch: int): pass - def on_validation_batch_start(self, logs: dict, idx: int): - pass - def on_validation_batch_end(self, logs: dict, idx: int): pass @@ -69,7 +63,7 @@ def on_end(self): class WandbCallback(BaseCallback): - def __init__(self, config: dict[str, Any]) -> None: + def __init__(self, config: Optional[dict[str, Any]] = {}) -> None: import wandb self.run = wandb.init(config=config) if not wandb.run else wandb.run # type: ignore diff --git a/trackers/core/reid/trainer/cross_entropy_trainer.py b/trackers/core/reid/trainer/cross_entropy_trainer.py new file mode 100644 index 00000000..1ff8eb2b --- /dev/null +++ b/trackers/core/reid/trainer/cross_entropy_trainer.py @@ -0,0 +1,78 @@ +from typing import Any, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.transforms import Compose + +from trackers.core.reid.trainer.base import BaseTrainer +from trackers.core.reid.trainer.metrics import top_k_accuracy +from trackers.log import get_logger + +logger = get_logger(__name__) + + +class CrossEntropyTrainer(BaseTrainer): + def __init__( + self, + model: nn.Module, + device: torch.device, + transforms: Compose, + epochs: int, + label_smoothing: float = 1e-2, + learning_rate: float = 1e-3, + weight_decay: float = 1e-2, + random_state: Optional[Union[int, float, str, bytes, bytearray]] = None, + log_dir: str = "logs", + log_to_matplotlib: bool = False, + log_to_tensorboard: bool = False, + log_to_wandb: bool = False, + model_metadata: dict[str, Any] = {}, + ): + config = { + "epochs": epochs, + "learning_rate": learning_rate, + "weight_decay": weight_decay, + "random_state": random_state, + "label_smoothing": label_smoothing, + "model_metadata": model_metadata, + } + self.criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing) + super().__init__( + model=model, + device=device, + transforms=transforms, + epochs=epochs, + learning_rate=learning_rate, + weight_decay=weight_decay, + random_state=random_state, + log_dir=log_dir, + log_to_matplotlib=log_to_matplotlib, + log_to_tensorboard=log_to_tensorboard, + log_to_wandb=log_to_wandb, + config=config, + ) + + def train_step(self, data: dict[str, torch.Tensor]): + images = self.transforms(data["image"]).to(self.device) + identities = data["identity"].to(self.device) + outputs = self.model(images) + loss = self.criterion(F.log_softmax(outputs, dim=1), identities) + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + return { + "train/loss": loss.item(), + "train/accuracy": top_k_accuracy(outputs, identities).item(), + } + + def validation_step(self, data: dict[str, torch.Tensor]): + images = self.transforms(data["image"]).to(self.device) + identities = data["identity"].to(self.device) + with torch.inference_mode(): + outputs = self.model(images) + loss = self.criterion(F.log_softmax(outputs, dim=1), identities) + return { + "validation/loss": loss.item(), + "validation/accuracy": top_k_accuracy(outputs, identities).item(), + } diff --git a/trackers/core/reid/trainer/metrics.py b/trackers/core/reid/trainer/metrics.py new file mode 100644 index 00000000..543b322a --- /dev/null +++ b/trackers/core/reid/trainer/metrics.py @@ -0,0 +1,127 @@ +import numpy as np +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +from tqdm.auto import tqdm + + +def top_k_accuracy( + logits: torch.Tensor, true_labels: torch.Tensor, top_k: int = 1 +) -> torch.Tensor: + if isinstance(logits, (tuple, list)): + logits = logits[0] + top_k_predicted_indices = torch.t( + torch.topk(logits, top_k, dim=1, largest=True, sorted=True)[1] + ) + correct_matches = torch.eq( + top_k_predicted_indices, + true_labels.view(1, -1).expand_as(top_k_predicted_indices), + ) + num_correct_in_top_k = correct_matches[:top_k].view(-1).float().sum(0, keepdim=True) + return num_correct_in_top_k.mul_(100.0 / true_labels.size(0)) + + +def compute_feature( + data_loader: DataLoader, feature_extractor: nn.Module, device: torch.device +): + features, identities, camera_ids = [], [], [] + with torch.inference_mode(): + for data in tqdm(data_loader, total=len(data_loader)): + image_batch = data["image"].to(device) + identity_batch = data["identity"] + camera_id_batch = data["camera_id"] + + extracted_features = feature_extractor.forward_features(image_batch).cpu() + image_batch = image_batch.cpu() + torch.cuda.empty_cache() + + features.append(extracted_features) + identities.extend(identity_batch) + camera_ids.extend(camera_id_batch) + + features = torch.cat(features, dim=0) + return features, identities, camera_ids + + +def compute_distance_matrix(feature_1: torch.Tensor, feature_2: torch.Tensor): + num_query_features, num_test_features = feature_1.size(0), feature_2.size(0) + query_squared_norms = ( + torch.pow(feature_1, 2) + .sum(dim=1, keepdim=True) + .expand(num_query_features, num_test_features) + ) + test_squared_norms = ( + torch.pow(feature_2, 2) + .sum(dim=1, keepdim=True) + .expand(num_test_features, num_query_features) + .t() + ) + distance_matrix = query_squared_norms + test_squared_norms + distance_matrix.addmm_(feature_1, feature_2.t(), beta=1, alpha=-2) + return distance_matrix + + +def evaluate_rank( + query_dataloader: DataLoader, + test_dataloader: DataLoader, + feature_extractor: nn.Module, + device: torch.device, + max_rank: int = 50, +): + query_features, query_identities, query_camera_ids = compute_feature( + query_dataloader, feature_extractor, device + ) + test_features, test_identities, test_camera_ids = compute_feature( + test_dataloader, feature_extractor, device + ) + distance_matrix = compute_distance_matrix(query_features, test_features) + num_queries, num_test_samples = distance_matrix.shape + + if num_test_samples < max_rank: + max_rank = num_test_samples + + ranked_test_indices = np.argsort(distance_matrix, axis=1) + identity_matches = ( + test_identities[ranked_test_indices] == query_identities[:, np.newaxis] + ).astype(np.int32) + + query_cmc_curves = [] + query_average_precisions = [] + num_valid_queries = 0.0 + + for query_idx in range(num_queries): + query_person_id = query_identities[query_idx] + query_camera_id = query_camera_ids[query_idx] + + sorted_test_indices = ranked_test_indices[query_idx] + same_camera_same_identity_mask = ( + test_identities[sorted_test_indices] == query_person_id + ) & (test_camera_ids[sorted_test_indices] == query_camera_id) + valid_test_mask = np.invert(same_camera_same_identity_mask) + + # compute cmc curve + raw_cmc_curve = identity_matches[query_idx][valid_test_mask] + if not np.any(raw_cmc_curve): + continue + + cmc_curve = raw_cmc_curve.cumsum() + cmc_curve[cmc_curve > 1] = 1 + + query_cmc_curves.append(cmc_curve[:max_rank]) + num_valid_queries += 1.0 + + # compute average precision + num_relevant_matches = raw_cmc_curve.sum() + precision_at_rank = raw_cmc_curve.cumsum() + precision_at_rank = [x / (i + 1.0) for i, x in enumerate(precision_at_rank)] + precision_at_rank = np.asarray(precision_at_rank) * raw_cmc_curve + average_precision = precision_at_rank.sum() / num_relevant_matches + query_average_precisions.append(average_precision) + + assert num_valid_queries > 0, "Error: all query identities do not appear in test" + + query_cmc_curves = np.asarray(query_cmc_curves).astype(np.float32) + query_cmc_curves = query_cmc_curves.sum(0) / num_valid_queries + mean_average_precision = np.mean(query_average_precisions) + + return query_cmc_curves, mean_average_precision