diff --git a/movement/filtering.py b/movement/filtering.py index 0361d9099..4faab3ff7 100644 --- a/movement/filtering.py +++ b/movement/filtering.py @@ -6,8 +6,10 @@ import xarray as xr from scipy import signal +from movement.kinematics import compute_displacement from movement.utils.logging import log_to_attrs, logger from movement.utils.reports import report_nan_values +from movement.utils.vector import compute_norm @log_to_attrs @@ -60,6 +62,92 @@ def filter_by_confidence( return data_filtered +@log_to_attrs +def filter_by_displacement( + position: xr.DataArray, + threshold: float = 10.0, + direction: Literal["forward", "bidirectional"] = "forward", + print_report: bool = False, +) -> xr.DataArray: + """Filter data points based on displacement magnitude. + + Frames in the ``position`` array that exceed a displacement magnitude + threshold are set to NaN. + + Two modes are supported via the ``direction`` parameter: + - "forward" (default): A point at time ``t`` is set to NaN if it has + moved more than the ``threshold`` Euclidean distance from the same + point at time ``t-1`` (i.e., if ``|pos(t) - pos(t-1)| > threshold``). + - "bidirectional": A point at time ``t`` is set to NaN only if BOTH the + displacement from ``t-1`` to ``t`` AND the displacement from ``t`` to + ``t+1`` exceed the threshold (i.e., if + ``|pos(t) - pos(t-1)| > threshold`` AND + ``|pos(t+1) - pos(t)| > threshold``). This corresponds to the + Stage 1 outlier detection described by the user. + + Parameters + ---------- + position : xarray.DataArray + The input data containing position information, with ``time`` + and ``space`` (in Cartesian coordinates) as required dimensions. + threshold : float, optional + The maximum Euclidean distance allowed for displacement. + Defaults to 10.0. + direction : Literal["forward", "bidirectional"], optional + The directionality of the displacement check. Defaults to "forward". + print_report : bool, optional + Whether to print a report of the number of NaN values before and after + filtering. Defaults to False. + + Returns + ------- + xr.DataArray + The filtered position array, where points exceeding the displacement + threshold condition have been set to NaN. + + See Also + -------- + movement.kinematics.compute_displacement: + The function used to compute an array of displacement vectors. + movement.utils.vector.compute_norm: + The function used to compute distance as the magnitude of + displacement vectors. + + """ + if not isinstance(position, xr.DataArray): + raise TypeError("Input 'position' must be an xarray.DataArray.") + + # Calculate forward displacement magnitude: + # norm at time t = |pos(t) - pos(t-1)| + displacement_fwd = compute_displacement(position) + mag_fwd = compute_norm(displacement_fwd) + + if direction == "forward": + # Uni-directional: Keep if magnitude from t-1 to t is below threshold + condition = mag_fwd < threshold + elif direction == "bidirectional": + # Bi-directional: Keep unless BOTH jump in and jump out are large. + # Equivalent to: Keep if jump in is small OR jump out is small. + # Calculate backward magnitude: + # norm at t+1 related to t = |pos(t+1) - pos(t)| + # mag_bwd[t] = mag_fwd[t+1] + mag_bwd = mag_fwd.shift(time=-1, fill_value=0) + condition = (mag_fwd < threshold) | (mag_bwd < threshold) + else: + raise ValueError( + f"Invalid direction: {direction}. " + f"Must be 'forward' or 'bidirectional'." + ) + + position_filtered = position.where(condition) + + if print_report: + print(report_nan_values(position, "input")) + print(report_nan_values(position_filtered, "output")) + + return position_filtered + + @log_to_attrs def interpolate_over_time( data: xr.DataArray, @@ -177,15 +265,24 @@ def rolling_filter( """ half_window = window // 2 - data_windows = data.pad( # Pad the edges to avoid NaNs + # Pad the edges to avoid NaNs before applying rolling window + # Transpose ensures padding happens correctly regardless of dim order + padded_data = data.transpose("time", ...).pad( time=half_window, mode="reflect" - ).rolling( # Take rolling windows across time + ) + # Apply rolling window across time on padded data + data_windows = padded_data.rolling( time=window, center=True, min_periods=min_periods ) # Compute the statistic over each window allowed_statistics = ["mean", "median", "max", "min"] if statistic not in allowed_statistics: + raise ValueError( + f"Invalid statistic '{statistic}'. " + f"Must be one of {allowed_statistics}." + ) # <-- Corrected: Added closing parenthesis + raise logger.error( ValueError( f"Invalid statistic '{statistic}'. " @@ -195,8 +292,15 @@ def rolling_filter( data_rolled = getattr(data_windows, statistic)(skipna=True) - # Remove the padded edges - data_rolled = data_rolled.isel(time=slice(half_window, -half_window)) + # Remove the padded edges by slicing + # Ensure the slice matches the original time dimension size + original_time_size = data.sizes["time"] + data_rolled = data_rolled.isel( + time=slice(half_window, half_window + original_time_size) + ) + + # Transpose back to original dimension order + data_rolled = data_rolled.transpose(*data.dims) # Optional: Print NaN report if print_report: @@ -257,15 +361,25 @@ def savgol_filter( """ if "axis" in kwargs: + raise ValueError("The 'axis' argument may not be overridden.") + raise logger.error( ValueError("The 'axis' argument may not be overridden.") ) + data_smoothed = data.copy() + # Find the axis index corresponding to the 'time' dimension + try: + time_axis = data.dims.index("time") + except ValueError as e: + raise ValueError("Input data must have a 'time' dimension.") from e + + # Apply savgol_filter along the identified time axis data_smoothed.values = signal.savgol_filter( - data, + data.values, # Pass numpy array to savgol_filter window, polyorder, - axis=0, + axis=time_axis, **kwargs, ) if print_report: