Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 42 additions & 146 deletions test/core/reid/dataset/test_base.py
Original file line number Diff line number Diff line change
@@ -1,159 +1,55 @@
from contextlib import ExitStack as DoesNotRaise
import os
import shutil

import pytest

from trackers.core.reid.dataset.base import TripletsDataset
from trackers.core.reid.dataset.base import IdentityDataset
from trackers.core.reid.dataset.market_1501 import parse_market1501_dataset
from trackers.core.reid.model import ReIDModel
from trackers.utils.data_utils import unzip_file
from trackers.utils.downloader import download_file

DATASET_URL = "https://storage.googleapis.com/com-roboflow-marketing/trackers/datasets/market_1501.zip"

@pytest.mark.parametrize(
"tracker_id_to_images, exception",
[
(
{"0111": []},
pytest.raises(ValueError),
), # Single tracker with no images - should raise ValueError
(
{"0111": ["0111_00000000.jpg"]},
pytest.raises(ValueError),
), # Single tracker with one image - should raise ValueError
(
{"0111": ["0111_00000000.jpg", "0111_00000001.jpg"]},
pytest.raises(ValueError),
), # Single tracker with multiple images - should raise ValueError
(
{
"0111": ["0111_00000000.jpg", "0111_00000001.jpg"],
"0112": ["0112_00000000.jpg"],
},
pytest.raises(ValueError),
), # Two trackers but one has only one image - should raise ValueError
(
{
"0111": ["0111_00000000.jpg", "0111_00000001.jpg"],
"0112": ["0112_00000000.jpg", "0112_00000001.jpg"],
},
DoesNotRaise(),
), # Two trackers with multiple images - should not raise
(
{
"0111": ["0111_00000000.jpg", "0111_00000001.jpg"],
"0112": ["0112_00000000.jpg", "0112_00000001.jpg"],
"0113": ["0113_00000000.jpg"],
},
DoesNotRaise(),
), # Three trackers, one with fewer images - should validate dataset length
],
)
def test_triplet_dataset_initialization(tracker_id_to_images, exception):
with exception:
_ = TripletsDataset(tracker_id_to_images)

@pytest.fixture
def market_1501_dataset():
os.makedirs("test_data", exist_ok=True)
dataset_path = os.path.join("test_data", "Market-1501-v15.09.15")
zip_path = os.path.join("test_data", "market_1501.zip")
if not os.path.exists(dataset_path):
if not os.path.exists(zip_path):
download_file(DATASET_URL)
shutil.move("market_1501.zip", str(zip_path))
unzip_file(str(zip_path), "test_data")
yield dataset_path


@pytest.mark.parametrize(
"tracker_id_to_images, split_ratio, expected_train_size, expected_val_size, exception", # noqa: E501
[
(
{
"0111": ["0111_00000000.jpg", "0111_00000001.jpg"],
"0112": ["0112_00000000.jpg", "0112_00000001.jpg"],
},
0.5,
1,
1,
pytest.raises(ValueError),
), # Split results in only 1 tracker in test set - should raise ValueError
(
{
"0111": ["0111_00000000.jpg", "0111_00000001.jpg"],
"0112": ["0112_00000000.jpg", "0112_00000001.jpg"],
"0113": ["0113_00000000.jpg", "0113_00000001.jpg"],
"0114": ["0114_00000000.jpg", "0114_00000001.jpg"],
"0115": ["0115_00000000.jpg", "0115_00000001.jpg"],
},
0.2,
1,
4,
pytest.raises(ValueError),
), # Split results in only 1 tracker in test set - should raise ValueError
(
{
"0111": ["0111_00000000.jpg", "0111_00000001.jpg"],
"0112": ["0112_00000000.jpg", "0112_00000001.jpg"],
"0113": ["0113_00000000.jpg", "0113_00000001.jpg"],
"0114": ["0114_00000000.jpg", "0114_00000001.jpg"],
"0115": ["0115_00000000.jpg", "0115_00000001.jpg"],
},
0.8,
4,
1,
pytest.raises(ValueError),
), # Split results in only 1 tracker in val set - should raise ValueError
(
{
"0111": ["0111_00000000.jpg", "0111_00000001.jpg"],
"0112": ["0112_00000000.jpg", "0112_00000001.jpg"],
"0113": ["0113_00000000.jpg", "0113_00000001.jpg"],
"0114": ["0114_00000000.jpg", "0114_00000001.jpg"],
"0115": ["0115_00000000.jpg", "0115_00000001.jpg"],
},
0.6,
3,
2,
DoesNotRaise(),
), # Valid split with multiple trackers in both sets
(
{
"0111": ["0111_00000000.jpg", "0111_00000001.jpg"],
"0112": ["0112_00000000.jpg", "0112_00000001.jpg"],
"0113": ["0113_00000000.jpg", "0113_00000001.jpg"],
"0114": ["0114_00000000.jpg", "0114_00000001.jpg"],
},
0.5,
2,
2,
DoesNotRaise(),
), # 50% train, 50% validation - valid
],
"dataset_split", ["bounding_box_train, bounding_box_test", "query"]
)
def test_triplet_dataset_split(
tracker_id_to_images, split_ratio, expected_train_size, expected_val_size, exception
):
with exception:
dataset = TripletsDataset(tracker_id_to_images)
train_dataset, val_dataset = dataset.split(split_ratio=split_ratio)

