Skip to content

Commit 6a36444

Browse files
Add batch context object to microbatch jinja context (#11031) (#11064)
* Add `batch_id` to jinja context of microbatch batches * Add changie doc * Update `format_batch_start` to assume `batch_start` is always provided * Add "runtime only" property `batch_context` to `ModelNode` By it being "runtime only" we mean that it doesn't exist on the artifact and thus won't be written out to the manifest artifact. * Begin populating `batch_context` during materialization execution for microbatch batches * Fix circular import * Fixup MicrobatchBuilder.batch_id property method * Ensure MicrobatchModelRunner doesn't double compile batches We were compiling the node for each batch _twice_. Besides making microbatch models more expensive than they needed to be, double compiling wasn't causing any issue. However the first compilation was happening _before_ we had added the batch context information to the model node for the batch. This was leading to models which try to access the `batch_context` information on the model to blow up, which was undesirable. As such, we've now gone and skipped the first compilation. We've done this similar to how SavedQuery nodes skip compilation. * Add `__post_serialize__` method to `BatchContext` to ensure correct dict shape This is weird, but necessary, I apologize. Mashumaro handles the dictification of this class via a compile time generated `to_dict` method based off of the _typing_ of th class. By default `datetime` types are converted to strings. We don't want that, we want them to stay datetimes. * Update tests to check for `batch_context` * Update `resolve_event_time_filter` to use new `batch_context` * Stop testing for batchless compiled code for microbatch models In 45daec7 we stopped an extra compilation that was happening per batch prior to the batch_context being loaded. Stopping this extra compilation means that compiled sql for the microbatch model without the event time filter / batch context is no longer produced. We have discussed this and _believe_ it is okay given that this is a new node type that has not hit GA yet. * Rename `ModelNode.batch_context` to `ModelNode.batch` * Rename `build_batch_context` to `build_jinja_context_for_batch` The name `build_batch_context` was confusing as 1) We have a `BatchContext` object, which the method was not building 2) The method builds the jinja context for the batch As such it felt appropriate to rename the method to more accurately communicate what it does. * Rename test macro `invalid_batch_context_macro_sql` to `invalid_batch_jinja_context_macro_sql` This rename was to make it more clear that the jinja context for a batch was being checked, as a batch_context has a slightly different connotation. * Update changie doc (cherry picked from commit c3d87b8) Co-authored-by: Quigley Malcolm <[email protected]>
1 parent 65f05e0 commit 6a36444

File tree

8 files changed

+90
-41
lines changed

8 files changed

+90
-41
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
kind: Features
2+
body: Add `batch` context object to model jinja context
3+
time: 2024-11-21T12:56:30.715473-06:00
4+
custom:
5+
Author: QMalcolm
6+
Issue: "11025"

core/dbt/context/providers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,9 +244,10 @@ def resolve_event_time_filter(self, target: ManifestNode) -> Optional[EventTimeF
244244
and self.model.config.materialized == "incremental"
245245
and self.model.config.incremental_strategy == "microbatch"
246246
and self.manifest.use_microbatch_batches(project_name=self.config.project_name)
247+
and self.model.batch is not None
247248
):
248-
start = self.model.config.get("__dbt_internal_microbatch_event_time_start")
249-
end = self.model.config.get("__dbt_internal_microbatch_event_time_end")
249+
start = self.model.batch.event_time_start
250+
end = self.model.batch.event_time_end
250251

251252
if start is not None or end is not None:
252253
event_time_filter = EventTimeFilter(

core/dbt/contracts/graph/nodes.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
ConstraintType,
9494
ModelLevelConstraint,
9595
)
96+
from dbt_common.dataclass_schema import dbtClassMixin
9697
from dbt_common.events.contextvars import set_log_contextvars
9798
from dbt_common.events.functions import warn_or_error
9899

@@ -442,9 +443,30 @@ def resource_class(cls) -> Type[HookNodeResource]:
442443
return HookNodeResource
443444

444445

446+
@dataclass
447+
class BatchContext(dbtClassMixin):
448+
id: str
449+
event_time_start: datetime
450+
event_time_end: datetime
451+
452+
def __post_serialize__(self, data, context):
453+
# This is insane, but necessary, I apologize. Mashumaro handles the
454+
# dictification of this class via a compile time generated `to_dict`
455+
# method based off of the _typing_ of th class. By default `datetime`
456+
# types are converted to strings. We don't want that, we want them to
457+
# stay datetimes.
458+
# Note: This is safe because the `BatchContext` isn't part of the artifact
459+
# and thus doesn't get written out.
460+
new_data = super().__post_serialize__(data, context)
461+
new_data["event_time_start"] = self.event_time_start
462+
new_data["event_time_end"] = self.event_time_end
463+
return new_data
464+
465+
445466
@dataclass
446467
class ModelNode(ModelResource, CompiledNode):
447468
previous_batch_results: Optional[BatchResults] = None
469+
batch: Optional[BatchContext] = None
448470
_has_this: Optional[bool] = None
449471

450472
def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):

