Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions movement/kinematics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
compute_speed,
compute_time_derivative,
compute_velocity,
detect_u_turns,
)
from movement.kinematics.orientation import (
compute_forward_vector,
Expand All @@ -26,4 +27,5 @@
"compute_forward_vector",
"compute_head_direction_vector",
"compute_forward_vector_angle",
"detect_u_turns",
]
69 changes: 68 additions & 1 deletion movement/kinematics/kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
import warnings
from typing import Literal

import numpy as np
import xarray as xr

from movement.kinematics.orientation import compute_forward_vector
from movement.utils.logging import logger
from movement.utils.reports import report_nan_values
from movement.utils.vector import compute_norm
from movement.utils.vector import compute_norm, compute_signed_angle_2d
from movement.validators.arrays import validate_dims_coords


Expand Down Expand Up @@ -379,3 +381,68 @@
valid_proportion = valid_segments / (data.sizes["time"] - 1)
# return scaled path length
return compute_norm(displacement).sum(dim="time") / valid_proportion


def detect_u_turns(
data: xr.DataArray,
use_direction: Literal["forward_vector", "displacement"] = "displacement",
u_turn_threshold: float = np.pi * 5 / 6, # 150 degrees in radians
camera_view: Literal["top_down", "bottom_up"] = "bottom_up",
) -> xr.DataArray:
"""Detect U-turn behavior in a trajectory.

This function computes the directional change between consecutive time
frames and accumulates the rotation angles. If the accumulated angle
exceeds a specified threshold, a U-turn is detected.

Parameters
----------
data : xarray.DataArray
The trajectory data, which must contain the 'time' and 'space' (x, y).
use_direction : Literal["forward_vector", "displacement"], optional
Method to compute direction vectors, default is `"displacement"`:
- `"forward_vector"`: Computes the forward direction vector.
- `"displacement"`: Computes displacement vectors.
u_turn_threshold : float, optional
The angle threshold (in radians) to detect U-turn. Default is (`5π/6`).
camera_view : Literal["top_down", "bottom_up"], optional
Specifies the camera perspective used for computing direction vectors.

Returns
-------
xarray.DataArray
A boolean scalar DataArray indicating whether a U-turn has occurred.

"""
# Compute direction vectors
if use_direction == "forward_vector":
direction_vectors = compute_forward_vector(
data, "left_ear", "right_ear", camera_view=camera_view
)
elif use_direction == "displacement":
if "keypoints" in data.dims:
raise ValueError(

Check warning on line 424 in movement/kinematics/kinematics.py

View check run for this annotation

Codecov / codecov/patch

movement/kinematics/kinematics.py#L424

Added line #L424 was not covered by tests
"Displacement expects single keypoint data "
"and must not include the 'keypoints' dimension."
)
direction_vectors = compute_displacement(data)
else:
raise ValueError(
"The parameter `use_direction` must be one of `forward_vector` "
f" or `displacement`, but got {use_direction}."
)

# Compute angle between vectors
angles = compute_signed_angle_2d(
direction_vectors.shift(time=1), direction_vectors
)
angles = angles.fillna(0)

# Accumulate angles over time and compute range
cumulative_rotation = angles.cumsum(dim="time")
rotation_range = cumulative_rotation.max(
dim="time"
) - cumulative_rotation.min(dim="time")

# Return scalar boolean as xarray.DataArray
return xr.DataArray(rotation_range >= u_turn_threshold)
68 changes: 68 additions & 0 deletions tests/test_unit/test_kinematics/test_kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,71 @@ def test_path_length_nan_warn_threshold(
position, nan_warn_threshold=nan_warn_threshold
)
assert result.name == "path_length"


@pytest.fixture
def valid_data_array_for_u_turn_detection():
"""Return a position data array for an individual with 3 keypoints
(left ear, right ear, and nose), tracked for 4 frames, in x-y space.
"""
time = [0, 1, 2, 3]
keypoints = ["left_ear", "right_ear", "nose"]
space = ["x", "y"]

ds = xr.DataArray(
[
[[-1, 0], [1, 0], [0, 1]], # time 0
[[0, 2], [0, 0], [1, 1]], # time 1
[[2, 1], [0, 1], [1, 0]], # time 2
[[1, -1], [1, 1], [0, 0]], # time 3
],
dims=["time", "keypoints", "space"],
coords={
"time": time,
"keypoints": keypoints,
"space": space,
},
)
return ds


def test_detect_u_turns(valid_data_array_for_u_turn_detection):
"""Test that U-turn detection works correctly using a mock dataset."""
# Forward vector method
u_turn_forward_vector = kinematics.detect_u_turns(
valid_data_array_for_u_turn_detection, use_direction="forward_vector"
)
assert u_turn_forward_vector.item() is True

# Displacement method (nose-only)
nose_data = valid_data_array_for_u_turn_detection.sel(
keypoints="nose"
).drop_vars("keypoints")
u_turn_displacement = kinematics.detect_u_turns(
nose_data, use_direction="displacement"
)
assert u_turn_displacement.item() is True

# Stricter threshold - displacement should return False
strict_displacement = kinematics.detect_u_turns(
nose_data, use_direction="displacement", u_turn_threshold=np.pi * 7 / 6
)
assert strict_displacement.item() is False

# Stricter threshold - forward vector still returns True
strict_forward_vector = kinematics.detect_u_turns(
valid_data_array_for_u_turn_detection,
use_direction="forward_vector",
u_turn_threshold=np.pi * 7 / 6,
)
assert strict_forward_vector.item() is True

# Invalid use_direction check
with pytest.raises(
ValueError,
match="must be one of `forward_vector`.*but got invalid_direction",
):
kinematics.detect_u_turns(
valid_data_array_for_u_turn_detection,
use_direction="invalid_direction",
)
22 changes: 22 additions & 0 deletions tests/test_unit/test_kinematics/test_orientation.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,3 +431,25 @@ def test_casts_from_tuple(

xr.testing.assert_allclose(pass_numpy, pass_tuple)
xr.testing.assert_allclose(pass_numpy, pass_list)

def test_nan_handling(self, spinning_on_the_spot: xr.DataArray) -> None:
"""Ensure NaNs don't crash the angle computation."""
data = spinning_on_the_spot.copy()
data[1, :, :] = np.nan
result = kinematics.compute_forward_vector_angle(
data, "left", "right", self.x_axis
)
# NaN at index 1
assert np.isnan(result[1])
# Valid elsewhere
assert not np.isnan(result[0])

def test_invalid_dimensionality_raises(self) -> None:
"""Invalid space dimension should raise a ValueError."""
data = xr.DataArray(
np.zeros((5, 2, 3)), # 3D space is invalid
dims=["time", "keypoints", "space"],
coords={"space": ["x", "y", "z"], "keypoints": ["left", "right"]},
)
with pytest.raises(ValueError, match="2 spatial dimensions"):
kinematics.compute_forward_vector_angle(data, "left", "right")
Loading