From 2d4a6868d87b0212dac5a504906d0b850ae7c79a Mon Sep 17 00:00:00 2001 From: rolson24 Date: Sat, 26 Apr 2025 09:45:38 -0400 Subject: [PATCH 1/4] initial dataset loader commit --- trackers/dataset/base.py | 85 +++++ trackers/dataset/mot_challenge.py | 603 ++++++++++++++++++++++++++++++ trackers/dataset/utils.py | 85 +++++ 3 files changed, 773 insertions(+) create mode 100644 trackers/dataset/base.py create mode 100644 trackers/dataset/mot_challenge.py create mode 100644 trackers/dataset/utils.py diff --git a/trackers/dataset/base.py b/trackers/dataset/base.py new file mode 100644 index 00000000..177096e5 --- /dev/null +++ b/trackers/dataset/base.py @@ -0,0 +1,85 @@ +import abc +from typing import Any, Dict, Iterator, List, Optional, Tuple +import supervision as sv + +# --- Base Dataset --- +class Dataset(abc.ABC): + """Abstract base class for datasets used in tracking evaluation.""" + + @abc.abstractmethod + def load_ground_truth(self, sequence_name: str) -> Optional[sv.Detections]: + """ + Loads ground truth data for a specific sequence. + + Args: + sequence_name: The name of the sequence. + + Returns: + An sv.Detections object containing ground truth annotations, or None + if ground truth is not available or cannot be loaded. The Detections + object should ideally include `tracker_id` and `data['frame_idx']`. + """ + pass + + @abc.abstractmethod + def get_sequence_names(self) -> List[str]: + """Returns a list of sequence names available in the dataset.""" + pass + + @abc.abstractmethod + def get_sequence_info(self, sequence_name: str) -> Dict[str, Any]: + """ + Returns metadata for a specific sequence. + + Args: + sequence_name: The name of the sequence. + + Returns: + A dictionary containing sequence information (e.g., 'frame_rate', + 'seqLength', 'img_width', 'img_height', 'img_dir'). Keys and value + types may vary depending on the dataset format. + """ + pass + + @abc.abstractmethod + def get_frame_iterator(self, sequence_name: str) -> Iterator[Dict[str, Any]]: + """ + Returns an iterator over frame information dictionaries for a sequence. + + Args: + sequence_name: The name of the sequence. + + Yields: + Dictionaries, each representing a frame. Each dictionary should + contain at least 'frame_idx' (int, typically 1-based) and + 'image_path' (str, absolute path recommended). + """ + pass + + @abc.abstractmethod + def preprocess( + self, + ground_truth: sv.Detections, + predictions: sv.Detections, + iou_threshold: float = 0.5, + remove_distractor_matches: bool = True, + ) -> Tuple[sv.Detections, sv.Detections]: + """ + Applies dataset-specific preprocessing steps to ground truth and predictions. + + This typically involves filtering unwanted annotations (e.g., distractors, + zero-marked GTs) and potentially relabeling IDs. + + Args: + ground_truth (sv.Detections): Raw ground truth detections for a sequence. + predictions (sv.Detections): Raw prediction detections for a sequence. + iou_threshold (float): IoU threshold used for matching during preprocessing + (e.g., for removing predictions matching distractors). + remove_distractor_matches (bool): Flag indicating whether to remove + predictions matched to distractors. + + Returns: + Tuple[sv.Detections, sv.Detections]: A tuple containing the processed + ground truth detections and processed prediction detections. + """ + pass \ No newline at end of file diff --git a/trackers/dataset/mot_challenge.py b/trackers/dataset/mot_challenge.py new file mode 100644 index 00000000..e623a25d --- /dev/null +++ b/trackers/dataset/mot_challenge.py @@ -0,0 +1,603 @@ +import abc +import configparser +from pathlib import Path +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union + +import numpy as np +import supervision as sv +from scipy.optimize import linear_sum_assignment # Added import + +from trackers.dataset.base import Dataset +from trackers.dataset.utils import _relabel_ids + +# --- Define MOT Constants needed for preprocessing --- +MOT_PEDESTRIAN_ID = 1 +MOT_DISTRACTOR_IDS = [ + 2, + 7, + 8, + 12, +] # person_on_vehicle, static_person, distractor, reflection +MOT_IGNORE_IDS = [2, 7, 8, 12, 13] # Includes crowd (13) for ignore, adjust as needed +ZERO_MARKED_EPSILON = 1e-5 +# --- End MOT Constants --- + + +class MOTChallengeDataset(Dataset): + """ + Dataset class for loading sequences in the MOTChallenge format. + Handles parsing `seqinfo.ini`, `gt/gt.txt`, and optionally `det/det.txt`. + + Expected directory structure: + dataset_path/ + sequence_name_1/ + seqinfo.ini + gt/ + gt.txt # Ground truth annotations + img1/ # Image frames (directory name from seqinfo.ini) + 000001.jpg + ... + det/ # Optional, for public detections + det.txt # Format: frame,id,bb_left,bb_top,w,h,conf,-1,-1,-1 + sequence_name_2/ + ... + """ + + def __init__(self, dataset_path: Union[str, Path]): + """ + Initializes the MOTChallengeDataset. + + Args: + dataset_path: Path (str or Path object) to the root directory of the + MOTChallenge dataset (e.g., `/path/to/MOT17/train`). + + Raises: + FileNotFoundError: If the `dataset_path` does not exist or is not a + directory. + """ + self.root_path = Path(dataset_path) + if not self.root_path.is_dir(): + raise FileNotFoundError(f"Dataset path not found: {self.root_path}") + self._sequence_names = self._find_sequences() + self._public_detections: Optional[Dict[str, sv.Detections]] = ( + None # Cache for public detections + ) + + def _find_sequences(self) -> List[str]: + """ + Finds valid sequence directories (containing seqinfo.ini) within + the root path. + """ + sequences = [] + for item in self.root_path.iterdir(): + if item.is_dir() and (item / "seqinfo.ini").exists(): + sequences.append(item.name) + if not sequences: + print(f"Warning: No valid MOTChallenge sequences found in {self.root_path}") + return sorted(sequences) + + def _parse_mot_file( + self, file_path: Path, min_confidence: Optional[float] = None + ) -> Tuple[List[Dict[str, Any]], Dict[int, List[Dict[str, Any]]]]: + """ + Parses a MOT format file (gt.txt or det.txt) into structured dictionaries. + + Handles comma-separated values and converts bounding boxes from xywh to xyxy. + + Args: + file_path: Path object pointing to the MOT format file. + min_confidence: Optional minimum confidence threshold. Detections below + this threshold will be ignored. + + Returns: + A tuple containing: + - A list of dictionaries, where each dictionary represents a single + detection/annotation line parsed from the file. Keys include + 'frame_idx', 'obj_id', 'xyxy', 'confidence', 'class_id'. + - A dictionary mapping frame indices (int) to lists of detection/annotation + dictionaries belonging to that frame. + Returns ([], {}) if the file doesn't exist or an error occurs during parsing + """ + if not file_path.exists(): + return [], {} + + all_detections = [] + frame_detections: Dict[int, List[Dict[str, Any]]] = {} + + try: + with open(file_path, "r") as f: + for line in f: + line = line.strip() + if not line: + continue + + parts = line.split(",") + if len(parts) < 7: + print( + f"Warning: Skipping malformed line in {file_path}: {line}" + ) + continue + + try: + frame_idx = int(parts[0]) + obj_id = int(parts[1]) + x = float(parts[2]) + y = float(parts[3]) + width = float(parts[4]) + height = float(parts[5]) + # --- Ensure 7th column (index 6) is assigned to confidence --- + # This value often represents visibility/ignore flag in gt.txt + # (0=ignore) + # or detection confidence in det.txt + confidence = float(parts[6]) + + class_id = int(parts[7]) if len(parts) > 7 else -1 + + if min_confidence is not None and confidence < min_confidence: + continue + + detection = { + "frame_idx": frame_idx, + "obj_id": obj_id, + "xyxy": [x, y, x + width, y + height], + "confidence": confidence, # Correctly assigned + "class_id": class_id, + } + + all_detections.append(detection) + + if frame_idx not in frame_detections: + frame_detections[frame_idx] = [] + frame_detections[frame_idx].append(detection) + + except ValueError as ve: + print( + f"Warning: Skipping line with invalid numeric data in \ + {file_path}: {line} ({ve})" + ) + continue + + return all_detections, frame_detections + + except Exception as e: + print(f"Error parsing MOT file {file_path}: {e}") + return [], {} + + def load_ground_truth(self, sequence_name: str) -> Optional[sv.Detections]: + """ + Loads ground truth data for a specific sequence from the `gt/gt.txt` file. + + Parses the file and converts annotations into an sv.Detections object. + Frame indices are stored in `sv.Detections.data['frame_idx']`. + + Args: + sequence_name: The name of the sequence (e.g., 'MOT17-02-SDP'). + + Returns: + An sv.Detections object containing all ground truth annotations for the + sequence, or None if the `gt.txt` file doesn't exist or an error occurs + during loading/parsing. Returns `sv.Detections.empty()` if the file + exists but contains no valid annotations. + """ + gt_path = self.root_path / sequence_name / "gt" / "gt.txt" + if not gt_path.exists(): + print( + f"Warning: Ground truth file not found for sequence \ + {sequence_name} at {gt_path}" + ) + return None + + try: + all_detections, _ = self._parse_mot_file(gt_path) + + if not all_detections: + print(f"Warning: No valid annotations found in {gt_path}") + return sv.Detections.empty() + + # Extract data from parsed detections + boxes = [det["xyxy"] for det in all_detections] + confidence_scores = [det["confidence"] for det in all_detections] + class_ids = [det["class_id"] for det in all_detections] + track_ids = [det["obj_id"] for det in all_detections] + frame_indices = [det["frame_idx"] for det in all_detections] + + # Convert lists to numpy arrays + xyxy = np.array(boxes, dtype=np.float32) + confidence = np.array(confidence_scores, dtype=np.float32) + class_id = np.array(class_ids, dtype=np.int32) + tracker_id = np.array(track_ids, dtype=np.int32) + frame_idx = np.array(frame_indices, dtype=np.int32) + + # Create sv.Detections object with frame indices in data + return sv.Detections( + xyxy=xyxy, + confidence=confidence, + class_id=class_id, + tracker_id=tracker_id, + data={"frame_idx": frame_idx}, + ) + + except Exception as e: + print(f"Error loading ground truth for {sequence_name}: {e}") + return None + + def get_sequence_names(self) -> List[str]: + """Returns a sorted list of sequence names found in the dataset directory.""" + return self._sequence_names + + def get_sequence_info(self, sequence_name: str) -> Dict[str, Any]: + """ + Parses the `seqinfo.ini` file for a given sequence and returns its contents. + + Attempts to convert values to int or float where possible. Standardizes + common keys like 'frame_rate', 'seqLength', 'img_width', 'img_height', + 'img_dir', 'name'. + + Args: + sequence_name: The name of the sequence. + + Returns: + A dictionary containing sequence information. Returns an empty dictionary + if `seqinfo.ini` is not found, cannot be parsed, or lacks the + '[Sequence]' section. The 'img_dir' value will be a Path object. + """ + seq_info_path = self.root_path / sequence_name / "seqinfo.ini" + if not seq_info_path.exists(): + print(f"Warning: seqinfo.ini not found for sequence {sequence_name}") + return {} + + config = configparser.ConfigParser() + try: + config.read(seq_info_path) + if "Sequence" in config: + # Convert values to appropriate types (int, float, str) + info: Dict[str, Union[int, float, str]] = {} + for key, value in config["Sequence"].items(): + try: + # Attempt to convert to int, then float, else keep as string + info[key] = int(value) + except ValueError: + try: + info[key] = float(value) + except ValueError: + info[key] = value + + # Ensure standard keys exist (case-insensitive matching) + standard_info: Dict[str, Union[int, float, str, Path, None]] = { + "frame_rate": info.get("framerate"), + "seqLength": info.get("seqlength"), + "img_width": info.get("imwidth"), + "img_height": info.get("imheight"), + "img_dir": self.root_path + / sequence_name + / str( + info.get("imdir", "img1") + ), # Convert to string to ensure Path compatibility + "name": info.get("name", sequence_name), + } + # Filter out None values if keys weren't present + return {k: v for k, v in standard_info.items() if v is not None} + else: + print(f"Warning: '[Sequence]' section not found in {seq_info_path}") + return {} + except configparser.Error as e: + print(f"Error parsing {seq_info_path}: {e}") + return {} + + def get_frame_iterator(self, sequence_name: str) -> Iterator[Dict[str, Any]]: + """ + Returns an iterator yielding information about each frame in a sequence. + + Determines frame count and image directory from `seqinfo.ini`. Infers the + image file extension based on the first frame found. + + Args: + sequence_name: The name of the sequence. + + Yields: + Dictionaries, each containing: + - 'frame_idx': The frame number (int, 1-based). + - 'image_path': The absolute path to the image file (str). + Yields nothing if sequence info is incomplete or image files + cannot be found. + """ + seq_info = self.get_sequence_info(sequence_name) + num_frames = seq_info.get("seqLength") + img_dir = seq_info.get( + "img_dir" + ) + + if num_frames is None or img_dir is None or not img_dir.is_dir(): + print( + f"Warning: Could not determine frame count or image directory for \ + {sequence_name}. Check seqinfo.ini." + ) + return # Return empty iterator + + # Look for the first file to determine the extension + first_frame_pattern = f"{1:06d}.*" + potential_files = list(img_dir.glob(first_frame_pattern)) + if not potential_files: + print( + f"Warning: No image files found matching pattern \ + '{first_frame_pattern}' in {img_dir}" + ) + # Try common extensions explicitly if glob fails + if (img_dir / f"{1:06d}.jpg").exists(): + img_ext = ".jpg" + elif (img_dir / f"{1:06d}.png").exists(): + img_ext = ".png" + else: + print( + f"Warning: Could not determine image extension for sequence \ + {sequence_name}." + ) + return + else: + img_ext = potential_files[0].suffix + + for i in range(1, num_frames + 1): + frame_filename = f"{i:06d}{img_ext}" + frame_path = img_dir / frame_filename + if not frame_path.exists(): + print(f"Warning: Expected frame image not found: {frame_path}") + # Decide whether to skip or raise error. Skipping for now. + continue + + yield { + "frame_idx": i, + "image_path": str( + frame_path.resolve() + ), # Use absolute path for reliable dict key + } + + + def load_public_detections(self, min_confidence: Optional[float] = None) -> None: + """ + Loads public detections from `det/det.txt` for all sequences into memory. + + Parses `det.txt` files and stores detections in an internal cache, keyed + by the absolute image path. Detections are stored as sv.Detections objects. + + Args: + min_confidence: Optional minimum detection confidence score to include. + If None, all detections are loaded. + """ + print("Loading public detections...") + self._public_detections = {} + loaded_count = 0 + total_dets = 0 + + for seq_name in self.get_sequence_names(): + det_path = self.root_path / seq_name / "det" / "det.txt" + if not det_path.exists(): + print(f" Info: No det.txt found for sequence {seq_name}") + continue + + try: + # Load detections using common parser + _, frame_detections = self._parse_mot_file(det_path, min_confidence) + + if not frame_detections: + continue + + loaded_count += 1 + seq_total_dets = 0 + + # Get frame iterator to map frame index to image path + frame_map = { + info["frame_idx"]: info["image_path"] + for info in self.get_frame_iterator(seq_name) + } + + for frame_idx, detections in frame_detections.items(): + if frame_idx not in frame_map: + print( + f" Warning: Detections found for frame {frame_idx} \ + outside sequence length in {seq_name}. Skipping." + ) + continue + + image_path = frame_map[frame_idx] + + # Prepare arrays for sv.Detections + xyxy = np.array([det["xyxy"] for det in detections]) + confidence = np.array([det["confidence"] for det in detections]) + + self._public_detections[image_path] = sv.Detections( + xyxy=xyxy, + confidence=confidence, + class_id=None, # MOT public detections don't have class IDs + ) + seq_total_dets += len(detections) + + print(f" Loaded {seq_total_dets} detections for sequence {seq_name}") + total_dets += seq_total_dets + + except Exception as e: + print(f" Error loading detections for sequence {seq_name}: {e}") + + print( + f"Finished loading public detections. Found {total_dets} \ + detections across {loaded_count} sequences." + ) + if not self._public_detections: + print("Warning: No public detections were loaded.") + + @property + def has_public_detections(self) -> bool: + """Returns True if public detections have been loaded via + `load_public_detections`.""" + return self._public_detections is not None + + def get_public_detections(self, image_path: str) -> sv.Detections: + """ + Retrieves the loaded public detections associated with a specific image path. + + Requires `load_public_detections()` to have been called first. + + Args: + image_path: The absolute path (str) to the image file. + + Returns: + An sv.Detections object containing the public detections for the given + image path. Returns `sv.Detections.empty()` if no detections were loaded + for this path or if `load_public_detections()` was not called. + """ + if not self.has_public_detections: + print( + "Warning: Public detections requested but not loaded. \ + # Call load_public_detections() first." + ) + return sv.Detections.empty() + + abs_image_path = str(Path(image_path).resolve()) + + return (self._public_detections or {}).get( + abs_image_path, sv.Detections.empty() + ) + + def preprocess( + self, + gt_dets: sv.Detections, + pred_dets: sv.Detections, + iou_threshold: float = 0.5, + remove_distractor_matches: bool = True, + ) -> Tuple[sv.Detections, sv.Detections]: + """ + Applies MOT specific preprocessing based on TrackEval logic. + + 1. Optionally removes tracker detections matched to ground truth distractors. + 2. Removes ground truth annotations marked as distractors or "zero-marked". + 3. Relabels ground truth and (processed) prediction `tracker_id`s to be + contiguous and 0-based using the utility function. + + Requires GT detections to have `frame_idx`, `tracker_id`, `class_id`, and + `confidence` (where confidence corresponds to MOT `gt.txt` column 7). + Requires prediction detections to have `frame_idx` and `tracker_id`. + + Args: + gt_dets (sv.Detections): Raw ground truth detections. + pred_dets (sv.Detections): Raw prediction detections. + iou_threshold (float): IoU threshold used for matching predictions to + distractors. Defaults to 0.5. + remove_distractor_matches (bool): If True, remove predictions matched to + GT distractors/ignored classes. Defaults to True. + + Returns: + Tuple[sv.Detections, sv.Detections]: A tuple containing the processed + ground truth detections and processed prediction detections. Returns + original inputs if required fields are missing. + """ + gt_out_list = [] + pred_out_list = [] + + # --- Input Validation --- + if ( + gt_dets.data is None + or "frame_idx" not in gt_dets.data + or gt_dets.tracker_id is None + or gt_dets.class_id is None + or gt_dets.confidence is None + ): + print( + "Warning: GT detections missing required fields " + "(frame_idx, tracker_id, class_id, confidence) for " + "MOT preprocessing. Skipping." + ) + return gt_dets, pred_dets + if ( + pred_dets.data is None + or "frame_idx" not in pred_dets.data + or pred_dets.tracker_id is None + ): + print( + "Warning: Prediction detections missing required fields " + "(frame_idx, tracker_id) for MOT preprocessing. Skipping." + ) + return gt_dets, pred_dets + + all_frame_indices = sorted( + list(set(gt_dets.data["frame_idx"]).union(set(pred_dets.data["frame_idx"]))) + ) + + for frame_idx in all_frame_indices: + gt_dets_t = gt_dets[gt_dets.data["frame_idx"] == frame_idx] + pred_dets_t = pred_dets[pred_dets.data["frame_idx"] == frame_idx] + + pred_dets_t_filtered = pred_dets_t + + # --- TrackEval Preprocessing Step 1 & 2: + # Optionally remove tracker dets matching distractor GTs --- + if remove_distractor_matches: + to_remove_tracker_indices = np.array([], dtype=int) + if len(gt_dets_t) > 0 and len(pred_dets_t) > 0: + # Match all preds against all GTs for this frame + similarity = sv.detection.utils.box_iou_batch( + gt_dets_t.xyxy, pred_dets_t.xyxy + ) + match_scores = similarity.copy() + match_scores[ + match_scores < iou_threshold - np.finfo("float").eps + ] = 0 + + match_rows, match_cols = linear_sum_assignment( + -match_scores + ) # Maximize score + valid_match_mask = ( + match_scores[match_rows, match_cols] > 0 + np.finfo("float").eps + ) + match_rows = match_rows[valid_match_mask] + match_cols = match_cols[valid_match_mask] + + # Identify matches where GT is a distractor + matched_gt_classes = gt_dets_t.class_id[match_rows] + is_distractor_match = np.isin( + matched_gt_classes, MOT_DISTRACTOR_IDS + ) + to_remove_tracker_indices = match_cols[is_distractor_match] + + # Filter tracker detections for the frame + if len(to_remove_tracker_indices) > 0: + pred_keep_mask = np.ones(len(pred_dets_t), dtype=bool) + pred_keep_mask[to_remove_tracker_indices] = False + pred_dets_t_filtered = pred_dets_t[pred_keep_mask] + # else: pred_dets_t_filtered remains pred_dets_t + + # --- TrackEval Preprocessing Step 4: Remove unwanted GT dets --- + gt_is_pedestrian = gt_dets_t.class_id == MOT_PEDESTRIAN_ID + + # Refined check for zero_marked: Check if confidence is very close to 0 + gt_is_effectively_zero = np.abs(gt_dets_t.confidence) < ZERO_MARKED_EPSILON + # Also consider explicit ignore classes + gt_is_ignore_class = np.isin(gt_dets_t.class_id, MOT_IGNORE_IDS) + + # Keep GT if it IS a pedestrian AND it is NOT effectively zero-marked + # AND NOT an ignore class + gt_keep_mask = ( + gt_is_pedestrian & ~gt_is_effectively_zero & ~gt_is_ignore_class + ) + + gt_dets_t_filtered = gt_dets_t[gt_keep_mask] + + # Append filtered detections for the frame + if len(gt_dets_t_filtered) > 0: + gt_out_list.append(gt_dets_t_filtered) + if len(pred_dets_t_filtered) > 0: + pred_out_list.append(pred_dets_t_filtered) + + # Merge filtered detections across all frames + gt_processed = ( + sv.Detections.merge(gt_out_list) if gt_out_list else sv.Detections.empty() + ) + pred_processed = ( + sv.Detections.merge(pred_out_list) + if pred_out_list + else sv.Detections.empty() + ) + + # --- TrackEval Preprocessing Step 6: Relabel IDs using the utility function --- + gt_processed = _relabel_ids(gt_processed) + pred_processed = _relabel_ids(pred_processed) + + return gt_processed, pred_processed diff --git a/trackers/dataset/utils.py b/trackers/dataset/utils.py new file mode 100644 index 00000000..00982ce2 --- /dev/null +++ b/trackers/dataset/utils.py @@ -0,0 +1,85 @@ +import numpy as np +import supervision as sv + + +def _relabel_ids(detections: sv.Detections) -> sv.Detections: + """ + Relabels `tracker_id`s to be contiguous integers starting from 0. + + Handles potential NaN or non-integer IDs gracefully. IDs that cannot be + processed are left as -1 in the output. + + Args: + detections (sv.Detections): The detections object whose `tracker_id`s + need relabeling. + + Returns: + sv.Detections: The detections object with relabeled `tracker_id`s. + Returns the original object if no valid IDs are found or if input is empty. + """ + if len(detections) == 0 or detections.tracker_id is None: + return detections + + # 1. Filter out potential NaN values first + valid_ids_mask = ~np.isnan(detections.tracker_id) + if not np.any(valid_ids_mask): + # All IDs were NaN or array was empty after filtering + return detections + + # 2. Get unique integer IDs + try: + unique_ids = np.unique(detections.tracker_id[valid_ids_mask].astype(int)) + except ValueError: + print( + "Warning: Could not convert tracker IDs to integers during relabeling. " + "Skipping." + ) + return detections + + if len(unique_ids) == 0: + return detections + + # Now unique_ids contains only valid integers + max_id = np.max(unique_ids) + min_id = np.min(unique_ids) + + offset = 0 + if min_id < 0: + print( + f"Warning: Negative tracker IDs found ({min_id}). " + "Shifting IDs for relabeling." + ) + offset = -min_id + max_id += offset + + if np.isnan(max_id): + print("Warning: Max ID is NaN during relabeling after offset. Skipping.") + return detections + + map_size = int(max_id) + 1 + id_map = np.full(map_size, fill_value=-1, dtype=int) + new_id_counter = 0 + # Initialize new_ids based on the original tracker_id shape and type + new_ids = np.full_like(detections.tracker_id, fill_value=-1, dtype=int) + + # Iterate through the original positions where IDs were valid + original_indices = np.where(valid_ids_mask)[0] + for i in original_indices: + original_id = int(detections.tracker_id[i]) + offset + if original_id >= map_size or original_id < 0: + print( + f"Warning: Original ID {original_id - offset} out of bounds for map " + "during relabeling. Skipping ID." + ) + continue + + if id_map[original_id] == -1: + id_map[original_id] = new_id_counter + new_ids[i] = new_id_counter + new_id_counter += 1 + else: + new_ids[i] = id_map[original_id] + + + detections.tracker_id = new_ids + return detections \ No newline at end of file From df8ae6fc75732e18999c30e510c597ad7f41e369 Mon Sep 17 00:00:00 2001 From: rolson24 Date: Sat, 26 Apr 2025 09:54:47 -0400 Subject: [PATCH 2/4] replace prints with logging --- trackers/dataset/base.py | 4 +- trackers/dataset/mot_challenge.py | 101 +++++++++++++++--------------- trackers/dataset/utils.py | 23 +++---- 3 files changed, 65 insertions(+), 63 deletions(-) diff --git a/trackers/dataset/base.py b/trackers/dataset/base.py index 177096e5..b7e5538a 100644 --- a/trackers/dataset/base.py +++ b/trackers/dataset/base.py @@ -1,7 +1,9 @@ import abc from typing import Any, Dict, Iterator, List, Optional, Tuple + import supervision as sv + # --- Base Dataset --- class Dataset(abc.ABC): """Abstract base class for datasets used in tracking evaluation.""" @@ -82,4 +84,4 @@ def preprocess( Tuple[sv.Detections, sv.Detections]: A tuple containing the processed ground truth detections and processed prediction detections. """ - pass \ No newline at end of file + pass diff --git a/trackers/dataset/mot_challenge.py b/trackers/dataset/mot_challenge.py index e623a25d..43ae654e 100644 --- a/trackers/dataset/mot_challenge.py +++ b/trackers/dataset/mot_challenge.py @@ -1,14 +1,14 @@ -import abc import configparser from pathlib import Path from typing import Any, Dict, Iterator, List, Optional, Tuple, Union import numpy as np import supervision as sv -from scipy.optimize import linear_sum_assignment # Added import +from scipy.optimize import linear_sum_assignment from trackers.dataset.base import Dataset from trackers.dataset.utils import _relabel_ids +from trackers.log import get_logger # --- Define MOT Constants needed for preprocessing --- MOT_PEDESTRIAN_ID = 1 @@ -22,6 +22,8 @@ ZERO_MARKED_EPSILON = 1e-5 # --- End MOT Constants --- +logger = get_logger(__name__) + class MOTChallengeDataset(Dataset): """ @@ -73,7 +75,7 @@ def _find_sequences(self) -> List[str]: if item.is_dir() and (item / "seqinfo.ini").exists(): sequences.append(item.name) if not sequences: - print(f"Warning: No valid MOTChallenge sequences found in {self.root_path}") + logger.warning(f"No valid MOTChallenge sequences found in {self.root_path}") return sorted(sequences) def _parse_mot_file( @@ -113,8 +115,8 @@ def _parse_mot_file( parts = line.split(",") if len(parts) < 7: - print( - f"Warning: Skipping malformed line in {file_path}: {line}" + logger.warning( + f"Skipping malformed line in {file_path}: {line}" ) continue @@ -151,16 +153,16 @@ def _parse_mot_file( frame_detections[frame_idx].append(detection) except ValueError as ve: - print( - f"Warning: Skipping line with invalid numeric data in \ - {file_path}: {line} ({ve})" + logger.warning( + f"Skipping line with invalid numeric data in {file_path}: \ + {line} ({ve})" ) continue return all_detections, frame_detections except Exception as e: - print(f"Error parsing MOT file {file_path}: {e}") + logger.error(f"Error parsing MOT file {file_path}: {e}") return [], {} def load_ground_truth(self, sequence_name: str) -> Optional[sv.Detections]: @@ -181,9 +183,8 @@ def load_ground_truth(self, sequence_name: str) -> Optional[sv.Detections]: """ gt_path = self.root_path / sequence_name / "gt" / "gt.txt" if not gt_path.exists(): - print( - f"Warning: Ground truth file not found for sequence \ - {sequence_name} at {gt_path}" + logger.warning( + f"Ground truth file not found for sequence {sequence_name} at {gt_path}" ) return None @@ -191,7 +192,7 @@ def load_ground_truth(self, sequence_name: str) -> Optional[sv.Detections]: all_detections, _ = self._parse_mot_file(gt_path) if not all_detections: - print(f"Warning: No valid annotations found in {gt_path}") + logger.warning(f"No valid annotations found in {gt_path}") return sv.Detections.empty() # Extract data from parsed detections @@ -218,7 +219,7 @@ def load_ground_truth(self, sequence_name: str) -> Optional[sv.Detections]: ) except Exception as e: - print(f"Error loading ground truth for {sequence_name}: {e}") + logger.error(f"Error loading ground truth for {sequence_name}: {e}") return None def get_sequence_names(self) -> List[str]: @@ -243,7 +244,7 @@ def get_sequence_info(self, sequence_name: str) -> Dict[str, Any]: """ seq_info_path = self.root_path / sequence_name / "seqinfo.ini" if not seq_info_path.exists(): - print(f"Warning: seqinfo.ini not found for sequence {sequence_name}") + logger.warning(f"seqinfo.ini not found for sequence {sequence_name}") return {} config = configparser.ConfigParser() @@ -278,10 +279,10 @@ def get_sequence_info(self, sequence_name: str) -> Dict[str, Any]: # Filter out None values if keys weren't present return {k: v for k, v in standard_info.items() if v is not None} else: - print(f"Warning: '[Sequence]' section not found in {seq_info_path}") + logger.warning(f"'[Sequence]' section not found in {seq_info_path}") return {} except configparser.Error as e: - print(f"Error parsing {seq_info_path}: {e}") + logger.error(f"Error parsing {seq_info_path}: {e}") return {} def get_frame_iterator(self, sequence_name: str) -> Iterator[Dict[str, Any]]: @@ -303,13 +304,11 @@ def get_frame_iterator(self, sequence_name: str) -> Iterator[Dict[str, Any]]: """ seq_info = self.get_sequence_info(sequence_name) num_frames = seq_info.get("seqLength") - img_dir = seq_info.get( - "img_dir" - ) + img_dir = seq_info.get("img_dir") if num_frames is None or img_dir is None or not img_dir.is_dir(): - print( - f"Warning: Could not determine frame count or image directory for \ + logger.warning( + f"Could not determine frame count or image directory for \ {sequence_name}. Check seqinfo.ini." ) return # Return empty iterator @@ -318,9 +317,9 @@ def get_frame_iterator(self, sequence_name: str) -> Iterator[Dict[str, Any]]: first_frame_pattern = f"{1:06d}.*" potential_files = list(img_dir.glob(first_frame_pattern)) if not potential_files: - print( - f"Warning: No image files found matching pattern \ - '{first_frame_pattern}' in {img_dir}" + logger.warning( + f"No image files found matching pattern '{first_frame_pattern}' \ + in {img_dir}" ) # Try common extensions explicitly if glob fails if (img_dir / f"{1:06d}.jpg").exists(): @@ -328,19 +327,18 @@ def get_frame_iterator(self, sequence_name: str) -> Iterator[Dict[str, Any]]: elif (img_dir / f"{1:06d}.png").exists(): img_ext = ".png" else: - print( - f"Warning: Could not determine image extension for sequence \ - {sequence_name}." + logger.warning( + f"Could not determine image extension for sequence {sequence_name}." ) - return + return else: img_ext = potential_files[0].suffix for i in range(1, num_frames + 1): - frame_filename = f"{i:06d}{img_ext}" + frame_filename = f"{i:06d}{img_ext}" frame_path = img_dir / frame_filename if not frame_path.exists(): - print(f"Warning: Expected frame image not found: {frame_path}") + logger.warning(f"Expected frame image not found: {frame_path}") # Decide whether to skip or raise error. Skipping for now. continue @@ -351,7 +349,6 @@ def get_frame_iterator(self, sequence_name: str) -> Iterator[Dict[str, Any]]: ), # Use absolute path for reliable dict key } - def load_public_detections(self, min_confidence: Optional[float] = None) -> None: """ Loads public detections from `det/det.txt` for all sequences into memory. @@ -363,7 +360,7 @@ def load_public_detections(self, min_confidence: Optional[float] = None) -> None min_confidence: Optional minimum detection confidence score to include. If None, all detections are loaded. """ - print("Loading public detections...") + logger.info("Loading public detections...") self._public_detections = {} loaded_count = 0 total_dets = 0 @@ -371,7 +368,7 @@ def load_public_detections(self, min_confidence: Optional[float] = None) -> None for seq_name in self.get_sequence_names(): det_path = self.root_path / seq_name / "det" / "det.txt" if not det_path.exists(): - print(f" Info: No det.txt found for sequence {seq_name}") + logger.info(f"No det.txt found for sequence {seq_name}") continue try: @@ -392,9 +389,9 @@ def load_public_detections(self, min_confidence: Optional[float] = None) -> None for frame_idx, detections in frame_detections.items(): if frame_idx not in frame_map: - print( - f" Warning: Detections found for frame {frame_idx} \ - outside sequence length in {seq_name}. Skipping." + logger.warning( + f"Detections found for frame {frame_idx} outside sequence \ + length in {seq_name}. Skipping." ) continue @@ -411,18 +408,20 @@ def load_public_detections(self, min_confidence: Optional[float] = None) -> None ) seq_total_dets += len(detections) - print(f" Loaded {seq_total_dets} detections for sequence {seq_name}") + logger.info( + f"Loaded {seq_total_dets} detections for sequence {seq_name}" + ) total_dets += seq_total_dets except Exception as e: - print(f" Error loading detections for sequence {seq_name}: {e}") + logger.error(f"Error loading detections for sequence {seq_name}: {e}") - print( - f"Finished loading public detections. Found {total_dets} \ - detections across {loaded_count} sequences." + logger.info( + f"Finished loading public detections. Found {total_dets} detections \ + across {loaded_count} sequences." ) if not self._public_detections: - print("Warning: No public detections were loaded.") + logger.warning("No public detections were loaded.") @property def has_public_detections(self) -> bool: @@ -445,9 +444,9 @@ def get_public_detections(self, image_path: str) -> sv.Detections: for this path or if `load_public_detections()` was not called. """ if not self.has_public_detections: - print( - "Warning: Public detections requested but not loaded. \ - # Call load_public_detections() first." + logger.warning( + "Public detections requested but not loaded. \ + Call load_public_detections() first." ) return sv.Detections.empty() @@ -494,14 +493,14 @@ def preprocess( # --- Input Validation --- if ( - gt_dets.data is None + gt_dets.data is None or "frame_idx" not in gt_dets.data or gt_dets.tracker_id is None or gt_dets.class_id is None or gt_dets.confidence is None ): - print( - "Warning: GT detections missing required fields " + logger.warning( + "GT detections missing required fields " "(frame_idx, tracker_id, class_id, confidence) for " "MOT preprocessing. Skipping." ) @@ -511,8 +510,8 @@ def preprocess( or "frame_idx" not in pred_dets.data or pred_dets.tracker_id is None ): - print( - "Warning: Prediction detections missing required fields " + logger.warning( + "Prediction detections missing required fields " "(frame_idx, tracker_id) for MOT preprocessing. Skipping." ) return gt_dets, pred_dets diff --git a/trackers/dataset/utils.py b/trackers/dataset/utils.py index 00982ce2..8d7757ca 100644 --- a/trackers/dataset/utils.py +++ b/trackers/dataset/utils.py @@ -1,6 +1,10 @@ import numpy as np import supervision as sv +from trackers.log import get_logger # Added import + +logger = get_logger(__name__) # Added logger instance + def _relabel_ids(detections: sv.Detections) -> sv.Detections: """ @@ -30,9 +34,8 @@ def _relabel_ids(detections: sv.Detections) -> sv.Detections: try: unique_ids = np.unique(detections.tracker_id[valid_ids_mask].astype(int)) except ValueError: - print( - "Warning: Could not convert tracker IDs to integers during relabeling. " - "Skipping." + logger.warning( + "Could not convert tracker IDs to integers during relabeling. Skipping." ) return detections @@ -45,15 +48,14 @@ def _relabel_ids(detections: sv.Detections) -> sv.Detections: offset = 0 if min_id < 0: - print( - f"Warning: Negative tracker IDs found ({min_id}). " - "Shifting IDs for relabeling." + logger.warning( + f"Negative tracker IDs found ({min_id}). Shifting IDs for relabeling." ) offset = -min_id max_id += offset if np.isnan(max_id): - print("Warning: Max ID is NaN during relabeling after offset. Skipping.") + logger.warning("Max ID is NaN during relabeling after offset. Skipping.") return detections map_size = int(max_id) + 1 @@ -67,8 +69,8 @@ def _relabel_ids(detections: sv.Detections) -> sv.Detections: for i in original_indices: original_id = int(detections.tracker_id[i]) + offset if original_id >= map_size or original_id < 0: - print( - f"Warning: Original ID {original_id - offset} out of bounds for map " + logger.warning( + f"Original ID {original_id - offset} out of bounds for map " "during relabeling. Skipping ID." ) continue @@ -80,6 +82,5 @@ def _relabel_ids(detections: sv.Detections) -> sv.Detections: else: new_ids[i] = id_map[original_id] - detections.tracker_id = new_ids - return detections \ No newline at end of file + return detections From 390a342ecb0ffdc93991d4cc442b38a7d6c77c5f Mon Sep 17 00:00:00 2001 From: rolson24 Date: Mon, 28 Apr 2025 23:24:49 -0400 Subject: [PATCH 3/4] Chore: respond to feedback --- trackers/dataset/base.py | 2 +- trackers/dataset/mot_challenge.py | 55 ++++++++++++++++++++++++++----- trackers/dataset/utils.py | 6 ++-- 3 files changed, 50 insertions(+), 13 deletions(-) diff --git a/trackers/dataset/base.py b/trackers/dataset/base.py index b7e5538a..866dcad3 100644 --- a/trackers/dataset/base.py +++ b/trackers/dataset/base.py @@ -5,7 +5,7 @@ # --- Base Dataset --- -class Dataset(abc.ABC): +class EvaluationDataset(abc.ABC): """Abstract base class for datasets used in tracking evaluation.""" @abc.abstractmethod diff --git a/trackers/dataset/mot_challenge.py b/trackers/dataset/mot_challenge.py index 43ae654e..417a7d76 100644 --- a/trackers/dataset/mot_challenge.py +++ b/trackers/dataset/mot_challenge.py @@ -6,8 +6,8 @@ import supervision as sv from scipy.optimize import linear_sum_assignment -from trackers.dataset.base import Dataset -from trackers.dataset.utils import _relabel_ids +from trackers.dataset.base import EvaluationDataset +from trackers.dataset.utils import relabel_ids from trackers.log import get_logger # --- Define MOT Constants needed for preprocessing --- @@ -25,7 +25,7 @@ logger = get_logger(__name__) -class MOTChallengeDataset(Dataset): +class MOTChallengeDataset(EvaluationDataset): """ Dataset class for loading sequences in the MOTChallenge format. Handles parsing `seqinfo.ini`, `gt/gt.txt`, and optionally `det/det.txt`. @@ -61,9 +61,8 @@ def __init__(self, dataset_path: Union[str, Path]): if not self.root_path.is_dir(): raise FileNotFoundError(f"Dataset path not found: {self.root_path}") self._sequence_names = self._find_sequences() - self._public_detections: Optional[Dict[str, sv.Detections]] = ( - None # Cache for public detections - ) + self._public_detections: Dict[str, sv.Detections] = {} + self._frame_maps: Dict[str, Dict[int, str]] = {} def _find_sequences(self) -> List[str]: """ @@ -362,6 +361,7 @@ def load_public_detections(self, min_confidence: Optional[float] = None) -> None """ logger.info("Loading public detections...") self._public_detections = {} + self._frame_maps = {} loaded_count = 0 total_dets = 0 @@ -387,6 +387,8 @@ def load_public_detections(self, min_confidence: Optional[float] = None) -> None for info in self.get_frame_iterator(seq_name) } + self._frame_maps[seq_name] = frame_map + for frame_idx, detections in frame_detections.items(): if frame_idx not in frame_map: logger.warning( @@ -429,7 +431,7 @@ def has_public_detections(self) -> bool: `load_public_detections`.""" return self._public_detections is not None - def get_public_detections(self, image_path: str) -> sv.Detections: + def get_public_detections_from_image_path(self, image_path: str) -> sv.Detections: """ Retrieves the loaded public detections associated with a specific image path. @@ -456,6 +458,41 @@ def get_public_detections(self, image_path: str) -> sv.Detections: abs_image_path, sv.Detections.empty() ) + def get_public_detections_from_frame_index( + self, sequence_name: str, frame_idx: int + ) -> sv.Detections: + """ + Retrieves the loaded public detections for a specific frame index in a + sequence. + Requires `load_public_detections()` to have been called first. + Args: + sequence_name: The name of the sequence (e.g., 'MOT17-02-SDP'). + frame_idx: The frame index (1-based). + Returns: + An sv.Detections object containing the public detections for the + specified frame index. Returns `sv.Detections.empty()` if no detections + were loaded for this frame or if `load_public_detections()` was not + called. + """ + if not self.has_public_detections: + logger.warning( + "Public detections requested but not loaded. \ + Call load_public_detections() first." + ) + return sv.Detections.empty() + + frame_map = self._frame_maps.get(sequence_name, {}) + abs_image_path = frame_map.get(frame_idx) + + if abs_image_path is None: + logger.warning( + f"No public detections found for sequence {sequence_name} at frame \ + {frame_idx}" + ) + return sv.Detections.empty() + + return self.get_public_detections_from_image_path(abs_image_path) + def preprocess( self, gt_dets: sv.Detections, @@ -596,7 +633,7 @@ def preprocess( ) # --- TrackEval Preprocessing Step 6: Relabel IDs using the utility function --- - gt_processed = _relabel_ids(gt_processed) - pred_processed = _relabel_ids(pred_processed) + gt_processed = relabel_ids(gt_processed) + pred_processed = relabel_ids(pred_processed) return gt_processed, pred_processed diff --git a/trackers/dataset/utils.py b/trackers/dataset/utils.py index 8d7757ca..5edb7852 100644 --- a/trackers/dataset/utils.py +++ b/trackers/dataset/utils.py @@ -6,7 +6,7 @@ logger = get_logger(__name__) # Added logger instance -def _relabel_ids(detections: sv.Detections) -> sv.Detections: +def relabel_ids(detections: sv.Detections) -> sv.Detections: """ Relabels `tracker_id`s to be contiguous integers starting from 0. @@ -43,8 +43,8 @@ def _relabel_ids(detections: sv.Detections) -> sv.Detections: return detections # Now unique_ids contains only valid integers - max_id = np.max(unique_ids) - min_id = np.min(unique_ids) + max_id: int = np.max(unique_ids) + min_id: int = np.min(unique_ids) offset = 0 if min_id < 0: From 3034cb5a615b88c3363f3ed2d5a7834e9e111734 Mon Sep 17 00:00:00 2001 From: rolson24 Date: Tue, 29 Apr 2025 22:35:23 -0400 Subject: [PATCH 4/4] chore: fix nitpicks --- trackers/dataset/mot_challenge.py | 14 +++++--------- trackers/dataset/utils.py | 8 ++++---- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/trackers/dataset/mot_challenge.py b/trackers/dataset/mot_challenge.py index 417a7d76..52ced261 100644 --- a/trackers/dataset/mot_challenge.py +++ b/trackers/dataset/mot_challenge.py @@ -20,7 +20,7 @@ ] # person_on_vehicle, static_person, distractor, reflection MOT_IGNORE_IDS = [2, 7, 8, 12, 13] # Includes crowd (13) for ignore, adjust as needed ZERO_MARKED_EPSILON = 1e-5 -# --- End MOT Constants --- + logger = get_logger(__name__) @@ -141,7 +141,7 @@ def _parse_mot_file( "frame_idx": frame_idx, "obj_id": obj_id, "xyxy": [x, y, x + width, y + height], - "confidence": confidence, # Correctly assigned + "confidence": confidence, "class_id": class_id, } @@ -254,7 +254,6 @@ def get_sequence_info(self, sequence_name: str) -> Dict[str, Any]: info: Dict[str, Union[int, float, str]] = {} for key, value in config["Sequence"].items(): try: - # Attempt to convert to int, then float, else keep as string info[key] = int(value) except ValueError: try: @@ -528,7 +527,6 @@ def preprocess( gt_out_list = [] pred_out_list = [] - # --- Input Validation --- if ( gt_dets.data is None or "frame_idx" not in gt_dets.data @@ -563,8 +561,7 @@ def preprocess( pred_dets_t_filtered = pred_dets_t - # --- TrackEval Preprocessing Step 1 & 2: - # Optionally remove tracker dets matching distractor GTs --- + # Optionally remove tracker dets matching distractor GTs if remove_distractor_matches: to_remove_tracker_indices = np.array([], dtype=int) if len(gt_dets_t) > 0 and len(pred_dets_t) > 0: @@ -598,9 +595,8 @@ def preprocess( pred_keep_mask = np.ones(len(pred_dets_t), dtype=bool) pred_keep_mask[to_remove_tracker_indices] = False pred_dets_t_filtered = pred_dets_t[pred_keep_mask] - # else: pred_dets_t_filtered remains pred_dets_t - # --- TrackEval Preprocessing Step 4: Remove unwanted GT dets --- + # Remove unwanted GT dets gt_is_pedestrian = gt_dets_t.class_id == MOT_PEDESTRIAN_ID # Refined check for zero_marked: Check if confidence is very close to 0 @@ -632,7 +628,7 @@ def preprocess( else sv.Detections.empty() ) - # --- TrackEval Preprocessing Step 6: Relabel IDs using the utility function --- + # Relabel IDs using the utility function gt_processed = relabel_ids(gt_processed) pred_processed = relabel_ids(pred_processed) diff --git a/trackers/dataset/utils.py b/trackers/dataset/utils.py index 5edb7852..a3b1944c 100644 --- a/trackers/dataset/utils.py +++ b/trackers/dataset/utils.py @@ -1,9 +1,9 @@ import numpy as np import supervision as sv -from trackers.log import get_logger # Added import +from trackers.log import get_logger -logger = get_logger(__name__) # Added logger instance +logger = get_logger(__name__) def relabel_ids(detections: sv.Detections) -> sv.Detections: @@ -24,13 +24,13 @@ def relabel_ids(detections: sv.Detections) -> sv.Detections: if len(detections) == 0 or detections.tracker_id is None: return detections - # 1. Filter out potential NaN values first + # Filter out potential NaN values valid_ids_mask = ~np.isnan(detections.tracker_id) if not np.any(valid_ids_mask): # All IDs were NaN or array was empty after filtering return detections - # 2. Get unique integer IDs + # Get unique integer IDs try: unique_ids = np.unique(detections.tracker_id[valid_ids_mask].astype(int)) except ValueError: