Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion trackers/core/deepsort/kalman_box_tracker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import Optional, Union
from typing import Optional, Tuple, Union

import numpy as np
from scipy.linalg import solve_triangular

# Chi-square 0.95 quantile for 4 degrees of freedom (Mahalanobis threshold)
MAHALANOBIS_THRESHOLD = 9.4877


class DeepSORTKalmanBoxTracker:
Expand Down Expand Up @@ -96,6 +100,58 @@ def _initialize_kalman_filter(self) -> None:
# Error covariance matrix (P)
self.P = np.eye(8, dtype=np.float32)

def project(self) -> Tuple[np.ndarray, np.ndarray]:
"""
Projects the current state distribution to measurement space.

As per the Kalman Filter formulation mentioned implicitly in
Section 2.1 of the DeepSORT paper, this function computes:
(y_i, S_i) = (H·μ_i, H·Σ_i·H^T + R)

Returns:
Tuple[np.ndarray, np.ndarray]: Projected mean (y_i) and innovation
covariance (S_i) for gating and association.
"""
# Project state mean to measurement space: y_i = H·μ_i
projected_mean = self.H @ self.state

# Project state covariance to measurement space: H·Σ_i·H^T
projected_covariance = self.H @ self.P @ self.H.T

# Add measurement noise: S_i = H·Σ_i·H^T + R
innovation_covariance = projected_covariance + self.R

return projected_mean, innovation_covariance

def compute_gating_distance(self, measurements: np.ndarray) -> np.ndarray:
"""
Computes the squared Mahalanobis distance between the track and
measurements.

This function is used for gating (ruling out) unlikely associations
as described in Eq. (1)-(2) of the DeepSORT paper:
d^(1)(i,j) = (d_j - y_i)^T · S_i^(-1) · (d_j - y_i)

Args:
measurements (np.ndarray): An Nx4 matrix of N measurements, each in
format [x1, y1, x2, y2] representing detected bounding boxes.

Returns:
np.ndarray: An array of length N, where the i-th element contains the
squared Mahalanobis distance between the track and measurements[i].
"""
# Project current state to measurement space
mean, covariance = self.project()
mean = mean.reshape(1, 4)
cholesky_factor = np.linalg.cholesky(covariance)
d = measurements - mean
# Solve the system L·z = d^T efficiently using triangular solver
# This gives us z where z = L^(-1)·d^T
z = solve_triangular(cholesky_factor, d.T, lower=True, check_finite=False)
# Compute squared Mahalanobis distance as the squared norm of z
# d_m^2 = z^T·z = d^T·S^(-1)·d
return np.sum(z * z, axis=0)

