Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update _slice_collection to use toBands and select #91

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 10 additions & 22 deletions xee/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,31 +679,19 @@ def _slice_collection(self, image_slice: slice) -> ee.Image:
# Get the right range of Images in the collection, either a single image or
# a range of images...
start, stop, stride = image_slice.indices(self.shape[0])

# If the input images have IDs, just slice them. Otherwise, we need to do
# an expensive `toList()` operation.
if self.store.image_ids:
imgs = self.store.image_ids[start:stop:stride]
selectors = list(range(start, stop, stride))
col = self.store.image_collection.select(self.variable_name)
if self.shape[0] <= 5000: # 5000 == max bands in an Image
col_as_image = col.toBands()
return col_as_image.select(selectors)
elif stop < 5000: # 5000 == max bands in an Image
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wont the stride effect this number?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. Even if the start/stride arguments would make the final image have fewer than 5000 bands, I'm avoiding toList by converting to a multiband image before slicing so I need to grab the first "stop" bands and toBands will fail if it would return an image with more than 5000 bands.

col_as_image = col.limit(stop).toBands()
return col_as_image.select(selectors)
else:
# TODO(alxr, mahrsee): Find a way to make this case more efficient.
list_range = stop - start
col0 = self.store.image_collection
imgs = col0.toList(list_range, offset=start).slice(0, list_range, stride)

col = ee.ImageCollection(imgs)

# For a more efficient slice of the series of images, we reduce each
# image in the collection to bands on a single image.
def reduce_bands(x, acc):
return ee.Image(acc).addBands(x, [self.variable_name])

aggregate_images_as_bands = ee.Image(col.iterate(reduce_bands, ee.Image()))
# Remove the first "constant" band from the reduction.
target_image = aggregate_images_as_bands.select(
aggregate_images_as_bands.bandNames().slice(1)
)

return target_image
imgs = col.toList(list_range, offset=start).slice(0, list_range, stride)
return ee.ImageCollection(imgs).toBands()

def _raw_indexing_method(
self, key: tuple[Union[int, slice], ...]
Expand Down
63 changes: 63 additions & 0 deletions xee/ext_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,19 @@ def test_can_chunk__opened_dataset(self):
except ValueError:
self.fail('Chunking failed.')

def test_can_slice_past_5000(self):
ds = xr.open_dataset(
'NASA/GPM_L3/IMERG_V06',
crs='EPSG:4326',
scale=0.25,
engine=xee.EarthEngineBackendEntrypoint,
).isel(time=slice(4999, 5001), lon=slice(0, 1), lat=slice(0, 1))

try:
ds.chunk().compute()
except ValueError:
self.fail('Chunking failed.')

def test_honors_geometry(self):
ic = ee.ImageCollection('ECMWF/ERA5_LAND/HOURLY').filterDate(
'1992-10-05', '1993-03-31'
Expand Down Expand Up @@ -378,5 +391,55 @@ def test_validate_band_attrs(self):
for _, value in variable.attrs.items():
self.assertIsInstance(value, valid_types)

def test_rename_bands(self):
point = ee.Geometry.Point((-122.45, 37.79))
col = ee.ImageCollection('LANDSAT/LC08/C02/T1_L2').filterBounds(point)
col = col.map(lambda im: im.regexpRename('$', '_new'))
b1, b2 = col.first().bandNames().getInfo()[:2]

ds = xr.open_dataset(
col,
engine=xee.EarthEngineBackendEntrypoint,
scale=120,
crs='epsg:32610',
geometry=point.buffer(512).bounds(),
)

ds['sum'] = ds[b1] + ds[b2]

self.assertTrue('sum' in ds)

def test_add_new_bands(self):
s2 = ee.ImageCollection('COPERNICUS/S2_HARMONIZED')
geometry = ee.Geometry.Polygon([[
[82.60642647743225, 27.16350437805251],
[82.60984897613525, 27.1618529901377],
[82.61088967323303, 27.163695288375266],
[82.60757446289062, 27.16517483230927]
]])
filtered = s2 \
.filter(ee.Filter.date('2017-01-01', '2018-01-01')) \
.filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', 30)) \
.filter(ee.Filter.bounds(geometry))

def addNDVI(image):
ndvi = image.normalizedDifference(['B8', 'B4']).rename('ndvi')
return image.addBands(ndvi)

withNdvi = filtered.map(addNDVI)

ds = xr.open_dataset(
withNdvi.select('ndvi'),
engine=xee.EarthEngineBackendEntrypoint,
crs='EPSG:3857',
scale=10,
geometry=geometry,
)

original_ts = ds.ndvi.chunk('auto')
original_ts = original_ts.interp(X=82.607376, Y=27.164335)

self.assertIsInstance(original_ts.values, np.ndarray)

if __name__ == '__main__':
absltest.main()