Skip to content

Commit 81a0b31

Browse files
committed
feat(invalidate): add --cleanup-snapshots flag for scoped snapshot cleanup
Adds a `--cleanup-snapshots` flag to `sqlmesh invalidate` that immediately deletes physical snapshot tables exclusively owned by the target environment, without affecting snapshots shared with other environments (e.g. prod). Previously, users had to run `sqlmesh janitor --ignore-ttl` separately after invalidating, which performed a global cleanup across all environments. The new flag provides a scoped alternative that: 1. Captures the environment's snapshot IDs before invalidation 2. Filters to only those not referenced by any other active environment 3. Drops the physical tables and removes the state records for those snapshots Changes: - cli/main.py: add --cleanup-snapshots flag to the invalidate command - core/context.py: pass cleanup_snapshots through to invalidate_environment - core/janitor.py: add delete_snapshots_for_environment() helper function - core/state_sync/base.py: add target_snapshot_ids param to get/delete_expired_snapshots - core/state_sync/db/facade.py: thread target_snapshot_ids through facade - core/state_sync/db/snapshot.py: filter expired query by target_snapshot_ids when provided - core/state_sync/cache.py: add target_snapshot_ids param to CachingStateSync Closes #5844 Signed-off-by: mday-io <mdaytn@gmail.com>
1 parent 7b7e539 commit 81a0b31

9 files changed

Lines changed: 347 additions & 4 deletions

File tree

sqlmesh/cli/main.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,11 @@ def run(ctx: click.Context, environment: t.Optional[str] = None, **kwargs: t.Any
620620
is_flag=True,
621621
help="Wait for the environment to be deleted before returning. If not specified, the environment will be deleted asynchronously by the janitor process. This option requires a connection to the data warehouse.",
622622
)
623+
@click.option(
624+
"--cleanup-snapshots",
625+
is_flag=True,
626+
help="After invalidating, immediately delete physical snapshot tables that are exclusively owned by this environment (not referenced by any other environment). Implies --sync.",
627+
)
623628
@click.pass_context
624629
@error_handler
625630
@cli_analytics

