diff --git a/movement/kinematics/__init__.py b/movement/kinematics/__init__.py index 0e6e16a97..dae7c427f 100644 --- a/movement/kinematics/__init__.py +++ b/movement/kinematics/__init__.py @@ -8,6 +8,7 @@ compute_speed, compute_time_derivative, compute_velocity, + detect_u_turns, ) from movement.kinematics.orientation import ( compute_forward_vector, @@ -26,4 +27,5 @@ "compute_forward_vector", "compute_head_direction_vector", "compute_forward_vector_angle", + "detect_u_turns", ] diff --git a/movement/kinematics/kinematics.py b/movement/kinematics/kinematics.py index d2a1003aa..efe707b55 100644 --- a/movement/kinematics/kinematics.py +++ b/movement/kinematics/kinematics.py @@ -13,11 +13,13 @@ import warnings from typing import Literal +import numpy as np import xarray as xr +from movement.kinematics.orientation import compute_forward_vector from movement.utils.logging import logger from movement.utils.reports import report_nan_values -from movement.utils.vector import compute_norm +from movement.utils.vector import compute_norm, compute_signed_angle_2d from movement.validators.arrays import validate_dims_coords @@ -379,3 +381,68 @@ def _compute_scaled_path_length( valid_proportion = valid_segments / (data.sizes["time"] - 1) # return scaled path length return compute_norm(displacement).sum(dim="time") / valid_proportion + + +def detect_u_turns( + data: xr.DataArray, + use_direction: Literal["forward_vector", "displacement"] = "displacement", + u_turn_threshold: float = np.pi * 5 / 6, # 150 degrees in radians + camera_view: Literal["top_down", "bottom_up"] = "bottom_up", +) -> xr.DataArray: + """Detect U-turn behavior in a trajectory. + + This function computes the directional change between consecutive time + frames and accumulates the rotation angles. If the accumulated angle + exceeds a specified threshold, a U-turn is detected. + + Parameters + ---------- + data : xarray.DataArray + The trajectory data, which must contain the 'time' and 'space' (x, y). + use_direction : Literal["forward_vector", "displacement"], optional + Method to compute direction vectors, default is `"displacement"`: + - `"forward_vector"`: Computes the forward direction vector. + - `"displacement"`: Computes displacement vectors. + u_turn_threshold : float, optional + The angle threshold (in radians) to detect U-turn. Default is (`5π/6`). + camera_view : Literal["top_down", "bottom_up"], optional + Specifies the camera perspective used for computing direction vectors. + + Returns + ------- + xarray.DataArray + A boolean scalar DataArray indicating whether a U-turn has occurred. + + """ + # Compute direction vectors + if use_direction == "forward_vector": + direction_vectors = compute_forward_vector( + data, "left_ear", "right_ear", camera_view=camera_view + ) + elif use_direction == "displacement": + if "keypoints" in data.dims: + raise ValueError( + "Displacement expects single keypoint data " + "and must not include the 'keypoints' dimension." + ) + direction_vectors = compute_displacement(data) + else: + raise ValueError( + "The parameter `use_direction` must be one of `forward_vector` " + f" or `displacement`, but got {use_direction}." + ) + + # Compute angle between vectors + angles = compute_signed_angle_2d( + direction_vectors.shift(time=1), direction_vectors + ) + angles = angles.fillna(0) + + # Accumulate angles over time and compute range + cumulative_rotation = angles.cumsum(dim="time") + rotation_range = cumulative_rotation.max( + dim="time" + ) - cumulative_rotation.min(dim="time") + + # Return scalar boolean as xarray.DataArray + return xr.DataArray(rotation_range >= u_turn_threshold) diff --git a/tests/test_unit/test_kinematics/test_kinematics.py b/tests/test_unit/test_kinematics/test_kinematics.py index c56ee0801..ec823df97 100644 --- a/tests/test_unit/test_kinematics/test_kinematics.py +++ b/tests/test_unit/test_kinematics/test_kinematics.py @@ -318,3 +318,71 @@ def test_path_length_nan_warn_threshold( position, nan_warn_threshold=nan_warn_threshold ) assert result.name == "path_length" + + +@pytest.fixture +def valid_data_array_for_u_turn_detection(): + """Return a position data array for an individual with 3 keypoints + (left ear, right ear, and nose), tracked for 4 frames, in x-y space. + """ + time = [0, 1, 2, 3] + keypoints = ["left_ear", "right_ear", "nose"] + space = ["x", "y"] + + ds = xr.DataArray( + [ + [[-1, 0], [1, 0], [0, 1]], # time 0 + [[0, 2], [0, 0], [1, 1]], # time 1 + [[2, 1], [0, 1], [1, 0]], # time 2 + [[1, -1], [1, 1], [0, 0]], # time 3 + ], + dims=["time", "keypoints", "space"], + coords={ + "time": time, + "keypoints": keypoints, + "space": space, + }, + ) + return ds + + +def test_detect_u_turns(valid_data_array_for_u_turn_detection): + """Test that U-turn detection works correctly using a mock dataset.""" + # Forward vector method + u_turn_forward_vector = kinematics.detect_u_turns( + valid_data_array_for_u_turn_detection, use_direction="forward_vector" + ) + assert u_turn_forward_vector.item() is True + + # Displacement method (nose-only) + nose_data = valid_data_array_for_u_turn_detection.sel( + keypoints="nose" + ).drop_vars("keypoints") + u_turn_displacement = kinematics.detect_u_turns( + nose_data, use_direction="displacement" + ) + assert u_turn_displacement.item() is True + + # Stricter threshold - displacement should return False + strict_displacement = kinematics.detect_u_turns( + nose_data, use_direction="displacement", u_turn_threshold=np.pi * 7 / 6 + ) + assert strict_displacement.item() is False + + # Stricter threshold - forward vector still returns True + strict_forward_vector = kinematics.detect_u_turns( + valid_data_array_for_u_turn_detection, + use_direction="forward_vector", + u_turn_threshold=np.pi * 7 / 6, + ) + assert strict_forward_vector.item() is True + + # Invalid use_direction check + with pytest.raises( + ValueError, + match="must be one of `forward_vector`.*but got invalid_direction", + ): + kinematics.detect_u_turns( + valid_data_array_for_u_turn_detection, + use_direction="invalid_direction", + ) diff --git a/tests/test_unit/test_kinematics/test_orientation.py b/tests/test_unit/test_kinematics/test_orientation.py index a7ed7660f..02a6f7a29 100644 --- a/tests/test_unit/test_kinematics/test_orientation.py +++ b/tests/test_unit/test_kinematics/test_orientation.py @@ -431,3 +431,25 @@ def test_casts_from_tuple( xr.testing.assert_allclose(pass_numpy, pass_tuple) xr.testing.assert_allclose(pass_numpy, pass_list) + + def test_nan_handling(self, spinning_on_the_spot: xr.DataArray) -> None: + """Ensure NaNs don't crash the angle computation.""" + data = spinning_on_the_spot.copy() + data[1, :, :] = np.nan + result = kinematics.compute_forward_vector_angle( + data, "left", "right", self.x_axis + ) + # NaN at index 1 + assert np.isnan(result[1]) + # Valid elsewhere + assert not np.isnan(result[0]) + + def test_invalid_dimensionality_raises(self) -> None: + """Invalid space dimension should raise a ValueError.""" + data = xr.DataArray( + np.zeros((5, 2, 3)), # 3D space is invalid + dims=["time", "keypoints", "space"], + coords={"space": ["x", "y", "z"], "keypoints": ["left", "right"]}, + ) + with pytest.raises(ValueError, match="2 spatial dimensions"): + kinematics.compute_forward_vector_angle(data, "left", "right")