assert len(train_dataset) == expected_train_size, (
f"Expected train dataset size {expected_train_size}, "
f"got {len(train_dataset)}"
)
assert len(val_dataset) == expected_val_size, (
f"Expected validation dataset size {expected_val_size}, "
f"got {len(val_dataset)}"
)
def test_identity_dataset(market_1501_dataset, dataset_split):
dataset_path = os.path.join(market_1501_dataset, dataset_split)
dataset = IdentityDataset(parse_market1501_dataset(data_dir=dataset_path))
if dataset_split == "bounding_box_train":
assert len(dataset) == 12936
assert dataset.get_num_identities() == 751
elif dataset_split == "bounding_box_test":
assert len(dataset) == 15913
assert dataset.get_num_identities() == 751
elif dataset_split == "query":
assert len(dataset) == 3368
assert dataset.get_num_identities() == 750


@pytest.mark.parametrize(
"tracker_id_to_images, tracker_id, exception",
[
(
{
"0111": ["0111_00000000.jpg", "0111_00000001.jpg"],
"0112": ["0112_00000000.jpg", "0112_00000001.jpg"],
"0113": ["0113_00000000.jpg", "0113_00000001.jpg"],
},
"0111",
DoesNotRaise(),
),
],
"dataset_split", ["bounding_box_train, bounding_box_test", "query"]
)
def test_get_triplet_image_paths(tracker_id_to_images, tracker_id, exception) -> None:
with exception:
dataset = TripletsDataset(tracker_id_to_images)
anchor_path, positive_path, negative_path = dataset._get_triplet_image_paths(
tracker_id
)

assert anchor_path in tracker_id_to_images[tracker_id]
assert positive_path in tracker_id_to_images[tracker_id]
assert negative_path not in tracker_id_to_images[tracker_id]
assert anchor_path != positive_path
def test_reid_model_classification_head(market_1501_dataset, dataset_split):
dataset_path = os.path.join(market_1501_dataset, dataset_split)
dataset = IdentityDataset(parse_market1501_dataset(data_dir=dataset_path))
model = ReIDModel.from_timm("resnet50")
model.add_classification_head(
num_classes=dataset.get_num_identities(), freeze_backbone=True
)
assert model.backbone.fc.out_features == dataset.get_num_identities()
52 changes: 0 additions & 52 deletions test/core/reid/dataset/test_market_1501.py

This file was deleted.

4 changes: 2 additions & 2 deletions trackers/core/reid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
logger = get_logger(__name__)

try:
from trackers.core.reid.dataset.base import TripletsDataset
from trackers.core.reid.dataset.base import IdentityDataset
from trackers.core.reid.dataset.market_1501 import get_market1501_dataset
from trackers.core.reid.model import ReIDModel

__all__ = ["ReIDModel", "TripletsDataset", "get_market1501_dataset"]
__all__ = ["IdentityDataset", "ReIDModel", "get_market1501_dataset"]
except ImportError:
logger.warning(
"ReIDModel dependencies not installed. ReIDModel will not be available. "
Expand Down
Loading
Loading