Skip to content

Commit 890fcf2

Browse files
mipesonmaarnio
authored andcommitted
feat(IDW): Add search radius parameter
1 parent 5ca25b3 commit 890fcf2

File tree

7 files changed

+143
-109
lines changed

7 files changed

+143
-109
lines changed

eis_toolkit/cli.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -1960,6 +1960,7 @@ def idw_interpolation_cli(
19601960
pixel_size: float = None,
19611961
extent: Tuple[float, float, float, float] = (None, None, None, None),
19621962
power: float = 2.0,
1963+
search_radius: Optional[float] = None,
19631964
):
19641965
"""Apply inverse distance weighting (IDW) interpolation to input vector file."""
19651966
from eis_toolkit.exceptions import InvalidParameterValueException
@@ -1985,7 +1986,13 @@ def idw_interpolation_cli(
19851986
with rasterio.open(base_raster) as raster:
19861987
profile = raster.profile.copy()
19871988

1988-
out_image = idw(geodataframe=geodataframe, target_column=target_column, raster_profile=profile, power=power)
1989+
out_image = idw(
1990+
geodataframe=geodataframe,
1991+
target_column=target_column,
1992+
raster_profile=profile,
1993+
power=power,
1994+
search_radius=search_radius,
1995+
)
19891996
typer.echo("Progress: 75%")
19901997

19911998
profile["count"] = 1

eis_toolkit/vector_processing/idw_interpolation.py

+35-10
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import geopandas as gpd
44
import numpy as np
55
from beartype import beartype
6-
from beartype.typing import Union
6+
from beartype.typing import Optional, Union
77
from rasterio import profiles, transform
88

99
from eis_toolkit.exceptions import EmptyDataFrameException, InvalidParameterValueException, NonMatchingCrsException
@@ -18,6 +18,7 @@ def _idw_interpolation(
1818
raster_height: int,
1919
raster_transform: transform.Affine,
2020
power: Number,
21+
search_radius: Optional[Number],
2122
) -> np.ndarray:
2223

2324
points = np.array(geodataframe.geometry.apply(lambda geom: (geom.x, geom.y)).tolist())
@@ -34,26 +35,47 @@ def _idw_interpolation(
3435
y = np.linspace(grid_y_min, grid_y_max, raster_height)
3536
y = y[::-1].reshape(-1, 1)
3637

37-
interpolated_values = _idw_core(points[:, 0], points[:, 1], values, x, y, power)
38+
interpolated_values = _idw_core(points[:, 0], points[:, 1], values, x, y, power, search_radius)
3839
interpolated_values = interpolated_values.reshape(raster_height, raster_width)
3940

4041
return interpolated_values
4142

4243

4344
# Distance calculations
44-
def _idw_core(x, y, z, xi, yi: np.ndarray, power: Number) -> np.ndarray:
45+
def _idw_core(
46+
x: np.ndarray,
47+
y: np.ndarray,
48+
z: np.ndarray,
49+
xi: np.ndarray,
50+
yi: np.ndarray,
51+
power: Number,
52+
search_radius: Optional[Number],
53+
) -> np.ndarray:
4554
over = np.zeros((len(yi), len(xi)))
4655
under = np.zeros((len(yi), len(xi)))
4756
for n in range(len(x)):
4857
dist = np.hypot(xi - x[n], yi - y[n])
49-
# Add a small epsilon to avoid division by zero
50-
dist = np.where(dist == 0, 1e-12, dist)
51-
dist = dist**power
5258

53-
over += z[n] / dist
54-
under += 1.0 / dist
59+
# Exclude points outside search radius
60+
if search_radius is not None:
61+
mask = dist <= search_radius
62+
if not np.any(mask):
63+
continue
64+
65+
# Add a small epsilon to avoid division by zero
66+
dist = np.where(dist[mask] == 0, 1e-12, dist[mask]) ** power
67+
68+
over[mask] += z[n] / dist
69+
under[mask] += 1.0 / dist
70+
71+
else:
72+
# Add a small epsilon to avoid division by zero
73+
dist = np.where(dist == 0, 1e-12, dist) ** power
74+
75+
over += z[n] / dist
76+
under += 1.0 / dist
5577

56-
interpolated_values = over / under
78+
interpolated_values = np.divide(over, under, out=np.full_like(over, np.nan), where=under != 0)
5779
return interpolated_values
5880

5981

@@ -63,6 +85,7 @@ def idw(
6385
target_column: str,
6486
raster_profile: Union[profiles.Profile, dict],
6587
power: Number = 2,
88+
search_radius: Optional[Number] = None,
6689
) -> np.ndarray:
6790
"""Calculate inverse distance weighted (IDW) interpolation.
6891
@@ -73,6 +96,8 @@ def idw(
7396
crs, transform, width and height.
7497
power: The value for determining the rate at which the weights decrease. As power increases,
7598
the weights for distant points decrease rapidly. Defaults to 2.
99+
search_radius: The search radius within which to consider points for interpolation.
100+
If None, all points are used.
76101
77102
Returns:
78103
Numpy array containing the interpolated values.
@@ -97,7 +122,7 @@ def idw(
97122
raster_transform = raster_profile.get("transform")
98123

99124
interpolated_values = _idw_interpolation(
100-
geodataframe, target_column, raster_width, raster_height, raster_transform, power
125+
geodataframe, target_column, raster_width, raster_height, raster_transform, power, search_radius
101126
)
102127

103128
return interpolated_values

notebooks/testing_idw.ipynb

+81-98
Large diffs are not rendered by default.
Binary file not shown.
Binary file not shown.

tests/data/remote/interpolating/interpolation_test_data_small.gpkg-wal

Whitespace-only changes.

tests/vector_processing/idw_interpolation_test.py

+19
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
test_dir = Path(__file__).parent.parent
1515
idw_test_data = test_dir.joinpath("data/remote/interpolating/idw_test_data.tif")
16+
idw_radius_test_data = test_dir.joinpath("data/remote/interpolating/idw_radius_test_data.tif")
1617

1718

1819
@pytest.fixture
@@ -74,6 +75,24 @@ def test_validated_points_with_extent(validated_points, raster_profile):
7475
np.testing.assert_almost_equal(interpolated_values, external_values, decimal=2)
7576

7677

78+
def test_validated_points_with_radius(validated_points, raster_profile):
79+
"""Test IDW with search radius."""
80+
target_column = "random_number"
81+
interpolated_values = idw(
82+
geodataframe=validated_points,
83+
target_column=target_column,
84+
raster_profile=raster_profile,
85+
power=2,
86+
search_radius=0.5,
87+
)
88+
assert target_column in validated_points.columns
89+
90+
with rasterio.open(idw_radius_test_data) as src:
91+
external_values = src.read(1)
92+
93+
np.testing.assert_almost_equal(interpolated_values, external_values, decimal=2)
94+
95+
7796
def test_invalid_column(test_points, raster_profile):
7897
"""Test invalid column GeoDataFrame."""
7998
target_column = "not-in-data-column"

0 commit comments

Comments
 (0)