core/dbt/materializations/incremental/microbatch.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -100,25 +100,25 @@ def build_batches(self, start: datetime, end: datetime) -> List[BatchType]:
100100

101101
return batches
102102

103-
def build_batch_context(self, incremental_batch: bool) -> Dict[str, Any]:
103+
def build_jinja_context_for_batch(self, incremental_batch: bool) -> Dict[str, Any]:
104104
"""
105105
Create context with entries that reflect microbatch model + incremental execution state
106106
107107
Assumes self.model has been (re)-compiled with necessary batch filters applied.
108108
"""
109-
batch_context: Dict[str, Any] = {}
109+
jinja_context: Dict[str, Any] = {}
110110

111111
# Microbatch model properties
112-
batch_context["model"] = self.model.to_dict()
113-
batch_context["sql"] = self.model.compiled_code
114-
batch_context["compiled_code"] = self.model.compiled_code
112+
jinja_context["model"] = self.model.to_dict()
113+
jinja_context["sql"] = self.model.compiled_code
114+
jinja_context["compiled_code"] = self.model.compiled_code
115115

116116
# Add incremental context variables for batches running incrementally
117117
if incremental_batch:
118-
batch_context["is_incremental"] = lambda: True
119-
batch_context["should_full_refresh"] = lambda: False
118+
jinja_context["is_incremental"] = lambda: True
119+
jinja_context["should_full_refresh"] = lambda: False
120120

121-
return batch_context
121+
return jinja_context
122122

123123
@staticmethod
124124
def offset_timestamp(timestamp: datetime, batch_size: BatchSize, offset: int) -> datetime:
@@ -193,12 +193,11 @@ def truncate_timestamp(timestamp: datetime, batch_size: BatchSize) -> datetime:
193193
return truncated
194194

195195
@staticmethod
196-
def format_batch_start(
197-
batch_start: Optional[datetime], batch_size: BatchSize
198-
) -> Optional[str]:
199-
if batch_start is None:
200-
return batch_start
196+
def batch_id(start_time: datetime, batch_size: BatchSize) -> str:
197+
return MicrobatchBuilder.format_batch_start(start_time, batch_size).replace("-", "")
201198

199+
@staticmethod
200+
def format_batch_start(batch_start: datetime, batch_size: BatchSize) -> str:
202201
return str(
203202
batch_start.date() if (batch_start and batch_size != BatchSize.hour) else batch_start
204203
)

