Skip to content

Commit bfd17c5

Browse files
feat: add bidirectional option to filter_by_displacement
1 parent d2cd4e6 commit bfd17c5

File tree

1 file changed

+96
-0
lines changed

1 file changed

+96
-0
lines changed

movement/filtering.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
"""Filter and interpolate tracks in ``movement`` datasets."""
22

3+
import logging
34
import warnings
45
from typing import Literal
56

67
import xarray as xr
78
from scipy import signal
89

10+
from movement.kinematics import compute_displacement
911
from movement.utils.logging import log_error, log_to_attrs
1012
from movement.utils.reports import report_nan_values
13+
from movement.utils.vector import compute_norm
14+
15+
logger = logging.getLogger(__name__)
1116

1217

1318
@log_to_attrs
@@ -60,6 +65,97 @@ def filter_by_confidence(
6065
return data_filtered
6166

6267

68+
@log_to_attrs
69+
def filter_by_displacement(
70+
position: xr.DataArray,
71+
threshold: float = 10.0,
72+
direction: Literal["forward", "bidirectional"] = "forward",
73+
print_report: bool = False,
74+
) -> xr.DataArray:
75+
"""Filter data points based on displacement magnitude.
76+
77+
Frames in the ``position`` array that exceed a displacement magnitude
78+
threshold are set to NaN.
79+
80+
Two modes are supported via the ``direction`` parameter:
81+
- "forward" (default): A point at time ``t`` is set to NaN if it has
82+
moved more than the ``threshold`` Euclidean distance from the same
83+
point at time ``t-1`` (i.e., if ``|pos(t) - pos(t-1)| > threshold``).
84+
- "bidirectional": A point at time ``t`` is set to NaN only if BOTH the
85+
displacement from ``t-1`` to ``t`` AND the displacement from ``t`` to
86+
``t+1`` exceed the threshold (i.e., if
87+
``|pos(t) - pos(t-1)| > threshold`` AND
88+
``|pos(t+1) - pos(t)| > threshold``). This corresponds to the
89+
Stage 1 outlier detection described by the user.
90+
91+
Parameters
92+
----------
93+
position : xr.DataArray
94+
The input data containing position information, with ``time``
95+
and ``space`` (in Cartesian coordinates) as required dimensions.
96+
threshold : float, optional
97+
The maximum Euclidean distance allowed for displacement.
98+
Defaults to 10.0.
99+
direction : Literal["forward", "bidirectional"], optional
100+
The directionality of the displacement check. Defaults to "forward".
101+
print_report : bool, optional
102+
Whether to print a report of the number of NaN values before and after
103+
filtering. Defaults to False.
104+
105+
Returns
106+
-------
107+
xr.DataArray
108+
The filtered position array, where points exceeding the displacement
109+
threshold condition have been set to NaN.
110+
111+
See Also
112+
--------
113+
movement.kinematics.compute_displacement:
114+
The function used to compute an array of displacement vectors.
115+
movement.utils.vector.compute_norm:
116+
The function used to compute distance as the magnitude of
117+
displacement vectors.
118+
119+
"""
120+
if not isinstance(position, xr.DataArray):
121+
raise log_error(
122+
TypeError, "Input 'position' must be an xarray.DataArray."
123+
)
124+
125+
# Calculate forward displacement magnitude:
126+
# norm at time t = |pos(t) - pos(t-1)|
127+
displacement_fwd = compute_displacement(position)
128+
mag_fwd = compute_norm(displacement_fwd)
129+
130+
if direction == "forward":
131+
# Uni-directional: Keep if magnitude from t-1 to t is below threshold
132+
condition = mag_fwd < threshold
133+
elif direction == "bidirectional":
134+
# Bi-directional: Keep unless BOTH jump in and jump out are large.
135+
# Equivalent to: Keep if jump in is small OR jump out is small.
136+
# Calculate backward magnitude:
137+
# norm at t+1 related to t = |pos(t+1) - pos(t)|
138+
# mag_bwd[t] = mag_fwd[t+1]
139+
mag_bwd = mag_fwd.shift(time=-1, fill_value=0)
140+
condition = (mag_fwd < threshold) | (mag_bwd < threshold)
141+
else:
142+
raise log_error(
143+
ValueError,
144+
(
145+
f"Invalid direction: {direction}. "
146+
f"Must be 'forward' or 'bidirectional'."
147+
),
148+
)
149+
150+
position_filtered = position.where(condition)
151+
152+
if print_report:
153+
print(report_nan_values(position, "input"))
154+
print(report_nan_values(position_filtered, "output"))
155+
156+
return position_filtered
157+
158+
63159
@log_to_attrs
64160
def interpolate_over_time(
65161
data: xr.DataArray,

0 commit comments

Comments
 (0)