Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 120 additions & 6 deletions movement/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import xarray as xr
from scipy import signal

from movement.kinematics import compute_displacement
from movement.utils.logging import log_to_attrs, logger
from movement.utils.reports import report_nan_values
from movement.utils.vector import compute_norm


@log_to_attrs
Expand Down Expand Up @@ -60,6 +62,92 @@
return data_filtered


@log_to_attrs
def filter_by_displacement(
position: xr.DataArray,
threshold: float = 10.0,
direction: Literal["forward", "bidirectional"] = "forward",
print_report: bool = False,
) -> xr.DataArray:
"""Filter data points based on displacement magnitude.

Frames in the ``position`` array that exceed a displacement magnitude
threshold are set to NaN.

Two modes are supported via the ``direction`` parameter:
- "forward" (default): A point at time ``t`` is set to NaN if it has
moved more than the ``threshold`` Euclidean distance from the same
point at time ``t-1`` (i.e., if ``|pos(t) - pos(t-1)| > threshold``).
- "bidirectional": A point at time ``t`` is set to NaN only if BOTH the
displacement from ``t-1`` to ``t`` AND the displacement from ``t`` to
``t+1`` exceed the threshold (i.e., if
``|pos(t) - pos(t-1)| > threshold`` AND
``|pos(t+1) - pos(t)| > threshold``). This corresponds to the
Stage 1 outlier detection described by the user.

Parameters
----------
position : xarray.DataArray
The input data containing position information, with ``time``
and ``space`` (in Cartesian coordinates) as required dimensions.
threshold : float, optional
The maximum Euclidean distance allowed for displacement.
Defaults to 10.0.
direction : Literal["forward", "bidirectional"], optional
The directionality of the displacement check. Defaults to "forward".
print_report : bool, optional
Whether to print a report of the number of NaN values before and after
filtering. Defaults to False.

Returns
-------
xr.DataArray
The filtered position array, where points exceeding the displacement
threshold condition have been set to NaN.

See Also
--------
movement.kinematics.compute_displacement:
The function used to compute an array of displacement vectors.
movement.utils.vector.compute_norm:
The function used to compute distance as the magnitude of
displacement vectors.

"""
if not isinstance(position, xr.DataArray):
raise TypeError("Input 'position' must be an xarray.DataArray.")

Check warning on line 118 in movement/filtering.py

View check run for this annotation

Codecov / codecov/patch

movement/filtering.py#L117-L118

Added lines #L117 - L118 were not covered by tests

# Calculate forward displacement magnitude:
# norm at time t = |pos(t) - pos(t-1)|
displacement_fwd = compute_displacement(position)
mag_fwd = compute_norm(displacement_fwd)

Check warning on line 123 in movement/filtering.py

View check run for this annotation

Codecov / codecov/patch

movement/filtering.py#L122-L123

Added lines #L122 - L123 were not covered by tests

if direction == "forward":

Check warning on line 125 in movement/filtering.py

View check run for this annotation

Codecov / codecov/patch

movement/filtering.py#L125

Added line #L125 was not covered by tests
# Uni-directional: Keep if magnitude from t-1 to t is below threshold
condition = mag_fwd < threshold
elif direction == "bidirectional":

Check warning on line 128 in movement/filtering.py

View check run for this annotation

Codecov / codecov/patch

movement/filtering.py#L127-L128

Added lines #L127 - L128 were not covered by tests
# Bi-directional: Keep unless BOTH jump in and jump out are large.
# Equivalent to: Keep if jump in is small OR jump out is small.
# Calculate backward magnitude:
# norm at t+1 related to t = |pos(t+1) - pos(t)|
# mag_bwd[t] = mag_fwd[t+1]
mag_bwd = mag_fwd.shift(time=-1, fill_value=0)
condition = (mag_fwd < threshold) | (mag_bwd < threshold)

Check warning on line 135 in movement/filtering.py

View check run for this annotation

Codecov / codecov/patch

movement/filtering.py#L134-L135

Added lines #L134 - L135 were not covered by tests
else:
raise ValueError(

Check warning on line 137 in movement/filtering.py

View check run for this annotation

Codecov / codecov/patch

movement/filtering.py#L137

Added line #L137 was not covered by tests
f"Invalid direction: {direction}. "
f"Must be 'forward' or 'bidirectional'."
)

position_filtered = position.where(condition)

Check warning on line 142 in movement/filtering.py

View check run for this annotation

Codecov / codecov/patch

movement/filtering.py#L142

Added line #L142 was not covered by tests

if print_report:
print(report_nan_values(position, "input"))
print(report_nan_values(position_filtered, "output"))

Check warning on line 146 in movement/filtering.py

View check run for this annotation

Codecov / codecov/patch

movement/filtering.py#L144-L146

Added lines #L144 - L146 were not covered by tests

return position_filtered

Check warning on line 148 in movement/filtering.py

View check run for this annotation

Codecov / codecov/patch

movement/filtering.py#L148

Added line #L148 was not covered by tests


@log_to_attrs
def interpolate_over_time(
data: xr.DataArray,
Expand Down Expand Up @@ -177,15 +265,24 @@

"""
half_window = window // 2
data_windows = data.pad( # Pad the edges to avoid NaNs
# Pad the edges to avoid NaNs before applying rolling window
# Transpose ensures padding happens correctly regardless of dim order
padded_data = data.transpose("time", ...).pad(
time=half_window, mode="reflect"
).rolling( # Take rolling windows across time
)
# Apply rolling window across time on padded data
data_windows = padded_data.rolling(
time=window, center=True, min_periods=min_periods
)

# Compute the statistic over each window
allowed_statistics = ["mean", "median", "max", "min"]
if statistic not in allowed_statistics:
raise ValueError(
f"Invalid statistic '{statistic}'. "
f"Must be one of {allowed_statistics}."
) # <-- Corrected: Added closing parenthesis

raise logger.error(
ValueError(
f"Invalid statistic '{statistic}'. "
Expand All @@ -195,8 +292,15 @@

data_rolled = getattr(data_windows, statistic)(skipna=True)

# Remove the padded edges
data_rolled = data_rolled.isel(time=slice(half_window, -half_window))
# Remove the padded edges by slicing
# Ensure the slice matches the original time dimension size
original_time_size = data.sizes["time"]
data_rolled = data_rolled.isel(
time=slice(half_window, half_window + original_time_size)
)

# Transpose back to original dimension order
data_rolled = data_rolled.transpose(*data.dims)

# Optional: Print NaN report
if print_report:
Expand Down Expand Up @@ -257,15 +361,25 @@

"""
if "axis" in kwargs:
raise ValueError("The 'axis' argument may not be overridden.")

raise logger.error(
ValueError("The 'axis' argument may not be overridden.")
)

data_smoothed = data.copy()
# Find the axis index corresponding to the 'time' dimension
try:
time_axis = data.dims.index("time")
except ValueError as e:
raise ValueError("Input data must have a 'time' dimension.") from e

Check warning on line 375 in movement/filtering.py

View check run for this annotation

Codecov / codecov/patch

movement/filtering.py#L374-L375

Added lines #L374 - L375 were not covered by tests

# Apply savgol_filter along the identified time axis
data_smoothed.values = signal.savgol_filter(
data,
data.values, # Pass numpy array to savgol_filter
window,
polyorder,
axis=0,
axis=time_axis,
**kwargs,
)
if print_report:
Expand Down
Loading