core/dbt/task/run.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from dbt.config import RuntimeConfig
2828
from dbt.context.providers import generate_runtime_model_context
2929
from dbt.contracts.graph.manifest import Manifest
30-
from dbt.contracts.graph.nodes import HookNode, ModelNode, ResultNode
30+
from dbt.contracts.graph.nodes import BatchContext, HookNode, ModelNode, ResultNode
3131
from dbt.events.types import (
3232
GenericExceptionOnRun,
3333
LogHookEndLine,
@@ -341,6 +341,13 @@ def __init__(self, config, adapter, node, node_index: int, num_nodes: int):
341341
self.batches: Dict[int, BatchType] = {}
342342
self.relation_exists: bool = False
343343

344+
def compile(self, manifest: Manifest):
345+
# The default compile function is _always_ called. However, we do our
346+
# compilation _later_ in `_execute_microbatch_materialization`. This
347+
# meant the node was being compiled _twice_ for each batch. To get around
348+
# this, we've overriden the default compile method to do nothing
349+
return self.node
350+
344351
def set_batch_idx(self, batch_idx: int) -> None:
345352
self.batch_idx = batch_idx
346353

@@ -353,7 +360,7 @@ def set_batches(self, batches: Dict[int, BatchType]) -> None:
353360
def describe_node(self) -> str:
354361
return f"{self.node.language} microbatch model {self.get_node_representation()}"
355362

356-
def describe_batch(self, batch_start: Optional[datetime]) -> str:
363+
def describe_batch(self, batch_start: datetime) -> str:
357364
# Only visualize date if batch_start year/month/day
358365
formatted_batch_start = MicrobatchBuilder.format_batch_start(
359366
batch_start, self.node.config.batch_size
@@ -530,10 +537,16 @@ def _execute_microbatch_materialization(
530537
# call materialization_macro to get a batch-level run result
531538
start_time = time.perf_counter()
532539
try:
533-
# Set start/end in context prior to re-compiling
540+
# LEGACY: Set start/end in context prior to re-compiling (Will be removed for 1.10+)
541+
# TODO: REMOVE before 1.10 GA
534542
model.config["__dbt_internal_microbatch_event_time_start"] = batch[0]
535543
model.config["__dbt_internal_microbatch_event_time_end"] = batch[1]
536-
544+
# Create batch context on model node prior to re-compiling
545+
model.batch = BatchContext(
546+
id=MicrobatchBuilder.batch_id(batch[0], model.config.batch_size),
547+
event_time_start=batch[0],
548+
event_time_end=batch[1],
549+
)
537550
# Recompile node to re-resolve refs with event time filters rendered, update context
538551
self.compiler.compile_node(
539552
model,
@@ -544,10 +557,10 @@ def _execute_microbatch_materialization(
544557
),
545558
)
546559
# Update jinja context with batch context members
547-
batch_context = microbatch_builder.build_batch_context(
560+
jinja_context = microbatch_builder.build_jinja_context_for_batch(
548561
incremental_batch=self.relation_exists
549562
)
550-
context.update(batch_context)
563+
context.update(jinja_context)
551564

552565
# Materialize batch and cache any materialized relations
553566
result = MacroGenerator(

tests/functional/microbatch/test_microbatch.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@
6464
select * from {{ ref('microbatch_model') }}
6565
"""
6666

67-
invalid_batch_context_macro_sql = """
68-
{% macro check_invalid_batch_context() %}
67+
invalid_batch_jinja_context_macro_sql = """
68+
{% macro check_invalid_batch_jinja_context() %}
6969
7070
{% if model is not mapping %}
7171
{{ exceptions.raise_compiler_error("`model` is invalid: expected mapping type") }}
@@ -83,9 +83,9 @@
8383
"""
8484

8585
microbatch_model_with_context_checks_sql = """
86-
{{ config(pre_hook="{{ check_invalid_batch_context() }}", materialized='incremental', incremental_strategy='microbatch', unique_key='id', event_time='event_time', batch_size='day', begin=modules.datetime.datetime(2020, 1, 1, 0, 0, 0)) }}
86+
{{ config(pre_hook="{{ check_invalid_batch_jinja_context() }}", materialized='incremental', incremental_strategy='microbatch', unique_key='id', event_time='event_time', batch_size='day', begin=modules.datetime.datetime(2020, 1, 1, 0, 0, 0)) }}
8787
88-
{{ check_invalid_batch_context() }}
88+
{{ check_invalid_batch_jinja_context() }}
8989
select * from {{ ref('input_model') }}
9090
"""
9191

@@ -404,7 +404,7 @@ class TestMicrobatchJinjaContext(BaseMicrobatchTest):
404404

405405
@pytest.fixture(scope="class")
406406
def macros(self):
407-
return {"check_batch_context.sql": invalid_batch_context_macro_sql}
407+
return {"check_batch_jinja_context.sql": invalid_batch_jinja_context_macro_sql}
408408

409409
@pytest.fixture(scope="class")
410410
def models(self):
@@ -498,6 +498,13 @@ def test_run_with_event_time(self, project):
498498
{{ config(materialized='incremental', incremental_strategy='microbatch', unique_key='id', event_time='event_time', batch_size='day', begin=modules.datetime.datetime(2020, 1, 1, 0, 0, 0)) }}
499499
{{ log("start: "~ model.config.__dbt_internal_microbatch_event_time_start, info=True)}}
500500
{{ log("end: "~ model.config.__dbt_internal_microbatch_event_time_end, info=True)}}
501+
{% if model.batch %}
502+
{{ log("batch.event_time_start: "~ model.batch.event_time_start, info=True)}}
503+
{{ log("batch.event_time_end: "~ model.batch.event_time_end, info=True)}}
504+
{{ log("batch.id: "~ model.batch.id, info=True)}}
505+
{{ log("start timezone: "~ model.batch.event_time_start.tzinfo, info=True)}}
506+
{{ log("end timezone: "~ model.batch.event_time_end.tzinfo, info=True)}}
507+
{% endif %}
501508
select * from {{ ref('input_model') }}
502509
"""
503510

@@ -516,12 +523,23 @@ def test_run_with_event_time_logs(self, project):
516523

517524
assert "start: 2020-01-01 00:00:00+00:00" in logs
518525
assert "end: 2020-01-02 00:00:00+00:00" in logs
526+
assert "batch.event_time_start: 2020-01-01 00:00:00+00:00" in logs
527+
assert "batch.event_time_end: 2020-01-02 00:00:00+00:00" in logs
528+
assert "batch.id: 20200101" in logs
529+
assert "start timezone: UTC" in logs
530+
assert "end timezone: UTC" in logs
519531

520532
assert "start: 2020-01-02 00:00:00+00:00" in logs
521533
assert "end: 2020-01-03 00:00:00+00:00" in logs
534+
assert "batch.event_time_start: 2020-01-02 00:00:00+00:00" in logs
535+
assert "batch.event_time_end: 2020-01-03 00:00:00+00:00" in logs
536+
assert "batch.id: 20200102" in logs
522537

523538
assert "start: 2020-01-03 00:00:00+00:00" in logs
524539
assert "end: 2020-01-03 13:57:00+00:00" in logs
540+
assert "batch.event_time_start: 2020-01-03 00:00:00+00:00" in logs
541+
assert "batch.event_time_end: 2020-01-03 13:57:00+00:00" in logs
542+
assert "batch.id: 20200103" in logs
525543

526544

527545
microbatch_model_failing_incremental_partition_sql = """
@@ -675,16 +693,6 @@ def test_run_with_event_time(self, project):
675693
with patch_microbatch_end_time("2020-01-03 13:57:00"):
676694
run_dbt(["run"])
677695

678-
# Compiled paths - compiled model without filter only
679-
assert read_file(
680-
project.project_root,
681-
"target",
682-
"compiled",
683-
"test",
684-
"models",
685-
"microbatch_model.sql",
686-
)
687-
688696
# Compiled paths - batch compilations
689697
assert read_file(
690698
project.project_root,

tests/unit/contracts/graph/test_manifest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@
9696
"deprecation_date",
9797
"defer_relation",
9898
"time_spine",
99+
"batch",
99100
}
100101
)
101102

tests/unit/materializations/incremental/test_microbatch.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -489,11 +489,11 @@ def test_build_batches(self, microbatch_model, start, end, batch_size, expected_
489489
assert len(actual_batches) == len(expected_batches)
490490
assert actual_batches == expected_batches
491491

492-
def test_build_batch_context_incremental_batch(self, microbatch_model):
492+
def test_build_jinja_context_for_incremental_batch(self, microbatch_model):
493493
microbatch_builder = MicrobatchBuilder(
494494
model=microbatch_model, is_incremental=True, event_time_start=None, event_time_end=None
495495
)
496-
context = microbatch_builder.build_batch_context(incremental_batch=True)
496+
context = microbatch_builder.build_jinja_context_for_batch(incremental_batch=True)
497497

498498
assert context["model"] == microbatch_model.to_dict()
499499
assert context["sql"] == microbatch_model.compiled_code
@@ -502,11 +502,11 @@ def test_build_batch_context_incremental_batch(self, microbatch_model):
502502
assert context["is_incremental"]() is True
503503
assert context["should_full_refresh"]() is False
504504

505-
def test_build_batch_context_incremental_batch_false(self, microbatch_model):
505+
def test_build_jinja_context_for_incremental_batch_false(self, microbatch_model):
506506
microbatch_builder = MicrobatchBuilder(
507507
model=microbatch_model, is_incremental=True, event_time_start=None, event_time_end=None
508508
)
509-
context = microbatch_builder.build_batch_context(incremental_batch=False)
509+
context = microbatch_builder.build_jinja_context_for_batch(incremental_batch=False)
510510

511511
assert context["model"] == microbatch_model.to_dict()
512512
assert context["sql"] == microbatch_model.compiled_code
@@ -605,7 +605,6 @@ def test_truncate_timestamp(self, timestamp, batch_size, expected_timestamp):
605605
@pytest.mark.parametrize(
606606
"batch_size,batch_start,expected_formatted_batch_start",
607607
[
608-
(None, None, None),
609608
(BatchSize.year, datetime(2020, 1, 1, 1), "2020-01-01"),
610609
(BatchSize.month, datetime(2020, 1, 1, 1), "2020-01-01"),
611610
(BatchSize.day, datetime(2020, 1, 1, 1), "2020-01-01"),

0 commit comments

Comments
 (0)