diff --git a/trackers/core/deepsort/kalman_box_tracker.py b/trackers/core/deepsort/kalman_box_tracker.py index 05873bce..dd37c6b5 100644 --- a/trackers/core/deepsort/kalman_box_tracker.py +++ b/trackers/core/deepsort/kalman_box_tracker.py @@ -1,6 +1,10 @@ -from typing import Optional, Union +from typing import Optional, Tuple import numpy as np +from scipy.linalg import solve_triangular + +# Chi-square 0.95 quantile for 4 degrees of freedom (Mahalanobis threshold) +MAHALANOBIS_THRESHOLD = 9.4877 class DeepSORTKalmanBoxTracker: @@ -43,10 +47,16 @@ def get_next_tracker_id(cls) -> int: cls.count_id += 1 return next_id - def __init__(self, bbox: np.ndarray, feature: Optional[np.ndarray] = None): + def __init__( + self, + bbox: np.ndarray, + feature: Optional[np.ndarray] = None, + max_features_gallery_size: int = 100, + ): # Initialize with a temporary ID of -1 # Will be assigned a real ID when the track is considered mature self.tracker_id = -1 + self.max_features_gallery_size = max_features_gallery_size # Number of hits indicates how many times the object has been # updated successfully @@ -96,6 +106,58 @@ def _initialize_kalman_filter(self) -> None: # Error covariance matrix (P) self.P = np.eye(8, dtype=np.float32) + def project(self) -> Tuple[np.ndarray, np.ndarray]: + """ + Projects the current state distribution to measurement space. + + As per the Kalman Filter formulation mentioned implicitly in + Section 2.1 of the DeepSORT paper, this function computes: + (y_i, S_i) = (H·μ_i, H·Σ_i·H^T + R) + + Returns: + Tuple[np.ndarray, np.ndarray]: Projected mean (y_i) and innovation + covariance (S_i) for gating and association. + """ + # Project state mean to measurement space: y_i = H·μ_i + projected_mean = self.H @ self.state + + # Project state covariance to measurement space: H·Σ_i·H^T + projected_covariance = self.H @ self.P @ self.H.T + + # Add measurement noise: S_i = H·Σ_i·H^T + R + innovation_covariance = projected_covariance + self.R + + return projected_mean, innovation_covariance + + def compute_gating_distance(self, measurements: np.ndarray) -> np.ndarray: + """ + Computes the squared Mahalanobis distance between the track and + measurements. + + This function is used for gating (ruling out) unlikely associations + as described in Eq. (1)-(2) of the DeepSORT paper: + d^(1)(i,j) = (d_j - y_i)^T · S_i^(-1) · (d_j - y_i) + + Args: + measurements (np.ndarray): An Nx4 matrix of N measurements, each in + format [x1, y1, x2, y2] representing detected bounding boxes. + + Returns: + np.ndarray: An array of length N, where the i-th element contains the + squared Mahalanobis distance between the track and measurements[i]. + """ + # Project current state to measurement space + mean, covariance = self.project() + mean = mean.reshape(1, 4) + cholesky_factor = np.linalg.cholesky(covariance) + d = measurements - mean + # Solve the system L·z = d^T efficiently using triangular solver + # This gives us z where z = L^(-1)·d^T + z = solve_triangular(cholesky_factor, d.T, lower=True, check_finite=False) + # Compute squared Mahalanobis distance as the squared norm of z + # d_m^2 = z^T·z = d^T·S^(-1)·d + return np.sum(z * z, axis=0) + def predict(self) -> None: """ Predict the next state of the bounding box (applies the state transition). @@ -152,19 +214,5 @@ def get_state_bbox(self) -> np.ndarray: def update_feature(self, feature: np.ndarray): self.features.append(feature) - - def get_feature(self) -> Union[np.ndarray, None]: - """ - Get the mean feature vector for this tracker. - - Returns: - np.ndarray: Mean feature vector. - """ - if len(self.features) > 0: - # Return the mean of all features, thus (in theory) capturing the - # "average appearance" of the object, which should be more robust - # to minor appearance changes. Otherwise, the last feature can - # also be returned like the following: - # return self.features[-1] - return np.mean(self.features, axis=0) - return None + if len(self.features) > self.max_features_gallery_size: + self.features.pop(0) # Remove the oldest feature diff --git a/trackers/core/deepsort/tracker.py b/trackers/core/deepsort/tracker.py index 0f157810..668ca604 100644 --- a/trackers/core/deepsort/tracker.py +++ b/trackers/core/deepsort/tracker.py @@ -1,16 +1,21 @@ -from typing import Optional +from typing import List, Tuple import numpy as np import supervision as sv +from scipy.optimize import linear_sum_assignment from scipy.spatial.distance import cdist from trackers.core.base import BaseTrackerWithFeatures -from trackers.core.deepsort.kalman_box_tracker import DeepSORTKalmanBoxTracker +from trackers.core.deepsort.kalman_box_tracker import ( + MAHALANOBIS_THRESHOLD, + DeepSORTKalmanBoxTracker, +) from trackers.core.reid import ReIDModel from trackers.utils.sort_utils import ( get_alive_trackers, get_iou_matrix, update_detections_with_track_ids, + xyxy_to_xcycarh, ) @@ -24,8 +29,6 @@ class DeepSORTTracker(BaseTrackerWithFeatures): Args: reid_model (ReIDModel): An instance of a `ReIDModel` to extract appearance features. - device (Optional[str]): Device to run the feature extraction - model on (e.g., 'cpu', 'cuda'). lost_track_buffer (int): Number of frames to buffer when a track is lost. Enhances occlusion handling but may increase ID switches for similar objects. frame_rate (float): Frame rate of the video (frames per second). @@ -43,20 +46,21 @@ class DeepSORTTracker(BaseTrackerWithFeatures): distance in the combined matching cost. distance_metric (str): Distance metric for appearance features (e.g., 'cosine', 'euclidean'). See `scipy.spatial.distance.cdist`. + max_features_gallery_size (int): Maximum size of the feature gallery for appearance matching. """ # noqa: E501 def __init__( self, reid_model: ReIDModel, - device: Optional[str] = None, lost_track_buffer: int = 30, frame_rate: float = 30.0, track_activation_threshold: float = 0.25, minimum_consecutive_frames: int = 3, minimum_iou_threshold: float = 0.3, - appearance_threshold: float = 0.7, + appearance_threshold: float = 0.2, appearance_weight: float = 0.5, distance_metric: str = "cosine", + max_features_gallery_size: int = 100, ): self.reid_model = reid_model self.lost_track_buffer = lost_track_buffer @@ -67,6 +71,7 @@ def __init__( self.appearance_threshold = appearance_threshold self.appearance_weight = appearance_weight self.distance_metric = distance_metric + self.max_features_gallery_size = max_features_gallery_size # Calculate maximum frames without update based on lost_track_buffer and # frame_rate. This scales the buffer based on the frame rate to ensure # consistent time-based tracking across different frame rates. @@ -93,91 +98,239 @@ def _get_appearance_distance_matrix( if len(self.trackers) == 0 or len(detection_features) == 0: return np.zeros((len(self.trackers), len(detection_features))) - track_features = np.array([t.get_feature() for t in self.trackers]) - distance_matrix = cdist( - track_features, detection_features, metric=self.distance_metric - ) + # Initialize an empty distance matrix + distance_matrix = np.full((len(self.trackers), len(detection_features)), np.inf) + + for i, tracker in enumerate(self.trackers): + if not tracker.features: # Skip if tracker has no features + continue + track_gallery_features = np.array(tracker.features) + # Calculate cdist between all features in the track's gallery + # and all detection features. + cost_matrix_gallery = cdist( + track_gallery_features, detection_features, metric=self.distance_metric + ) + # Find the minimum cdist for this tracker against all detections + min_distances_to_detections = np.min(cost_matrix_gallery, axis=0) + distance_matrix[i, :] = min_distances_to_detections + distance_matrix = np.clip(distance_matrix, 0, 1) return distance_matrix + def _get_mahalanobis_distance_matrix( + self, + detection_boxes: np.ndarray, + ) -> np.ndarray: + """ + Calculate Mahalanobis distance matrix between tracks and detections, + as per Equation 1 in section 2.2 of the [DeepSORT paper](https://arxiv.org/pdf/1703.07402). + + Args: + detection_boxes (np.ndarray): Detected bounding boxes in the + form [x1, y1, x2, y2]. + + Returns: + np.ndarray: Mahalanobis distance matrix. + """ + if len(self.trackers) == 0 or len(detection_boxes) == 0: + return np.zeros((len(self.trackers), len(detection_boxes))) + + distance_matrix = np.zeros((len(self.trackers), len(detection_boxes))) + + for i, tracker in enumerate(self.trackers): + measurements = np.array([xyxy_to_xcycarh(box) for box in detection_boxes]) + + distance_matrix[i, :] = tracker.compute_gating_distance(measurements) + + return distance_matrix + def _get_combined_distance_matrix( self, - iou_matrix: np.ndarray, + mahalanobis_matrix: np.ndarray, appearance_dist_matrix: np.ndarray, ) -> np.ndarray: """ - Combine IOU and appearance distances into a single distance matrix. + Combine Mahalanobis and appearance distances into a single distance matrix, + as per Equation 5 in section 2.2 of the [DeepSORT paper](https://arxiv.org/pdf/1703.07402). Args: - iou_matrix (np.ndarray): IOU matrix between tracks and detections. + mahalanobis_matrix (np.ndarray): Mahalanobis distance matrix between + tracks and detections. appearance_dist_matrix (np.ndarray): Appearance distance matrix. Returns: np.ndarray: Combined distance matrix. """ - iou_distance: np.ndarray = 1 - iou_matrix + # Using weighted sum to combine Mahalanobis and appearance distances combined_dist = ( - 1 - self.appearance_weight - ) * iou_distance + self.appearance_weight * appearance_dist_matrix + self.appearance_weight * appearance_dist_matrix + + (1 - self.appearance_weight) * mahalanobis_matrix + ) - # Set high distance for IOU below threshold - mask = iou_matrix < self.minimum_iou_threshold - combined_dist[mask] = 1.0 + mahalanobis_gate = mahalanobis_matrix > MAHALANOBIS_THRESHOLD + appearance_gate = appearance_dist_matrix > self.appearance_threshold - # Set high distance for appearance above threshold - mask = appearance_dist_matrix > self.appearance_threshold - combined_dist[mask] = 1.0 + # An association is inadmissible if either metric is above threshold + invalid_mask = np.logical_or(mahalanobis_gate, appearance_gate) + combined_dist[invalid_mask] = 1.0 # Mark as infeasible return combined_dist + def _match_tracks_using_linear_sum_assignment( + self, + cost_matrix: np.ndarray, + track_indices: list, + detection_indices: list, + ) -> Tuple[List[Tuple[int, int]], List[int], List[int]]: + """ + Match tracks with detections for a specific stage of the matching cascade. + This implements the linear assignment for a specific group of tracks based + on their maturity. + + Args: + cost_matrix (np.ndarray): Cost matrix between tracks and detections. + track_indices (list): Indices of tracks to match. + detection_indices (list): Indices of detections to match. + + Returns: + tuple[list[tuple[int, int]], list[int], list[int]]: Matched indices, + unmatched track indices, unmatched detection indices. + """ + if len(track_indices) == 0 or len(detection_indices) == 0: + return [], track_indices, detection_indices + + sub_cost_matrix = cost_matrix[np.ix_(track_indices, detection_indices)] + + # Apply threshold of 1.0 to mark infeasible associations + valid_mask = sub_cost_matrix < 1.0 + + if not np.any(valid_mask): + return [], track_indices, detection_indices + + # Create a masked cost matrix where invalid associations have a very high cost + masked_cost_matrix = np.where(valid_mask, sub_cost_matrix, 1e10) + + # Find optimal assignment minimizing the total cost using + # scipy.optimize.linear_sum_assignment. Note that it uses a a modified + # Jonker-Volgenant algorithm with no initialization instead of the + # Hungarian algorithm as mentioned in the SORT paper. + row_indices, col_indices = linear_sum_assignment(masked_cost_matrix) + + # Filter out assignments with invalid cost (those we marked with 1e10) + valid_assignments = sub_cost_matrix[row_indices, col_indices] < 1.0 + row_indices = row_indices[valid_assignments] + col_indices = col_indices[valid_assignments] + + # Convert to original indices + matches = [] + for row, col in zip(row_indices, col_indices): + track_idx = track_indices[row] + detection_idx = detection_indices[col] + matches.append((track_idx, detection_idx)) + + # Determine unmatched tracks and detections + matched_track_indices = {track_indices[row] for row in row_indices} + matched_detection_indices = {detection_indices[col] for col in col_indices} + + unmatched_tracks = [ + idx for idx in track_indices if idx not in matched_track_indices + ] + unmatched_detections = [ + idx for idx in detection_indices if idx not in matched_detection_indices + ] + + return matches, unmatched_tracks, unmatched_detections + def _get_associated_indices( self, - iou_matrix: np.ndarray, + detection_boxes: np.ndarray, detection_features: np.ndarray, ) -> tuple[list[tuple[int, int]], set[int], set[int]]: """ - Associate detections to trackers based on both IOU and appearance. + Associate detections to trackers using a cascade matching approach. + The cascade gives priority to tracks that have been recently seen. + + As per the paper, we use a weighted combination of Mahalanobis distance + (for motion information) and appearance distance (for visual similarity). Args: - iou_matrix (np.ndarray): IOU matrix between tracks and detections. + detection_boxes (np.ndarray): Detected bounding boxes in the + form [x1, y1, x2, y2]. detection_features (np.ndarray): Features extracted from current detections. Returns: tuple[list[tuple[int, int]], set[int], set[int]]: Matched indices, unmatched trackers, unmatched detections. """ + confirmed_tracks = [] + unconfirmed_tracks = [] + for tracker_idx, tracker in enumerate(self.trackers): + if tracker.number_of_successful_updates >= self.minimum_consecutive_frames: + confirmed_tracks.append(tracker_idx) + else: + unconfirmed_tracks.append(tracker_idx) + + # Get Mahalanobis distance matrix + mahalanobis_matrix = self._get_mahalanobis_distance_matrix(detection_boxes) + + # Get appearance distance matrix appearance_dist_matrix = self._get_appearance_distance_matrix( detection_features ) - combined_dist = self._get_combined_distance_matrix( - iou_matrix, appearance_dist_matrix + + # Combine distances using weighted sum + combined_dist_matrix = self._get_combined_distance_matrix( + mahalanobis_matrix, appearance_dist_matrix ) - matched_indices = [] - unmatched_trackers = set(range(len(self.trackers))) - unmatched_detections = set(range(len(detection_features))) - - if combined_dist.size > 0: - row_indices, col_indices = np.where(combined_dist < 1.0) - sorted_pairs = sorted( - zip(map(int, row_indices), map(int, col_indices)), - key=lambda x: combined_dist[x[0], x[1]], + + confirmed_matches, unmatched_confirmed, unmatched_detections = ( + self._match_tracks_using_linear_sum_assignment( + combined_dist_matrix, + confirmed_tracks, + list(range(len(detection_features))), ) + ) - used_rows = set() - used_cols = set() - for row, col in sorted_pairs: - if (row not in used_rows) and (col not in used_cols): - used_rows.add(row) - used_cols.add(col) - matched_indices.append((row, col)) + # Find recently lost confirmed tracks (time_since_update == 1) + recently_lost = [ + tracker_idx + for tracker_idx in unmatched_confirmed + if self.trackers[tracker_idx].time_since_update == 1 + ] - unmatched_trackers = unmatched_trackers - {int(row) for row in used_rows} - unmatched_detections = unmatched_detections - { - int(col) for col in used_cols - } + # Remove recently lost from unmatched_confirmed + unmatched_confirmed = [ + tracker_idx + for tracker_idx in unmatched_confirmed + if tracker_idx not in recently_lost + ] - return matched_indices, unmatched_trackers, unmatched_detections + iou_track_candidates = unconfirmed_tracks + recently_lost + + # Match remaining tracks using IoU only + iou_matrix = get_iou_matrix( + trackers=self.trackers, detection_boxes=detection_boxes + ) + + iou_matches: list[tuple[int, int]] = [] + if iou_track_candidates and unmatched_detections: + iou_dist_matrix: np.ndarray = 1 - iou_matrix + + iou_matches, unmatched_candidates, unmatched_detections = ( + self._match_tracks_using_linear_sum_assignment( + iou_dist_matrix, + iou_track_candidates, + list(unmatched_detections), + ) + ) + else: + unmatched_candidates = iou_track_candidates + + matches = confirmed_matches + iou_matches + unmatched_tracks = set(unmatched_confirmed).union(set(unmatched_candidates)) + + return matches, unmatched_tracks, set(unmatched_detections) def _spawn_new_trackers( self, @@ -210,7 +363,9 @@ def _spawn_new_trackers( feature = detection_features[detection_idx] new_tracker = DeepSORTKalmanBoxTracker( - bbox=detection_boxes[detection_idx], feature=feature + bbox=detection_boxes[detection_idx], + feature=feature, + max_features_gallery_size=self.max_features_gallery_size, ) self.trackers.append(new_tracker) @@ -248,14 +403,9 @@ def update(self, detections: sv.Detections, frame: np.ndarray) -> sv.Detections: for tracker in self.trackers: tracker.predict() - # Build IOU cost matrix between detections and predicted bounding boxes - iou_matrix = get_iou_matrix( - trackers=self.trackers, detection_boxes=detection_boxes - ) - - # Associate detections to trackers based on IOU - matched_indices, _, unmatched_detections = self._get_associated_indices( - iou_matrix, detection_features + # Associate detections to trackers using the cascade matching approach + matched_indices, unmatched_tracks, unmatched_detections = ( + self._get_associated_indices(detection_boxes, detection_features) ) # Update matched trackers with assigned detections diff --git a/trackers/utils/sort_utils.py b/trackers/utils/sort_utils.py index 7ef83125..183953ec 100644 --- a/trackers/utils/sort_utils.py +++ b/trackers/utils/sort_utils.py @@ -143,3 +143,25 @@ def update_detections_with_track_ids( updated_detections.tracker_id = np.array(final_tracker_ids) return updated_detections + + +def xyxy_to_xcycarh(xyxy: np.ndarray) -> np.ndarray: + """ + Convert bounding box into measurement space to format + `(center x, center y, aspect ratio, height)`, + where the aspect ratio is `width / height`. + + Args: + xyxy (np.ndarray): Bounding box in format `(x1, y1, x2, y2)`. + + Returns: + np.ndarray: Bounding box in format + `(center x, center y, aspect ratio, height)`. + """ + x1, y1, x2, y2 = xyxy.T + width = x2 - x1 + height = y2 - y1 + center_x = x1 + width / 2 + center_y = y1 + height / 2 + aspect_ratio = width / height if height > 0 else 1.0 + return np.array([center_x, center_y, aspect_ratio, height])