Skip to content

Commit 43f8919

Browse files
authoredFeb 22, 2024··
Add reclassify raster (#197)
1 parent 999424d commit 43f8919

File tree

4 files changed

+613
-0
lines changed

4 files changed

+613
-0
lines changed
 

‎docs/raster_processing/reclassify.md

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Reclassify raster
2+
3+
::: eis_toolkit.raster_processing.reclassify
+413
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,413 @@
1+
from numbers import Number
2+
3+
import mapclassify as mc
4+
import numpy as np
5+
import rasterio
6+
from beartype import beartype
7+
from beartype.typing import Optional, Sequence, Tuple
8+
9+
from eis_toolkit.utilities.checks.raster import check_raster_bands
10+
11+
12+
def _reclassify_with_manual_breaks( # type: ignore[no-any-unimported]
13+
band: np.ndarray,
14+
breaks: Sequence[int],
15+
) -> np.ndarray:
16+
17+
data = np.digitize(band, breaks)
18+
19+
return data
20+
21+
22+
@beartype
23+
def reclassify_with_manual_breaks( # type: ignore[no-any-unimported]
24+
raster: rasterio.io.DatasetReader,
25+
breaks: Sequence[int],
26+
bands: Optional[Sequence[int]] = None,
27+
) -> Tuple[np.ndarray, dict]:
28+
"""Classify raster with manual breaks.
29+
30+
If bands are not given, all bands are used for classification.
31+
32+
Args:
33+
raster: Raster to be classified.
34+
breaks: List of break values for the classification.
35+
bands: Selected bands from multiband raster. Indexing begins from one. Defaults to None.
36+
37+
Returns:
38+
Raster classified with manual breaks and metadata.
39+
40+
Raises:
41+
InvalidParameterValueException: Bands contain negative values.
42+
"""
43+
44+
if bands is None or len(bands) == 0:
45+
bands = range(1, raster.count + 1)
46+
else:
47+
check_raster_bands(raster, bands)
48+
49+
out_image = np.empty((len(bands), raster.height, raster.width))
50+
out_meta = raster.meta.copy()
51+
52+
for i, band in enumerate(bands):
53+
band_data = raster.read(band)
54+
out_image[i] = _reclassify_with_manual_breaks(band_data, breaks)
55+
56+
return out_image, out_meta
57+
58+
59+
def _reclassify_with_defined_intervals( # type: ignore[no-any-unimported]
60+
band: np.ndarray,
61+
interval_size: int,
62+
) -> np.ndarray:
63+
64+
_, edges = np.histogram(band, bins=interval_size)
65+
66+
data = np.digitize(band, edges)
67+
68+
return data
69+
70+
71+
@beartype
72+
def reclassify_with_defined_intervals( # type: ignore[no-any-unimported]
73+
raster: rasterio.io.DatasetReader,
74+
interval_size: int,
75+
bands: Optional[Sequence[int]] = None,
76+
) -> Tuple[np.ndarray, dict]:
77+
"""Classify raster with defined intervals.
78+
79+
If bands are not given, all bands are used for classification.
80+
81+
Args:
82+
raster: Raster to be classified.
83+
interval_size: The number of units in each interval.
84+
bands: Selected bands from multiband raster. Indexing begins from one. Defaults to None.
85+
86+
Returns:
87+
Raster classified with defined intervals and metadata.
88+
89+
Raises:
90+
InvalidParameterValueException: Bands contain negative values.
91+
"""
92+
93+
if bands is None or len(bands) == 0:
94+
bands = range(1, raster.count + 1)
95+
else:
96+
check_raster_bands(raster, bands)
97+
98+
out_image = np.empty((len(bands), raster.height, raster.width))
99+
out_meta = raster.meta.copy()
100+
101+
for i, band in enumerate(bands):
102+
band_data = raster.read(band)
103+
out_image[i] = _reclassify_with_defined_intervals(band_data, interval_size)
104+
105+
return out_image, out_meta
106+
107+
108+
def _reclassify_with_equal_intervals( # type: ignore[no-any-unimported]
109+
band: np.ndarray,
110+
number_of_intervals: int,
111+
) -> np.ndarray:
112+
113+
percentiles = np.linspace(0, 100, number_of_intervals)
114+
115+
intervals = np.percentile(band, percentiles)
116+
117+
data = np.digitize(band, intervals)
118+
119+
return data
120+
121+
122+
@beartype
123+
def reclassify_with_equal_intervals( # type: ignore[no-any-unimported]
124+
raster: rasterio.io.DatasetReader,
125+
number_of_intervals: int,
126+
bands: Optional[Sequence[int]] = None,
127+
) -> Tuple[np.ndarray, dict]:
128+
"""Classify raster with equal intervals.
129+
130+
If bands are not given, all bands are used for classification.
131+
132+
Args:
133+
raster: Raster to be classified.
134+
number_of_intervals: The number of intervals.
135+
bands: Selected bands from multiband raster. Indexing begins from one. Defaults to None.
136+
137+
Returns:
138+
Raster classified with equal intervals.
139+
140+
Raises:
141+
InvalidParameterValueException: Bands contain negative values.
142+
"""
143+
144+
if bands is None or len(bands) == 0:
145+
bands = range(1, raster.count + 1)
146+
else:
147+
check_raster_bands(raster, bands)
148+
149+
out_image = np.empty((len(bands), raster.height, raster.width))
150+
out_meta = raster.meta.copy()
151+
152+
for i, band in enumerate(bands):
153+
band_data = raster.read(band)
154+
out_image[i] = _reclassify_with_equal_intervals(band_data, number_of_intervals)
155+
156+
return out_image, out_meta
157+
158+
159+
def _reclassify_with_quantiles( # type: ignore[no-any-unimported]
160+
band: np.ndarray,
161+
number_of_quantiles: int,
162+
) -> np.ndarray:
163+
164+
intervals = [np.percentile(band, i * 100 / number_of_quantiles) for i in range(number_of_quantiles)]
165+
data = np.digitize(band, intervals)
166+
167+
return data
168+
169+
170+
@beartype
171+
def reclassify_with_quantiles( # type: ignore[no-any-unimported]
172+
raster: rasterio.io.DatasetReader,
173+
number_of_quantiles: int,
174+
bands: Optional[Sequence[int]] = None,
175+
) -> Tuple[np.ndarray, dict]:
176+
"""Classify raster with quantiles.
177+
178+
If bands are not given, all bands are used for classification.
179+
180+
Args:
181+
raster: Raster to be classified.
182+
number_of_quantiles: The number of quantiles.
183+
bands: Selected bands from multiband raster. Indexing begins from one. Defaults to None.
184+
185+
Returns:
186+
Raster classified with quantiles and metadata.
187+
188+
Raises:
189+
InvalidParameterValueException: Bands contain negative values.
190+
"""
191+
192+
if bands is None or len(bands) == 0:
193+
bands = range(1, raster.count + 1)
194+
else:
195+
check_raster_bands(raster, bands)
196+
197+
out_image = np.empty((len(bands), raster.height, raster.width))
198+
out_meta = raster.meta.copy()
199+
200+
for i, band in enumerate(bands):
201+
band_data = raster.read(band)
202+
out_image[i] = _reclassify_with_quantiles(band_data, number_of_quantiles)
203+
204+
return out_image, out_meta
205+
206+
207+
def _reclassify_with_natural_breaks( # type: ignore[no-any-unimported]
208+
band: np.ndarray,
209+
number_of_classes: int,
210+
) -> np.ndarray:
211+
212+
breaks = mc.JenksCaspall(band, number_of_classes)
213+
data = np.digitize(band, np.sort(breaks.bins))
214+
215+
return data
216+
217+
218+
@beartype
219+
def reclassify_with_natural_breaks( # type: ignore[no-any-unimported]
220+
raster: rasterio.io.DatasetReader,
221+
number_of_classes: int,
222+
bands: Optional[Sequence[int]] = None,
223+
) -> Tuple[np.ndarray, dict]:
224+
"""Classify raster with natural breaks (Jenks Caspall).
225+
226+
If bands are not given, all bands are used for classification.
227+
228+
Args:
229+
raster: Raster to be classified.
230+
number_of_classes: The number of classes.
231+
bands: Selected bands from multiband raster. Indexing begins from one. Defaults to None.
232+
233+
Returns:
234+
Raster classified with natural breaks (Jenks Caspall) and metadata.
235+
236+
Raises:
237+
InvalidParameterValueException: Bands contain negative values.
238+
"""
239+
240+
if bands is None or len(bands) == 0:
241+
bands = range(1, raster.count + 1)
242+
else:
243+
check_raster_bands(raster, bands)
244+
245+
out_image = np.empty((len(bands), raster.height, raster.width))
246+
out_meta = raster.meta.copy()
247+
248+
for i, band in enumerate(bands):
249+
band_data = raster.read(band)
250+
out_image[i] = _reclassify_with_natural_breaks(band_data, number_of_classes)
251+
252+
return out_image, out_meta
253+
254+
255+
def _reclassify_with_geometrical_intervals(
256+
band: np.ndarray, number_of_classes: int, nodata_value: Number
257+
) -> np.ndarray:
258+
259+
# nan_value is either a set integer (e.g. -9999) or np.nan
260+
mask = band == nodata_value
261+
masked_array = np.ma.masked_array(data=band, mask=mask)
262+
263+
median_value = np.ma.median(masked_array)
264+
max_value = masked_array.max()
265+
min_value = masked_array.min()
266+
267+
values_out = np.ma.zeros_like(masked_array)
268+
269+
# Determine the tail with larger length
270+
if (median_value - min_value) < (max_value - median_value): # Large end tail longer
271+
tail_values = masked_array[np.ma.where((masked_array > median_value))]
272+
range_tail = max_value - median_value
273+
tail_values = tail_values - median_value + range_tail / 1000.0
274+
else: # Small end tail longer
275+
tail_values = masked_array[np.ma.where((masked_array < median_value))]
276+
range_tail = median_value - min_value
277+
tail_values = tail_values - min_value + range_tail / 1000.0
278+
279+
min_tail = np.ma.min(tail_values)
280+
max_tail = np.ma.max(tail_values)
281+
282+
# number of classes
283+
factor = (max_tail / min_tail) ** (1 / number_of_classes)
284+
285+
interval_index = 1
286+
break_points_tail = [min_tail]
287+
break_points = [min_tail]
288+
width = [0]
289+
290+
while break_points[-1] < max_tail:
291+
interval_index += 1
292+
break_points.append(min_tail * factor ** (interval_index - 1))
293+
break_points_tail.append(break_points[-1])
294+
width.append(break_points_tail[-1] - break_points_tail[0])
295+
k = 0
296+
297+
for j in range(1, len(width) - 2):
298+
values_out[
299+
np.ma.where(((median_value + width[j]) < masked_array) & ((masked_array <= (median_value + width[j + 1]))))
300+
] = (j + 1)
301+
values_out[
302+
np.ma.where(((median_value - width[j]) > masked_array) & ((masked_array >= (median_value - width[j + 1]))))
303+
] = (-j - 1)
304+
k = j
305+
306+
values_out[np.ma.where(((median_value + width[k + 1]) < masked_array))] = k + 1
307+
values_out[np.ma.where(((median_value - width[k + 1]) > masked_array))] = -k - 1
308+
values_out[np.ma.where(median_value == masked_array)] = 0
309+
310+
output = np.array(values_out)
311+
312+
return output
313+
314+
315+
@beartype
316+
def reclassify_with_geometrical_intervals( # type: ignore[no-any-unimported]
317+
raster: rasterio.io.DatasetReader, number_of_classes: int, bands: Optional[Sequence[int]] = None
318+
) -> Tuple[np.ndarray, dict]:
319+
"""Classify raster with geometrical intervals.
320+
321+
If bands are not given, all bands are used for classification.
322+
323+
Args:
324+
raster: Raster to be classified.
325+
number_of_classes: The number of classes. The true number of classes is at most double the amount,
326+
depending how symmetrical the input data is.
327+
bands: Selected bands from multiband raster. Indexing begins from one. Defaults to None.
328+
329+
Returns:
330+
Raster classified with geometrical intervals and metadata.
331+
332+
Raises:
333+
InvalidParameterValueException: Bands contain negative values.
334+
"""
335+
336+
if bands is None or len(bands) == 0:
337+
bands = range(1, raster.count + 1)
338+
else:
339+
check_raster_bands(raster, bands)
340+
341+
out_image = np.empty((len(bands), raster.height, raster.width))
342+
out_meta = raster.meta.copy()
343+
nodata_value = raster.nodata
344+
345+
for i, band in enumerate(bands):
346+
band_data = raster.read(band)
347+
out_image[i] = _reclassify_with_geometrical_intervals(band_data, number_of_classes, nodata_value)
348+
349+
return out_image, out_meta
350+
351+
352+
def _reclassify_with_standard_deviation( # type: ignore[no-any-unimported]
353+
band: np.ndarray,
354+
number_of_intervals: int,
355+
) -> np.ndarray:
356+
357+
band_statistics = []
358+
359+
stddev = np.nanstd(band)
360+
mean = np.nanmean(band)
361+
band_statistics.append((mean, stddev))
362+
interval_size = 2 * stddev / number_of_intervals
363+
364+
classified = np.empty_like(band)
365+
366+
below_mean = band < (mean - stddev)
367+
above_mean = band > (mean + stddev)
368+
369+
classified[below_mean] = -number_of_intervals
370+
classified[above_mean] = number_of_intervals
371+
372+
in_between = ~below_mean & ~above_mean
373+
interval = ((band - (mean - stddev)) / interval_size).astype(int)
374+
classified[in_between] = interval[in_between] - number_of_intervals // 2
375+
376+
return classified
377+
378+
379+
@beartype
380+
def reclassify_with_standard_deviation( # type: ignore[no-any-unimported]
381+
raster: rasterio.io.DatasetReader,
382+
number_of_intervals: int,
383+
bands: Optional[Sequence[int]] = None,
384+
) -> Tuple[np.ndarray, dict]:
385+
"""Classify raster with standard deviation.
386+
387+
If bands are not given, all bands are used for classification.
388+
389+
Args:
390+
raster: Raster to be classified.
391+
number_of_intervals: The number of intervals.
392+
bands: Selected bands from multiband raster. Indexing begins from one. Defaults to None.
393+
394+
Returns:
395+
Raster classified with standard deviation and metadata.
396+
397+
Raises:
398+
InvalidParameterValueException: Bands contain negative values.
399+
"""
400+
401+
if bands is None or len(bands) == 0:
402+
bands = range(1, raster.count + 1)
403+
else:
404+
check_raster_bands(raster, bands)
405+
406+
out_image = np.empty((len(bands), raster.height, raster.width))
407+
out_meta = raster.meta.copy()
408+
409+
for i, band in enumerate(bands):
410+
band_data = raster.read(band)
411+
out_image[i] = _reclassify_with_standard_deviation(band_data, number_of_intervals)
412+
413+
return out_image, out_meta

‎environment.yml

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ dependencies:
2121
- rtree >= 1.0.1
2222
- typer >=0.9.0
2323
- imbalanced-learn >= 0.11.0
24+
- mapclassify >= 2.6.1
2425
- esda >= 2.5.1
2526
# Dependencies for testing
2627
- pytest >=7.2.1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import numpy as np
2+
import rasterio
3+
from beartype.typing import Tuple
4+
5+
from eis_toolkit.raster_processing import reclassify
6+
from tests.raster_processing.clip_test import raster_path as SMALL_RASTER_PATH
7+
8+
TEST_ARRAY = np.array([[0, 10, 20, 30], [40, 50, 50, 60], [80, 80, 90, 90], [100, 100, 100, 100]])
9+
10+
11+
def test_reclassify_with_defined_intervals():
12+
"""Test raster with defined intervals."""
13+
interval_size = 3
14+
15+
result = reclassify._reclassify_with_defined_intervals(TEST_ARRAY, interval_size)
16+
17+
expected_output = np.array([[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]])
18+
19+
assert isinstance(result, np.ndarray)
20+
21+
np.testing.assert_allclose(result, expected_output)
22+
23+
24+
def test_reclassify_with_defined_intervals_main():
25+
"""Test raster with defined intervals parameters."""
26+
with rasterio.open(SMALL_RASTER_PATH) as raster:
27+
result = reclassify.reclassify_with_defined_intervals(
28+
raster=raster,
29+
interval_size=3,
30+
bands=[1],
31+
)
32+
33+
assert isinstance(result, Tuple)
34+
assert isinstance(result[0], np.ndarray)
35+
assert isinstance(result[1], dict)
36+
37+
38+
def test_reclassify_with_equal_intervals():
39+
"""Test raster with equal intervals."""
40+
number_of_intervals = 10
41+
42+
result = reclassify._reclassify_with_equal_intervals(TEST_ARRAY, number_of_intervals)
43+
44+
expected_output = np.array([[1, 1, 2, 2], [3, 4, 4, 5], [6, 6, 7, 7], [10, 10, 10, 10]])
45+
46+
assert isinstance(result, np.ndarray)
47+
48+
np.testing.assert_allclose(result, expected_output)
49+
50+
51+
def test_reclassify_with_equal_intervals_main():
52+
"""Test raster with equal intervals parameters."""
53+
with rasterio.open(SMALL_RASTER_PATH) as raster:
54+
result = reclassify.reclassify_with_defined_intervals(
55+
raster=raster,
56+
interval_size=10,
57+
bands=[1],
58+
)
59+
assert isinstance(result, Tuple)
60+
assert isinstance(result[0], np.ndarray)
61+
assert isinstance(result[1], dict)
62+
63+
64+
def test_reclassify_with_geometrical_intervals():
65+
"""Test raster with geometrical intervals."""
66+
number_of_classes = 10
67+
nodata_value = -9999
68+
69+
array_with_nan_value = np.array(
70+
[[nodata_value, 10, 20, 30], [40, 50, 50, 60], [80, 80, 90, 90], [100, 100, 100, 100]]
71+
)
72+
73+
result = reclassify._reclassify_with_geometrical_intervals(array_with_nan_value, number_of_classes, nodata_value)
74+
75+
expected_output = np.array([[0, -9, -9, -9], [-9, -9, -9, -9], [0, 0, 8, 8], [9, 9, 9, 9]])
76+
77+
assert isinstance(result, np.ndarray)
78+
79+
np.testing.assert_allclose(result, expected_output)
80+
81+
82+
def test_reclassify_with_geometrical_intervals_main():
83+
"""Test raster with geometrical intervals parameters."""
84+
with rasterio.open(SMALL_RASTER_PATH) as raster:
85+
result = reclassify.reclassify_with_geometrical_intervals(
86+
raster=raster,
87+
number_of_classes=10,
88+
bands=[1],
89+
)
90+
assert isinstance(result, Tuple)
91+
assert isinstance(result[0], np.ndarray)
92+
assert isinstance(result[1], dict)
93+
94+
95+
def test_reclassify_with_manual_breaks():
96+
"""Test raster with manual break intervals."""
97+
breaks = [20, 40, 60, 80]
98+
99+
result = reclassify._reclassify_with_manual_breaks(TEST_ARRAY, breaks)
100+
101+
expected_output = np.array([[0, 0, 1, 1], [2, 2, 2, 3], [4, 4, 4, 4], [4, 4, 4, 4]])
102+
103+
assert isinstance(result, np.ndarray)
104+
105+
np.testing.assert_allclose(result, expected_output)
106+
107+
108+
def test_reclassify_with_manual_breaks_main():
109+
"""Test raster with manual break intervals parameters."""
110+
with rasterio.open(SMALL_RASTER_PATH) as raster:
111+
result = reclassify.reclassify_with_manual_breaks(
112+
raster=raster,
113+
breaks=[2, 5, 9],
114+
bands=[1],
115+
)
116+
assert isinstance(result, Tuple)
117+
assert isinstance(result[0], np.ndarray)
118+
assert isinstance(result[1], dict)
119+
120+
121+
def test_reclassify_with_natural_breaks():
122+
"""Test raster with natural breaks."""
123+
number_of_classes = 10
124+
125+
result = reclassify._reclassify_with_natural_breaks(TEST_ARRAY, number_of_classes)
126+
127+
expected_output = np.array([[0, 1, 1, 2], [3, 4, 4, 5], [6, 6, 7, 7], [8, 8, 8, 8]])
128+
129+
assert isinstance(result, np.ndarray)
130+
131+
np.testing.assert_allclose(result, expected_output)
132+
133+
134+
def test_reclassify_with_natural_breaks_main():
135+
"""Test raster with natural break intervals parameters."""
136+
with rasterio.open(SMALL_RASTER_PATH) as raster:
137+
result = reclassify.reclassify_with_natural_breaks(
138+
raster=raster,
139+
number_of_classes=10,
140+
bands=[1],
141+
)
142+
assert isinstance(result, Tuple)
143+
assert isinstance(result[0], np.ndarray)
144+
assert isinstance(result[1], dict)
145+
146+
147+
def test_reclassify_with_standard_deviation():
148+
"""Test raster with standard deviation intervals."""
149+
number_of_intervals = 75
150+
151+
result = reclassify._reclassify_with_standard_deviation(TEST_ARRAY, number_of_intervals)
152+
153+
expected_output = np.array([[-75, -75, -75, -36], [-25, -14, -14, -3], [20, 20, 31, 31], [75, 75, 75, 75]])
154+
155+
assert isinstance(result, np.ndarray)
156+
157+
np.testing.assert_allclose(result, expected_output)
158+
159+
160+
def test_reclassify_with_standard_deviation_main():
161+
"""Test raster with standard_deviation intervals parameters."""
162+
with rasterio.open(SMALL_RASTER_PATH) as raster:
163+
result = reclassify.reclassify_with_standard_deviation(
164+
raster=raster,
165+
number_of_intervals=75,
166+
bands=[1],
167+
)
168+
assert isinstance(result, Tuple)
169+
assert isinstance(result[0], np.ndarray)
170+
assert isinstance(result[1], dict)
171+
172+
173+
def test_reclassify_with_quantiles():
174+
"""Test raster with quantile intervals by."""
175+
number_of_quantiles = 4
176+
177+
result = reclassify._reclassify_with_quantiles(TEST_ARRAY, number_of_quantiles)
178+
179+
expected_output = np.array([[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]])
180+
181+
assert isinstance(result, np.ndarray)
182+
183+
np.testing.assert_allclose(result, expected_output)
184+
185+
186+
def test_reclassify_with_quantiles_main():
187+
"""Test raster with quantiles parameters."""
188+
with rasterio.open(SMALL_RASTER_PATH) as raster:
189+
result = reclassify.reclassify_with_quantiles(
190+
raster=raster,
191+
number_of_quantiles=4,
192+
bands=[1],
193+
)
194+
assert isinstance(result, Tuple)
195+
assert isinstance(result[0], np.ndarray)
196+
assert isinstance(result[1], dict)

0 commit comments

Comments
 (0)
Please sign in to comment.