Skip to content

Commit a258911

Browse files
committed
Add partition_update_enabled option
1 parent 1ea6c5d commit a258911

File tree

3 files changed

+66
-23
lines changed

3 files changed

+66
-23
lines changed

src/pypgstac/python/pypgstac/load.py

+45-22
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ def load_partition(
270270
partition: Partition,
271271
items: Iterable[Dict[str, Any]],
272272
insert_mode: Optional[Methods] = Methods.insert,
273+
partition_update_enabled: Optional[bool] = True,
273274
) -> None:
274275
"""Load items data for a single partition."""
275276
conn = self.db.connect()
@@ -441,12 +442,17 @@ def load_partition(
441442
"Available modes are insert, ignore, upsert, and delsert."
442443
f"You entered {insert_mode}.",
443444
)
444-
cur.execute("SELECT update_partition_stats_q(%s);",(partition.name,))
445+
if partition_update_enabled:
446+
cur.execute("SELECT update_partition_stats_q(%s);",(partition.name,))
445447
logger.debug(
446448
f"Copying data for {partition} took {time.perf_counter() - t} seconds",
447449
)
448450

449-
def _partition_update(self, item: Dict[str, Any]) -> str:
451+
def _partition_update(
452+
self,
453+
item: Dict[str, Any],
454+
update_enabled: Optional[bool] = True,
455+
) -> str:
450456
"""Update the cached partition with the item information and return the name.
451457
452458
This method will mark the partition as dirty if the bounds of the partition
@@ -512,20 +518,24 @@ def _partition_update(self, item: Dict[str, Any]) -> str:
512518
partition = self._partition_cache[partition_name]
513519

514520
if partition:
515-
# Only update the partition if the item is outside the current bounds
516-
if item["datetime"] < partition.datetime_range_min:
517-
partition.datetime_range_min = item["datetime"]
518-
partition.requires_update = True
519-
if item["datetime"] > partition.datetime_range_max:
520-
partition.datetime_range_max = item["datetime"]
521-
partition.requires_update = True
522-
if item["end_datetime"] < partition.end_datetime_range_min:
523-
partition.end_datetime_range_min = item["end_datetime"]
524-
partition.requires_update = True
525-
if item["end_datetime"] > partition.end_datetime_range_max:
526-
partition.end_datetime_range_max = item["end_datetime"]
527-
partition.requires_update = True
521+
if update_enabled:
522+
# Only update the partition if the item is outside the current bounds
523+
if item["datetime"] < partition.datetime_range_min:
524+
partition.datetime_range_min = item["datetime"]
525+
partition.requires_update = True
526+
if item["datetime"] > partition.datetime_range_max:
527+
partition.datetime_range_max = item["datetime"]
528+
partition.requires_update = True
529+
if item["end_datetime"] < partition.end_datetime_range_min:
530+
partition.end_datetime_range_min = item["end_datetime"]
531+
partition.requires_update = True
532+
if item["end_datetime"] > partition.end_datetime_range_max:
533+
partition.end_datetime_range_max = item["end_datetime"]
534+
partition.requires_update = True
528535
else:
536+
if not update_enabled:
537+
raise Exception(f"Partition {partition_name} does not exist.")
538+
529539
# No partition exists yet; create a new one from item
530540
partition = Partition(
531541
name=partition_name,
@@ -541,7 +551,11 @@ def _partition_update(self, item: Dict[str, Any]) -> str:
541551

542552
return partition_name
543553

544-
def read_dehydrated(self, file: Union[Path, str] = "stdin") -> Generator:
554+
def read_dehydrated(
555+
self,
556+
file: Union[Path, str] = "stdin",
557+
partition_update_enabled: Optional[bool] = True,
558+
) -> Generator:
545559
if file is None:
546560
file = "stdin"
547561
if isinstance(file, str):
@@ -572,15 +586,21 @@ def read_dehydrated(self, file: Union[Path, str] = "stdin") -> Generator:
572586
item[field] = content_value
573587
else:
574588
item[field] = tab_split[i]
575-
item["partition"] = self._partition_update(item)
589+
item["partition"] = self._partition_update(
590+
item,
591+
partition_update_enabled,
592+
)
576593
yield item
577594

578595
def read_hydrated(
579-
self, file: Union[Path, str, Iterator[Any]] = "stdin",
596+
self,
597+
file: Union[Path, str,
598+
Iterator[Any]] = "stdin",
599+
partition_update_enabled: Optional[bool] = True,
580600
) -> Generator:
581601
for line in read_json(file):
582602
item = self.format_item(line)
583-
item["partition"] = self._partition_update(item)
603+
item["partition"] = self._partition_update(item, partition_update_enabled)
584604
yield item
585605

586606
def load_items(
@@ -589,6 +609,7 @@ def load_items(
589609
insert_mode: Optional[Methods] = Methods.insert,
590610
dehydrated: Optional[bool] = False,
591611
chunksize: Optional[int] = 10000,
612+
partition_update_enabled: Optional[bool] = True,
592613
) -> None:
593614
"""Load items json records."""
594615
self.check_version()
@@ -599,15 +620,17 @@ def load_items(
599620
self._partition_cache = {}
600621

601622
if dehydrated and isinstance(file, str):
602-
items = self.read_dehydrated(file)
623+
items = self.read_dehydrated(file, partition_update_enabled)
603624
else:
604-
items = self.read_hydrated(file)
625+
items = self.read_hydrated(file, partition_update_enabled)
605626

606627
for chunkin in chunked_iterable(items, chunksize):
607628
chunk = list(chunkin)
608629
chunk.sort(key=lambda x: x["partition"])
609630
for k, g in itertools.groupby(chunk, lambda x: x["partition"]):
610-
self.load_partition(self._partition_cache[k], g, insert_mode)
631+
self.load_partition(
632+
self._partition_cache[k], g, insert_mode, partition_update_enabled,
633+
)
611634

612635
logger.debug(f"Adding data to database took {time.perf_counter() - t} seconds.")
613636

src/pypgstac/python/pypgstac/pypgstac.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,16 @@ def load(
6363
method: Optional[Methods] = Methods.insert,
6464
dehydrated: Optional[bool] = False,
6565
chunksize: Optional[int] = 10000,
66+
partition_update_enabled: Optional[bool] = True,
6667
) -> None:
6768
"""Load collections or items into PGStac."""
6869
loader = Loader(db=self._db)
6970
if table == "collections":
7071
loader.load_collections(file, method)
7172
if table == "items":
72-
loader.load_items(file, method, dehydrated, chunksize)
73+
loader.load_items(
74+
file, method, dehydrated, chunksize, partition_update_enabled,
75+
)
7376

7477
def loadextensions(self) -> None:
7578
conn = self._db.connect()

src/pypgstac/tests/test_load.py

+17
Original file line numberDiff line numberDiff line change
@@ -441,3 +441,20 @@ def test_load_items_nopartitionconstraint_succeeds(loader: Loader) -> None:
441441
str(TEST_ITEMS),
442442
insert_mode=Methods.upsert,
443443
)
444+
445+
446+
def test_load_items_when_partition_creation_disabled(loader: Loader) -> None:
447+
"""
448+
Test pypgstac items loader raises an exception when partition
449+
does not exist and partition creation is disabled.
450+
"""
451+
loader.load_collections(
452+
str(TEST_COLLECTIONS_JSON),
453+
insert_mode=Methods.insert,
454+
)
455+
with pytest.raises(ValueError):
456+
loader.load_items(
457+
str(TEST_ITEMS),
458+
insert_mode=Methods.insert,
459+
partition_update_enabled=False,
460+
)

0 commit comments

Comments
 (0)