|
1 | 1 | """Filter and interpolate tracks in ``movement`` datasets.""" |
2 | 2 |
|
| 3 | +import logging |
3 | 4 | import warnings |
4 | 5 | from typing import Literal |
5 | 6 |
|
6 | 7 | import xarray as xr |
7 | 8 | from scipy import signal |
8 | 9 |
|
| 10 | +from movement.kinematics import compute_displacement |
9 | 11 | from movement.utils.logging import log_error, log_to_attrs |
10 | 12 | from movement.utils.reports import report_nan_values |
| 13 | +from movement.utils.vector import compute_norm |
| 14 | + |
| 15 | +logger = logging.getLogger(__name__) |
11 | 16 |
|
12 | 17 |
|
13 | 18 | @log_to_attrs |
@@ -60,6 +65,97 @@ def filter_by_confidence( |
60 | 65 | return data_filtered |
61 | 66 |
|
62 | 67 |
|
| 68 | +@log_to_attrs |
| 69 | +def filter_by_displacement( |
| 70 | + position: xr.DataArray, |
| 71 | + threshold: float = 10.0, |
| 72 | + direction: Literal["forward", "bidirectional"] = "forward", |
| 73 | + print_report: bool = False, |
| 74 | +) -> xr.DataArray: |
| 75 | + """Filter data points based on displacement magnitude. |
| 76 | +
|
| 77 | + Frames in the ``position`` array that exceed a displacement magnitude |
| 78 | + threshold are set to NaN. |
| 79 | +
|
| 80 | + Two modes are supported via the ``direction`` parameter: |
| 81 | + - "forward" (default): A point at time ``t`` is set to NaN if it has |
| 82 | + moved more than the ``threshold`` Euclidean distance from the same |
| 83 | + point at time ``t-1`` (i.e., if ``|pos(t) - pos(t-1)| > threshold``). |
| 84 | + - "bidirectional": A point at time ``t`` is set to NaN only if BOTH the |
| 85 | + displacement from ``t-1`` to ``t`` AND the displacement from ``t`` to |
| 86 | + ``t+1`` exceed the threshold (i.e., if |
| 87 | + ``|pos(t) - pos(t-1)| > threshold`` AND |
| 88 | + ``|pos(t+1) - pos(t)| > threshold``). This corresponds to the |
| 89 | + Stage 1 outlier detection described by the user. |
| 90 | +
|
| 91 | + Parameters |
| 92 | + ---------- |
| 93 | + position : xr.DataArray |
| 94 | + The input data containing position information, with ``time`` |
| 95 | + and ``space`` (in Cartesian coordinates) as required dimensions. |
| 96 | + threshold : float, optional |
| 97 | + The maximum Euclidean distance allowed for displacement. |
| 98 | + Defaults to 10.0. |
| 99 | + direction : Literal["forward", "bidirectional"], optional |
| 100 | + The directionality of the displacement check. Defaults to "forward". |
| 101 | + print_report : bool, optional |
| 102 | + Whether to print a report of the number of NaN values before and after |
| 103 | + filtering. Defaults to False. |
| 104 | +
|
| 105 | + Returns |
| 106 | + ------- |
| 107 | + xr.DataArray |
| 108 | + The filtered position array, where points exceeding the displacement |
| 109 | + threshold condition have been set to NaN. |
| 110 | +
|
| 111 | + See Also |
| 112 | + -------- |
| 113 | + movement.kinematics.compute_displacement: |
| 114 | + The function used to compute an array of displacement vectors. |
| 115 | + movement.utils.vector.compute_norm: |
| 116 | + The function used to compute distance as the magnitude of |
| 117 | + displacement vectors. |
| 118 | +
|
| 119 | + """ |
| 120 | + if not isinstance(position, xr.DataArray): |
| 121 | + raise log_error( |
| 122 | + TypeError, "Input 'position' must be an xarray.DataArray." |
| 123 | + ) |
| 124 | + |
| 125 | + # Calculate forward displacement magnitude: |
| 126 | + # norm at time t = |pos(t) - pos(t-1)| |
| 127 | + displacement_fwd = compute_displacement(position) |
| 128 | + mag_fwd = compute_norm(displacement_fwd) |
| 129 | + |
| 130 | + if direction == "forward": |
| 131 | + # Uni-directional: Keep if magnitude from t-1 to t is below threshold |
| 132 | + condition = mag_fwd < threshold |
| 133 | + elif direction == "bidirectional": |
| 134 | + # Bi-directional: Keep unless BOTH jump in and jump out are large. |
| 135 | + # Equivalent to: Keep if jump in is small OR jump out is small. |
| 136 | + # Calculate backward magnitude: |
| 137 | + # norm at t+1 related to t = |pos(t+1) - pos(t)| |
| 138 | + # mag_bwd[t] = mag_fwd[t+1] |
| 139 | + mag_bwd = mag_fwd.shift(time=-1, fill_value=0) |
| 140 | + condition = (mag_fwd < threshold) | (mag_bwd < threshold) |
| 141 | + else: |
| 142 | + raise log_error( |
| 143 | + ValueError, |
| 144 | + ( |
| 145 | + f"Invalid direction: {direction}. " |
| 146 | + f"Must be 'forward' or 'bidirectional'." |
| 147 | + ), |
| 148 | + ) |
| 149 | + |
| 150 | + position_filtered = position.where(condition) |
| 151 | + |
| 152 | + if print_report: |
| 153 | + print(report_nan_values(position, "input")) |
| 154 | + print(report_nan_values(position_filtered, "output")) |
| 155 | + |
| 156 | + return position_filtered |
| 157 | + |
| 158 | + |
63 | 159 | @log_to_attrs |
64 | 160 | def interpolate_over_time( |
65 | 161 | data: xr.DataArray, |
|
0 commit comments