def predict(self) -> None:
"""
Predict the next state of the bounding box (applies the state transition).
Expand Down
236 changes: 187 additions & 49 deletions trackers/core/deepsort/tracker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import List, Optional, Tuple, Union

import numpy as np
import supervision as sv
Expand All @@ -7,8 +7,12 @@

from trackers.core.base import BaseTrackerWithFeatures
from trackers.core.deepsort.feature_extractor import DeepSORTFeatureExtractor
from trackers.core.deepsort.kalman_box_tracker import DeepSORTKalmanBoxTracker
from trackers.core.deepsort.kalman_box_tracker import (
MAHALANOBIS_THRESHOLD,
DeepSORTKalmanBoxTracker,
)
from trackers.utils.sort_utils import (
convert_bbox_to_xyah,
get_alive_trackers,
get_iou_matrix,
update_detections_with_track_ids,
Expand Down Expand Up @@ -198,7 +202,7 @@ def _initialize_feature_extractor(

Args:
feature_extractor: The feature extractor input, which can be a model path,
a torch module, or a DeepSORTFeatureExtractor instance.
a torch module, or a DeepSORTFeatureExtractor instance.
device: The device to run the model on.

Returns:
Expand Down Expand Up @@ -236,83 +240,222 @@ def _get_appearance_distance_matrix(

return distance_matrix

def _get_mahalanobis_distance_matrix(
self,
detection_boxes: np.ndarray,
) -> np.ndarray:
"""
Calculate Mahalanobis distance matrix between tracks and detections,
as per Equation 1 in section 2.2 of the [DeepSORT paper](https://arxiv.org/pdf/1703.07402).

Args:
detection_boxes (np.ndarray): Detected bounding boxes in the
form [x1, y1, x2, y2].

Returns:
np.ndarray: Mahalanobis distance matrix.
"""
if len(self.trackers) == 0 or len(detection_boxes) == 0:
return np.zeros((len(self.trackers), len(detection_boxes)))

distance_matrix = np.zeros((len(self.trackers), len(detection_boxes)))

for i, tracker in enumerate(self.trackers):
measurements = np.array(
[convert_bbox_to_xyah(box) for box in detection_boxes]
)

distance_matrix[i, :] = tracker.compute_gating_distance(measurements)

return distance_matrix

def _get_combined_distance_matrix(
self,
iou_matrix: np.ndarray,
mahalanobis_matrix: np.ndarray,
appearance_dist_matrix: np.ndarray,
) -> np.ndarray:
"""
Combine IOU and appearance distances into a single distance matrix.
Combine Mahalanobis and appearance distances into a single distance matrix,
as per Equation 5 in section 2.2 of the [DeepSORT paper](https://arxiv.org/pdf/1703.07402).

Args:
iou_matrix (np.ndarray): IOU matrix between tracks and detections.
mahalanobis_matrix (np.ndarray): Mahalanobis distance matrix between
tracks and detections.
appearance_dist_matrix (np.ndarray): Appearance distance matrix.

Returns:
np.ndarray: Combined distance matrix.
"""
iou_distance = 1 - iou_matrix
# Using weighted sum to combine Mahalanobis and appearance distances
combined_dist = (
1 - self.appearance_weight
) * iou_distance + self.appearance_weight * appearance_dist_matrix
self.appearance_weight * appearance_dist_matrix
+ (1 - self.appearance_weight) * mahalanobis_matrix
)

# Set high distance for IOU below threshold
mask = iou_matrix < self.minimum_iou_threshold
combined_dist[mask] = 1.0
mahalanobis_gate = mahalanobis_matrix > MAHALANOBIS_THRESHOLD
appearance_gate = appearance_dist_matrix > self.appearance_threshold

# Set high distance for appearance above threshold
mask = appearance_dist_matrix > self.appearance_threshold
combined_dist[mask] = 1.0
# An association is inadmissible if either metric is above threshold
invalid_mask = np.logical_or(mahalanobis_gate, appearance_gate)
combined_dist[invalid_mask] = 1.0 # Mark as infeasible

return combined_dist

def _match_tracks_stage(
self,
cost_matrix: np.ndarray,
track_indices: list,
detection_indices: list,
) -> Tuple[List[Tuple[int, int]], List[int], List[int]]:
"""
Match tracks with detections for a specific stage of the matching cascade.
This implements the linear assignment for a specific group of tracks based
on their maturity.

Args:
cost_matrix (np.ndarray): Cost matrix between tracks and detections.
track_indices (list): Indices of tracks to match.
detection_indices (list): Indices of detections to match.

Returns:
tuple[list[tuple[int, int]], list[int], list[int]]: Matched indices,
unmatched track indices, unmatched detection indices.
"""
if len(track_indices) == 0 or len(detection_indices) == 0:
return [], track_indices, detection_indices

sub_cost_matrix = cost_matrix[np.ix_(track_indices, detection_indices)]

# Apply threshold of 1.0 to mark infeasible associations
# Only consider associations where cost < 1.0
valid_mask = sub_cost_matrix < 1.0

if not np.any(valid_mask):
return [], track_indices, detection_indices

row_indices, col_indices = np.where(valid_mask)

indices = np.stack([row_indices, col_indices], axis=1)
indices = indices[np.argsort(sub_cost_matrix[row_indices, col_indices])]

matches = []
unmatched_tracks = list(track_indices)
unmatched_detections = list(detection_indices)

matched_track_indices = set()
matched_detection_indices = set()

for row, col in indices:
track_idx = track_indices[row]
detection_idx = detection_indices[col]

# Skip if either track or detection is already matched
if row in matched_track_indices or col in matched_detection_indices:
continue

matches.append((track_idx, detection_idx))
matched_track_indices.add(row)
matched_detection_indices.add(col)

if track_idx in unmatched_tracks:
unmatched_tracks.remove(track_idx)
if detection_idx in unmatched_detections:
unmatched_detections.remove(detection_idx)

return matches, unmatched_tracks, unmatched_detections

def _get_associated_indices(
self,
iou_matrix: np.ndarray,
detection_boxes: np.ndarray,
detection_features: np.ndarray,
) -> tuple[list[tuple[int, int]], set[int], set[int]]:
"""
Associate detections to trackers based on both IOU and appearance.
Associate detections to trackers using a cascade matching approach.
The cascade gives priority to tracks that have been recently seen.

As per the paper, we use a weighted combination of Mahalanobis distance
(for motion information) and appearance distance (for visual similarity).

Args:
iou_matrix (np.ndarray): IOU matrix between tracks and detections.
detection_boxes (np.ndarray): Detected bounding boxes in the
form [x1, y1, x2, y2].
detection_features (np.ndarray): Features extracted from current detections.

Returns:
tuple[list[tuple[int, int]], set[int], set[int]]: Matched indices,
unmatched trackers, unmatched detections.
"""
confirmed_tracks = []
unconfirmed_tracks = []
for tracker_idx, tracker in enumerate(self.trackers):
if tracker.number_of_successful_updates >= self.minimum_consecutive_frames:
confirmed_tracks.append(tracker_idx)
else:
unconfirmed_tracks.append(tracker_idx)

# Get Mahalanobis distance matrix
mahalanobis_matrix = self._get_mahalanobis_distance_matrix(detection_boxes)

# Get appearance distance matrix
appearance_dist_matrix = self._get_appearance_distance_matrix(
detection_features
)
combined_dist = self._get_combined_distance_matrix(
iou_matrix, appearance_dist_matrix

# Combine distances using weighted sum
combined_dist_matrix = self._get_combined_distance_matrix(
mahalanobis_matrix, appearance_dist_matrix
)
matched_indices = []
unmatched_trackers = set(range(len(self.trackers)))
unmatched_detections = set(range(len(detection_features)))

if combined_dist.size > 0:
row_indices, col_indices = np.where(combined_dist < 1.0)
sorted_pairs = sorted(
zip(map(int, row_indices), map(int, col_indices)),
key=lambda x: combined_dist[x[0], x[1]],

confirmed_matches, unmatched_confirmed, unmatched_detections = (
self._match_tracks_stage(
combined_dist_matrix,
confirmed_tracks,
list(range(len(detection_features))),
)
)

used_rows = set()
used_cols = set()
for row, col in sorted_pairs:
if (row not in used_rows) and (col not in used_cols):
used_rows.add(row)
used_cols.add(col)
matched_indices.append((row, col))
# Find recently lost confirmed tracks (time_since_update == 1)
recently_lost = [
tracker_idx
for tracker_idx in unmatched_confirmed
if self.trackers[tracker_idx].time_since_update == 1
]

unmatched_trackers = unmatched_trackers - {int(row) for row in used_rows}
unmatched_detections = unmatched_detections - {
int(col) for col in used_cols
}
# Remove recently lost from unmatched_confirmed
unmatched_confirmed = [
tracker_idx
for tracker_idx in unmatched_confirmed
if tracker_idx not in recently_lost
]

return matched_indices, unmatched_trackers, unmatched_detections
iou_track_candidates = unconfirmed_tracks + recently_lost

# Match remaining tracks using IoU only
iou_matrix = get_iou_matrix(
trackers=self.trackers, detection_boxes=detection_boxes
)

iou_matches: list[tuple[int, int]] = []
if iou_track_candidates and unmatched_detections:
iou_dist_matrix = 1 - iou_matrix
iou_dist_matrix_filtered = iou_dist_matrix.copy()
mask = iou_matrix < self.minimum_iou_threshold
iou_dist_matrix_filtered[mask] = 1.0

iou_matches, unmatched_candidates, unmatched_detections = (
self._match_tracks_stage(
iou_dist_matrix_filtered,
iou_track_candidates,
list(unmatched_detections),
)
)
else:
unmatched_candidates = iou_track_candidates

matches = confirmed_matches + iou_matches
unmatched_tracks = set(unmatched_confirmed).union(set(unmatched_candidates))

return matches, unmatched_tracks, set(unmatched_detections)

def _spawn_new_trackers(
self,
Expand Down Expand Up @@ -380,14 +523,9 @@ def update(self, detections: sv.Detections, frame: np.ndarray) -> sv.Detections:
for tracker in self.trackers:
tracker.predict()

# Build IOU cost matrix between detections and predicted bounding boxes
iou_matrix = get_iou_matrix(
trackers=self.trackers, detection_boxes=detection_boxes
)

# Associate detections to trackers based on IOU
matched_indices, _, unmatched_detections = self._get_associated_indices(
iou_matrix, detection_features
# Associate detections to trackers using the cascade matching approach
matched_indices, unmatched_tracks, unmatched_detections = (
self._get_associated_indices(detection_boxes, detection_features)
)

# Update matched trackers with assigned detections
Expand Down
Loading