sqlmesh/core/context.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,11 @@
108108
StateReader,
109109
StateSync,
110110
)
111-
from sqlmesh.core.janitor import cleanup_expired_views, delete_expired_snapshots
111+
from sqlmesh.core.janitor import (
112+
cleanup_expired_views,
113+
delete_expired_snapshots,
114+
delete_snapshots_for_environment,
115+
)
112116
from sqlmesh.core.table_diff import TableDiff
113117
from sqlmesh.core.test import (
114118
ModelTextTestResult,
@@ -1835,18 +1839,44 @@ def apply(
18351839
)
18361840

18371841
@python_api_analytics
1838-
def invalidate_environment(self, name: str, sync: bool = False) -> None:
1842+
def invalidate_environment(
1843+
self, name: str, sync: bool = False, cleanup_snapshots: bool = False
1844+
) -> None:
18391845
"""Invalidates the target environment by setting its expiration timestamp to now.
18401846
18411847
Args:
18421848
name: The name of the environment to invalidate.
18431849
sync: If True, the call blocks until the environment is deleted. Otherwise, the environment will
18441850
be deleted asynchronously by the janitor process.
1851+
cleanup_snapshots: If True, immediately deletes physical snapshot tables that are exclusively
1852+
owned by this environment (not referenced by any other environment). Implies sync=True.
18451853
"""
18461854
name = Environment.sanitize_name(name)
1855+
1856+
if cleanup_snapshots:
1857+
# Capture snapshot IDs before invalidation so we can scope the cleanup afterwards.
1858+
env = self.state_sync.get_environment(name)
1859+
target_snapshot_ids = {s.snapshot_id for s in env.snapshots} if env else set()
1860+
18471861
self.state_sync.invalidate_environment(name)
1848-
if sync:
1862+
1863+
if sync or cleanup_snapshots:
18491864
self._cleanup_environments(name=name)
1865+
if cleanup_snapshots and target_snapshot_ids:
1866+
failures = delete_snapshots_for_environment(
1867+
self.state_sync,
1868+
self.snapshot_evaluator,
1869+
target_snapshot_ids,
1870+
console=self.console,
1871+
)
1872+
if failures:
1873+
summary = "\n".join(failures)
1874+
if self.config.janitor.warn_on_delete_failure:
1875+
self.console.log_warning(
1876+
f"Snapshot cleanup completed with failures:\n{summary}"
1877+
)
1878+
else:
1879+
raise SQLMeshError(f"Snapshot cleanup completed with failures:\n{summary}")
18501880
self.console.log_success(f"Environment '{name}' deleted.")
18511881
else:
18521882
self.console.log_success(f"Environment '{name}' invalidated.")

sqlmesh/core/janitor.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from sqlmesh.core.console import Console
99
from sqlmesh.core.dialect import schema_
1010
from sqlmesh.core.environment import Environment
11-
from sqlmesh.core.snapshot import SnapshotEvaluator
11+
from sqlmesh.core.snapshot import SnapshotEvaluator, SnapshotId
1212
from sqlmesh.core.state_sync import StateSync
1313
from sqlmesh.core.state_sync.common import (
1414
logger,
@@ -193,3 +193,76 @@ def delete_expired_snapshots(
193193
failures.append(message)
194194
logger.info("Cleaned up %s expired snapshots", num_expired_snapshots)
195195
return failures
196+
197+
198+
def delete_snapshots_for_environment(
199+
state_sync: StateSync,
200+
snapshot_evaluator: SnapshotEvaluator,
201+
target_snapshot_ids: t.Collection[SnapshotId],
202+
*,
203+
force_delete: bool = False,
204+
console: t.Optional[Console] = None,
205+
) -> t.List[str]:
206+
"""Delete snapshots that are exclusively owned by a specific (now-deleted) environment.
207+
208+
This performs a scoped cleanup: only the provided snapshot IDs are considered for deletion,
209+
and only those that are not referenced by any remaining active environment will be removed.
210+
211+
Args:
212+
state_sync: StateSync instance to query and delete snapshot state from.
213+
snapshot_evaluator: SnapshotEvaluator instance to clean up physical tables.
214+
target_snapshot_ids: The snapshot IDs to consider for deletion (typically from the
215+
environment that was just invalidated/deleted).
216+
force_delete: If True, delete snapshot state records even when physical table cleanup fails.
217+
console: Optional console for reporting progress.
218+
219+
Returns:
220+
List of failure messages encountered during cleanup.
221+
"""
222+
if not target_snapshot_ids:
223+
return []
224+
225+
failures: t.List[str] = []
226+
batch = state_sync.get_expired_snapshots(
227+
ignore_ttl=True,
228+
batch_range=ExpiredBatchRange.all_batch_range(),
229+
target_snapshot_ids=target_snapshot_ids,
230+
)
231+
if batch is None:
232+
return failures
233+
234+
logger.info(
235+
"Cleaning up %s snapshots exclusively owned by invalidated environment",
236+
len(batch.expired_snapshot_ids),
237+
)
238+
239+
cleanup_succeeded = True
240+
if batch.cleanup_tasks:
241+
try:
242+
snapshot_evaluator.cleanup(
243+
target_snapshots=batch.cleanup_tasks,
244+
on_complete=console.update_cleanup_progress if console else None,
245+
)
246+
except Exception as failed_drops:
247+
message = f"Failed to clean up: {failed_drops}"
248+
logger.warning(message)
249+
failures.append(message)
250+
cleanup_succeeded = False
251+
252+
if cleanup_succeeded or force_delete:
253+
try:
254+
state_sync.delete_expired_snapshots(
255+
batch_range=ExpiredBatchRange.all_batch_range(),
256+
ignore_ttl=True,
257+
target_snapshot_ids=target_snapshot_ids,
258+
)
259+
logger.info(
260+
"Cleaned up %s snapshots from invalidated environment",
261+
len(batch.expired_snapshot_ids),
262+
)
263+
except Exception as e:
264+
message = f"Failed to delete snapshot state records: {e}"
265+
logger.warning(message)
266+
failures.append(message)
267+
268+
return failures

sqlmesh/core/state_sync/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,13 +308,16 @@ def get_expired_snapshots(
308308
batch_range: ExpiredBatchRange,
309309
current_ts: t.Optional[int] = None,
310310
ignore_ttl: bool = False,
311+
target_snapshot_ids: t.Optional[t.Collection[SnapshotIdLike]] = None,
311312
) -> t.Optional[ExpiredSnapshotBatch]:
312313
"""Returns a single batch of expired snapshots ordered by (updated_ts, name, identifier).
313314
314315
Args:
315316
current_ts: Timestamp used to evaluate expiration.
316317
ignore_ttl: If True, include snapshots regardless of TTL (only checks if unreferenced).
317318
batch_range: The range of the batch to fetch.
319+
target_snapshot_ids: If provided, only consider snapshots with these IDs. Useful for
320+
scoped cleanup after environment invalidation.
318321
319322
Returns:
320323
A batch describing expired snapshots or None if no snapshots are pending cleanup.
@@ -368,6 +371,7 @@ def delete_expired_snapshots(
368371
batch_range: ExpiredBatchRange,
369372
ignore_ttl: bool = False,
370373
current_ts: t.Optional[int] = None,
374+
target_snapshot_ids: t.Optional[t.Collection[SnapshotIdLike]] = None,
371375
) -> None:
372376
"""Removes expired snapshots.
373377
@@ -379,6 +383,8 @@ def delete_expired_snapshots(
379383
ignore_ttl: Ignore the TTL on the snapshot when considering it expired. This has the effect of deleting
380384
all snapshots that are not referenced in any environment
381385
current_ts: Timestamp used to evaluate expiration.
386+
target_snapshot_ids: If provided, only delete snapshots with these IDs. Useful for
387+
scoped cleanup after environment invalidation.
382388
"""
383389

384390
@abc.abstractmethod

sqlmesh/core/state_sync/cache.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,14 @@ def delete_expired_snapshots(
113113
batch_range: ExpiredBatchRange,
114114
ignore_ttl: bool = False,
115115
current_ts: t.Optional[int] = None,
116+
target_snapshot_ids: t.Optional[t.Collection[SnapshotIdLike]] = None,
116117
) -> None:
117118
self.snapshot_cache.clear()
118119
self.state_sync.delete_expired_snapshots(
119120
batch_range=batch_range,
120121
ignore_ttl=ignore_ttl,
121122
current_ts=current_ts,
123+
target_snapshot_ids=target_snapshot_ids,
122124
)
123125

124126
def add_snapshots_intervals(self, snapshots_intervals: t.Sequence[SnapshotIntervals]) -> None:

sqlmesh/core/state_sync/db/facade.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,13 +267,15 @@ def get_expired_snapshots(
267267
batch_range: ExpiredBatchRange,
268268
current_ts: t.Optional[int] = None,
269269
ignore_ttl: bool = False,
270+
target_snapshot_ids: t.Optional[t.Collection[SnapshotIdLike]] = None,
270271
) -> t.Optional[ExpiredSnapshotBatch]:
271272
current_ts = current_ts or now_timestamp()
272273
return self.snapshot_state.get_expired_snapshots(
273274
environments=self.environment_state.get_environments(),
274275
current_ts=current_ts,
275276
ignore_ttl=ignore_ttl,
276277
batch_range=batch_range,
278+
target_snapshot_ids=target_snapshot_ids,
277279
)
278280

279281
def get_expired_environments(
@@ -287,11 +289,13 @@ def delete_expired_snapshots(
287289
batch_range: ExpiredBatchRange,
288290
ignore_ttl: bool = False,
289291
current_ts: t.Optional[int] = None,
292+
target_snapshot_ids: t.Optional[t.Collection[SnapshotIdLike]] = None,
290293
) -> None:
291294
batch = self.get_expired_snapshots(
292295
ignore_ttl=ignore_ttl,
293296
current_ts=current_ts,
294297
batch_range=batch_range,
298+
target_snapshot_ids=target_snapshot_ids,
295299
)
296300
if batch and batch.expired_snapshot_ids:
297301
self.snapshot_state.delete_snapshots(batch.expired_snapshot_ids)

sqlmesh/core/state_sync/db/snapshot.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ def get_expired_snapshots(
170170
current_ts: int,
171171
ignore_ttl: bool,
172172
batch_range: ExpiredBatchRange,
173+
target_snapshot_ids: t.Optional[t.Collection[SnapshotIdLike]] = None,
173174
) -> t.Optional[ExpiredSnapshotBatch]:
174175
expired_query = exp.select("name", "identifier", "version", "updated_ts").from_(
175176
self.snapshots_table
@@ -180,6 +181,17 @@ def get_expired_snapshots(
180181
(exp.column("updated_ts") + exp.column("ttl_ms")) <= current_ts
181182
)
182183

184+
if target_snapshot_ids is not None:
185+
target_conditions = list(
186+
snapshot_id_filter(
187+
self.engine_adapter,
188+
target_snapshot_ids,
189+
batch_size=self.SNAPSHOT_BATCH_SIZE,
190+
)
191+
)
192+
if target_conditions:
193+
expired_query = expired_query.where(exp.or_(*target_conditions))
194+
183195
expired_query = expired_query.where(batch_range.where_filter)
184196

185197
promoted_snapshot_ids = {

tests/core/integration/test_aux_commands.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,76 @@ def test_invalidating_environment(sushi_context: Context):
481481
assert start_schemas - schemas_after_janitor == {"sushi__dev"}
482482

483483

484+
def test_invalidate_environment_cleanup_snapshots_scoped(tmp_path: Path):
485+
"""Test that --cleanup-snapshots only deletes snapshots exclusively owned by the invalidated env."""
486+
models_dir = tmp_path / "models"
487+
models_dir.mkdir()
488+
(models_dir / "model1.sql").write_text("MODEL(name test.model1, kind FULL); SELECT 1 AS col")
489+
(models_dir / "model2.sql").write_text("MODEL(name test.model2, kind FULL); SELECT 2 AS col")
490+
491+
ctx = Context(
492+
paths=[tmp_path],
493+
config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")),
494+
)
495+
496+
# Apply both models to prod and dev.
497+
ctx.plan("prod", no_prompts=True, auto_apply=True)
498+
ctx.plan("dev", no_prompts=True, auto_apply=True, include_unmodified=True)
499+
500+
prod_env = ctx.state_sync.get_environment("prod")
501+
dev_env = ctx.state_sync.get_environment("dev")
502+
assert prod_env is not None
503+
assert dev_env is not None
504+
505+
prod_snapshot_ids = {s.snapshot_id for s in prod_env.snapshots}
506+
dev_snapshot_ids = {s.snapshot_id for s in dev_env.snapshots}
507+
508+
# In a virtual environment, dev shares snapshots with prod.
509+
# Shared snapshots must NOT be deleted when invalidating dev with --cleanup-snapshots.
510+
shared_snapshot_ids = prod_snapshot_ids & dev_snapshot_ids
511+
512+
ctx.invalidate_environment("dev", cleanup_snapshots=True)
513+
514+
# The dev environment record should be gone.
515+
assert ctx.state_sync.get_environment("dev") is None
516+
517+
# Shared snapshots (also in prod) must still exist.
518+
remaining_snapshots = ctx.state_sync.get_snapshots(list(shared_snapshot_ids))
519+
assert set(remaining_snapshots.keys()) == shared_snapshot_ids
520+
521+
# Prod environment should be unaffected.
522+
assert ctx.state_sync.get_environment("prod") is not None
523+
524+
525+
def test_invalidate_environment_cleanup_snapshots_exclusive(tmp_path: Path):
526+
"""Test that --cleanup-snapshots deletes snapshots exclusively owned by the invalidated env."""
527+
models_dir = tmp_path / "models"
528+
models_dir.mkdir()
529+
(models_dir / "model1.sql").write_text("MODEL(name test.model1, kind FULL); SELECT 1 AS col")
530+
531+
ctx = Context(
532+
paths=[tmp_path],
533+
config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")),
534+
)
535+
536+
# Apply model1 to dev only (not prod). These snapshots will be exclusively owned by dev.
537+
ctx.plan("dev", no_prompts=True, auto_apply=True)
538+
539+
dev_env = ctx.state_sync.get_environment("dev")
540+
assert dev_env is not None
541+
dev_snapshot_ids = {s.snapshot_id for s in dev_env.snapshots}
542+
assert dev_snapshot_ids
543+
544+
ctx.invalidate_environment("dev", cleanup_snapshots=True)
545+
546+
# The dev environment record should be gone.
547+
assert ctx.state_sync.get_environment("dev") is None
548+
549+
# All dev-exclusive snapshots should have been deleted.
550+
remaining_snapshots = ctx.state_sync.get_snapshots(list(dev_snapshot_ids))
551+
assert not remaining_snapshots
552+
553+
484554
@time_machine.travel("2023-01-08 15:00:00 UTC")
485555
def test_evaluate_uncategorized_snapshot(init_and_plan_context: t.Callable):
486556
context, plan = init_and_plan_context("examples/sushi")

0 commit comments

Comments
 (0)