Skip to content

Commit ce6414f

Browse files
Mika Sorvojanmaarnio
Mika Sorvoja
andauthored
perf, build(vector_processing): add optimized version of distance computation (#455)
* add dependency: Numba * improve documentation * comment out the old implementation Note: distance_to_anomaly, proximity_computation and proximity_to_anomaly all benefit from this optimization --------- Co-authored-by: Niko Aarnio <[email protected]>
1 parent 1783195 commit ce6414f

File tree

4 files changed

+356
-99
lines changed

4 files changed

+356
-99
lines changed

eis_toolkit/vector_processing/distance_computation.py

+258-43
Original file line numberDiff line numberDiff line change
@@ -4,89 +4,304 @@
44
import numpy as np
55
from beartype import beartype
66
from beartype.typing import Optional, Union
7+
from numba import njit, prange
78
from rasterio import profiles, transform
8-
from shapely.geometry.base import BaseGeometry, BaseMultipartGeometry
99

10-
from eis_toolkit.exceptions import EmptyDataFrameException, NonMatchingCrsException, NumericValueSignException
10+
from eis_toolkit import exceptions
1111
from eis_toolkit.utilities.checks.raster import check_raster_profile
12-
from eis_toolkit.utilities.miscellaneous import row_points
1312

1413

1514
@beartype
1615
def distance_computation(
1716
geodataframe: gpd.GeoDataFrame, raster_profile: Union[profiles.Profile, dict], max_distance: Optional[Number] = None
1817
) -> np.ndarray:
19-
"""Calculate distance from raster cell to nearest geometry.
18+
"""
19+
Calculate distance from each raster cell (centre) to the nearest input geometry.
20+
21+
Pixels on top of input geometries are assigned distance of 0.
22+
23+
Uses Numba to perform calculations quickly. The computation time increases (roughly)
24+
linearly with the amount of raster pixels defined by given `raster_profile`. Supports
25+
Polygon, MultiPolygon, LineString, MultiLineString, Point and MultiPoint geometries.
2026
2127
Args:
2228
geodataframe: The GeoDataFrame with geometries to determine distance to.
2329
raster_profile: The raster profile of the raster in which the distances
2430
to the nearest geometry are determined.
25-
max_distance: The maximum distance in the output array.
31+
max_distance: The maximum distance in the output array. Pixels beyond this
32+
distance will be assigned `max_distance` value.
2633
2734
Returns:
2835
A 2D numpy array with the distances computed.
2936
3037
Raises:
3138
NonMatchingCrsException: The input raster profile and geodataframe have mismatching CRS.
3239
EmptyDataFrameException: The input geodataframe is empty.
40+
NumericValueSignException: Max distance is defined and is not a positive number.
3341
"""
3442
if raster_profile.get("crs") != geodataframe.crs:
35-
raise NonMatchingCrsException("Expected coordinate systems to match between raster and GeoDataFrame.")
36-
if geodataframe.shape[0] == 0:
37-
raise EmptyDataFrameException("Expected GeoDataFrame to not be empty.")
43+
raise exceptions.NonMatchingCrsException(
44+
"Expected coordinate systems to match between raster and GeoDataFrame."
45+
)
46+
if geodataframe.empty:
47+
raise exceptions.EmptyDataFrameException("Expected GeoDataFrame to not be empty.")
3848
if max_distance is not None and max_distance <= 0:
39-
raise NumericValueSignException("Expected max distance to be a positive number.")
49+
raise exceptions.NumericValueSignException("Expected max distance to be a positive number.")
4050

4151
check_raster_profile(raster_profile=raster_profile)
4252

4353
raster_width = raster_profile.get("width")
4454
raster_height = raster_profile.get("height")
4555
raster_transform = raster_profile.get("transform")
4656

47-
distance_matrix = _distance_computation(
48-
raster_width=raster_width,
49-
raster_height=raster_height,
50-
raster_transform=raster_transform,
51-
geodataframe=geodataframe,
57+
# Generate the grid of raster cell center points
58+
raster_points = _generate_raster_points(raster_width, raster_height, raster_transform)
59+
60+
# Initialize lists needed for Numba-compatible calculations
61+
segment_coords = [] # These will also contain points coords, if present
62+
segment_indices = [0] # Start index
63+
polygon_coords = []
64+
polygon_indices = [0] # Start index
65+
66+
for geometry in geodataframe.geometry:
67+
if geometry.geom_type == "Polygon":
68+
coords = list(geometry.exterior.coords)
69+
for x, y in coords:
70+
polygon_coords.extend([x, y])
71+
polygon_indices.append(len(polygon_coords) // 2)
72+
segments = [
73+
(coords[i][0], coords[i][1], coords[i + 1][0], coords[i + 1][1]) for i in range(len(coords) - 1)
74+
]
75+
76+
elif geometry.geom_type == "MultiPolygon":
77+
# For MultiPolygon, iterate over each polygon
78+
segments = []
79+
for poly in geometry.geoms:
80+
coords = list(poly.exterior.coords)
81+
for x, y in coords:
82+
polygon_coords.extend([x, y])
83+
polygon_indices.append(len(polygon_coords) // 2)
84+
85+
# Add polygon boundary as segments for distance calculations
86+
segments.extend(
87+
[(coords[i][0], coords[i][1], coords[i + 1][0], coords[i + 1][1]) for i in range(len(coords) - 1)]
88+
)
89+
90+
elif geometry.geom_type == "LineString":
91+
coords = list(geometry.coords)
92+
segments = [
93+
(coords[i][0], coords[i][1], coords[i + 1][0], coords[i + 1][1]) for i in range(len(coords) - 1)
94+
]
95+
96+
elif geometry.geom_type == "MultiLineString":
97+
# For MultiLineString, iterate through each line string component
98+
segments = []
99+
for line in geometry.geoms:
100+
coords = list(line.coords)
101+
segments.extend(
102+
[(coords[i][0], coords[i][1], coords[i + 1][0], coords[i + 1][1]) for i in range(len(coords) - 1)]
103+
)
104+
105+
elif geometry.geom_type == "Point":
106+
segments = [(geometry.x, geometry.y)]
107+
108+
elif geometry.geom_type == "MultiPoint":
109+
# For MultiPoint, iterate over each point and add as individual (x, y) tuples
110+
segments = [(point.x, point.y) for point in geometry.geoms]
111+
112+
else:
113+
raise exceptions.GeometryTypeException(f"Encountered unsupported geometry type: {geometry.geom_type}.")
114+
115+
segment_coords.extend(segments)
116+
segment_indices.append(len(segment_coords)) # End index for this geometry's segments
117+
118+
# Convert all lists to numpy arrays
119+
segment_coords = np.array(segment_coords, dtype=np.float64)
120+
segment_indices = np.array(segment_indices, dtype=np.int64)
121+
polygon_coords = np.array(polygon_coords, dtype=np.float64)
122+
polygon_indices = np.array(polygon_indices, dtype=np.int64)
123+
124+
distance_matrix = _compute_distances_core(
125+
raster_points,
126+
segment_coords,
127+
segment_indices,
128+
polygon_coords,
129+
polygon_indices,
130+
raster_width,
131+
raster_height,
132+
max_distance,
52133
)
53-
if max_distance is not None:
54-
distance_matrix[distance_matrix > max_distance] = max_distance
55134

56135
return distance_matrix
57136

58137

59-
def _calculate_row_distances(
60-
row: int,
61-
cols: np.ndarray,
62-
raster_transform: transform.Affine,
63-
geometries_unary_union: Union[BaseGeometry, BaseMultipartGeometry],
138+
@njit(parallel=True)
139+
def _compute_distances_core(
140+
raster_points: np.ndarray,
141+
segment_coords: np.ndarray,
142+
segment_indices: np.ndarray,
143+
polygon_coords: np.ndarray,
144+
polygon_indices: np.ndarray,
145+
width: int,
146+
height: int,
147+
max_distance: Optional[Number],
64148
) -> np.ndarray:
65-
row_distances = np.array(
66-
[
67-
point.distance(geometries_unary_union)
68-
for point in row_points(row=row, cols=cols, raster_transform=raster_transform)
69-
]
70-
)
71-
return row_distances
149+
distance_matrix = np.full((height, width), np.inf)
150+
for i in prange(len(raster_points)):
151+
px, py = raster_points[i]
152+
min_dist = np.inf
72153

154+
# Check if the point is inside any polygon, if polygons are present
155+
if len(polygon_indices) > 1 and _point_in_polygon(px, py, polygon_coords, polygon_indices):
156+
min_dist = 0 # Set distance to zero if point is inside a polygon
157+
else:
158+
# Only calculate distance to segments if point is outside all polygons
159+
for j in range(len(segment_indices) - 1):
160+
for k in range(segment_indices[j], segment_indices[j + 1]):
161+
# Case 1: Point
162+
if len(segment_coords[k]) == 2:
163+
x1, y1 = segment_coords[k]
164+
dist = np.sqrt((px - x1) ** 2 + (py - y1) ** 2)
165+
# Case 2: Line segment
166+
else:
167+
x1, y1, x2, y2 = segment_coords[k]
168+
dist = _point_to_segment_distance(px, py, x1, y1, x2, y2)
169+
if dist < min_dist:
170+
min_dist = dist
73171

74-
def _distance_computation(
75-
raster_width: int, raster_height: int, raster_transform: transform.Affine, geodataframe: gpd.GeoDataFrame
76-
) -> np.ndarray:
172+
# Apply max_distance threshold if specified
173+
if max_distance is not None:
174+
min_dist = min(min_dist, max_distance)
77175

78-
cols = np.arange(raster_width)
79-
rows = np.arange(raster_height)
176+
# Update the distance matrix
177+
distance_matrix[i // width, i % width] = min_dist
178+
return distance_matrix
80179

81-
geometries_unary_union = geodataframe.geometry.unary_union
82180

83-
distance_matrix = np.array(
84-
[
85-
_calculate_row_distances(
86-
row=row, cols=cols, raster_transform=raster_transform, geometries_unary_union=geometries_unary_union
87-
)
88-
for row in rows
89-
]
90-
)
181+
def _generate_raster_points(width: int, height: int, affine_transform: transform.Affine) -> np.ndarray:
182+
"""Generate a full grid of points from the raster dimensions and affine transform."""
183+
cols, rows = np.meshgrid(np.arange(width), np.arange(height))
184+
cols = cols.ravel()
185+
rows = rows.ravel()
186+
xs, ys = transform.xy(affine_transform, rows, cols, offset="center")
187+
return np.column_stack([xs, ys])
91188

92-
return distance_matrix
189+
190+
@njit
191+
def _point_to_segment_distance(px: Number, py: Number, x1: Number, y1: Number, x2: Number, y2: Number) -> np.ndarray:
192+
"""Calculate the minimum distance from a point to a line segment."""
193+
dx, dy = x2 - x1, y2 - y1
194+
if dx == 0 and dy == 0:
195+
# Segment is a point (Should not happen)
196+
return np.sqrt((px - x1) ** 2 + (py - y1) ** 2)
197+
t = max(0, min(1, ((px - x1) * dx + (py - y1) * dy) / (dx * dx + dy * dy)))
198+
nearest_x, nearest_y = x1 + t * dx, y1 + t * dy
199+
return np.sqrt((px - nearest_x) ** 2 + (py - nearest_y) ** 2)
200+
201+
202+
@njit
203+
def _point_in_polygon(px: Number, py: Number, polygon_coords: np.ndarray, polygon_indices: np.ndarray) -> bool:
204+
"""Determine if a point is inside any polygon using the ray-casting algorithm."""
205+
for p_start, p_end in zip(polygon_indices[:-1], polygon_indices[1:]):
206+
inside = False
207+
xints = 0.0
208+
n = p_end - p_start
209+
p1x, p1y = polygon_coords[2 * p_start], polygon_coords[2 * p_start + 1]
210+
for i in range(n + 1):
211+
p2x, p2y = polygon_coords[2 * (p_start + i % n)], polygon_coords[2 * (p_start + i % n) + 1]
212+
if py > min(p1y, p2y):
213+
if py <= max(p1y, p2y):
214+
if px <= max(p1x, p2x):
215+
if p1y != p2y:
216+
xints = (py - p1y) * (p2x - p1x) / (p2y - p1y) + p1x
217+
if p1x == p2x or px <= xints:
218+
inside = not inside
219+
p1x, p1y = p2x, p2y
220+
if inside:
221+
return True
222+
return False
223+
224+
225+
# @beartype
226+
# def distance_computation(
227+
# geodataframe: gpd.GeoDataFrame,
228+
# raster_profile: Union[profiles.Profile, dict],
229+
# max_distance: Optional[Number] = None
230+
# ) -> np.ndarray:
231+
# """Calculate distance from raster cell to nearest geometry.
232+
233+
# Args:
234+
# geodataframe: The GeoDataFrame with geometries to determine distance to.
235+
# raster_profile: The raster profile of the raster in which the distances
236+
# to the nearest geometry are determined.
237+
# max_distance: The maximum distance in the output array.
238+
239+
# Returns:
240+
# A 2D numpy array with the distances computed.
241+
242+
# Raises:
243+
# NonMatchingCrsException: The input raster profile and geodataframe have mismatching CRS.
244+
# EmptyDataFrameException: The input geodataframe is empty.
245+
# NumericValueSignException: Max distance is defined and is not a positive number.
246+
# """
247+
# if raster_profile.get("crs") != geodataframe.crs:
248+
# raise exceptions.NonMatchingCrsException(
249+
# "Expected coordinate systems to match between raster and GeoDataFrame."
250+
# )
251+
# if geodataframe.shape[0] == 0:
252+
# raise exceptions.EmptyDataFrameException("Expected GeoDataFrame to not be empty.")
253+
# if max_distance is not None and max_distance <= 0:
254+
# raise exceptions.NumericValueSignException("Expected max distance to be a positive number.")
255+
256+
# check_raster_profile(raster_profile=raster_profile)
257+
258+
# raster_width = raster_profile.get("width")
259+
# raster_height = raster_profile.get("height")
260+
# raster_transform = raster_profile.get("transform")
261+
262+
# distance_matrix = _distance_computation(
263+
# raster_width=raster_width,
264+
# raster_height=raster_height,
265+
# raster_transform=raster_transform,
266+
# geodataframe=geodataframe,
267+
# )
268+
# if max_distance is not None:
269+
# distance_matrix[distance_matrix > max_distance] = max_distance
270+
271+
# return distance_matrix
272+
273+
274+
# def _calculate_row_distances(
275+
# row: int,
276+
# cols: np.ndarray,
277+
# raster_transform: transform.Affine,
278+
# geometries_unary_union: Union[BaseGeometry, BaseMultipartGeometry],
279+
# ) -> np.ndarray:
280+
# row_distances = np.array(
281+
# [
282+
# point.distance(geometries_unary_union)
283+
# for point in row_points(row=row, cols=cols, raster_transform=raster_transform)
284+
# ]
285+
# )
286+
# return row_distances
287+
288+
289+
# def _distance_computation(
290+
# raster_width: int, raster_height: int, raster_transform: transform.Affine, geodataframe: gpd.GeoDataFrame
291+
# ) -> np.ndarray:
292+
293+
# cols = np.arange(raster_width)
294+
# rows = np.arange(raster_height)
295+
296+
# geometries_unary_union = geodataframe.geometry.unary_union
297+
298+
# distance_matrix = np.array(
299+
# [
300+
# _calculate_row_distances(
301+
# row=row, cols=cols, raster_transform=raster_transform, geometries_unary_union=geometries_unary_union
302+
# )
303+
# for row in rows
304+
# ]
305+
# )
306+
307+
# return distance_matrix

environment.yml

+1
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,6 @@ dependencies:
2222
- imbalanced-learn >= 0.11.0
2323
- mapclassify >= 2.6.1
2424
- esda >= 2.5.1
25+
- numba >= 0.60.0
2526
# Dependencies for testing
2627
- pytest >=7.2.1

0 commit comments

Comments
 (0)