Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion core/dbt/parser/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ def update_parsed_node_config(
patch_config_dict=None,
patch_file_id=None,
validate_config_call_dict: bool = False,
patch_file_index: Optional[str] = None,
) -> None:
"""Given the ContextConfig used for parsing and the parsed node,
generate and set the true values to use, overriding the temporary parse
Expand Down Expand Up @@ -413,8 +414,20 @@ def update_parsed_node_config(
if patch_file and isinstance(patch_file, SchemaSourceFile):
schema_key = resource_types_to_schema_file_keys.get(parsed_node.resource_type)
if schema_key:
lookup_name = parsed_node.name
lookup_version = getattr(parsed_node, "version", None)

# Test lookup needs to consider attached node and indexing
if (
parsed_node.resource_type == NodeType.Test
and hasattr(parsed_node, "attached_node") and parsed_node.attached_node
):
if attached_node := self.manifest.nodes.get(parsed_node.attached_node):
lookup_name = f"{attached_node.name}_{patch_file_index}"
lookup_version = getattr(attached_node, "version", None)

if unrendered_patch_config := patch_file.get_unrendered_config(
schema_key, parsed_node.name, getattr(parsed_node, "version", None)
schema_key, lookup_name, lookup_version
):
patch_config_dict = deep_merge(
patch_config_dict, unrendered_patch_config
Expand Down
4 changes: 4 additions & 0 deletions core/dbt/parser/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"semantic_models": NodeType.SemanticModel,
"saved_queries": NodeType.SavedQuery,
"functions": NodeType.Function,
"data_tests": NodeType.Test,
}

resource_types_to_schema_file_keys = {
Expand Down Expand Up @@ -183,6 +184,7 @@ class GenericTestBlock(TestBlock[Testable], Generic[Testable]):
column_name: Optional[str]
tags: List[str]
version: Optional[NodeVersion]
test_index: int

@classmethod
def from_test_block(
Expand All @@ -192,6 +194,7 @@ def from_test_block(
column_name: Optional[str],
tags: List[str],
version: Optional[NodeVersion],
test_index: int,
) -> "GenericTestBlock":
return cls(
file=src.file,
Expand All @@ -201,6 +204,7 @@ def from_test_block(
column_name=column_name,
tags=tags,
version=version,
test_index=test_index,
)


Expand Down
45 changes: 30 additions & 15 deletions core/dbt/parser/schema_generic_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def parse_column_tests(
if not column.data_tests:
return

for data_test in column.data_tests:
self.parse_test(block, data_test, column, version)
for data_test_idx, data_test in enumerate(column.data_tests):
self.parse_test(block, data_test, column, version, data_test_idx)

def create_test_node(
self,
Expand Down Expand Up @@ -161,6 +161,7 @@ def parse_generic_test(
column_name: Optional[str],
schema_file_id: str,
version: Optional[NodeVersion],
test_index: Optional[int] = None,
) -> GenericTestNode:
try:
builder = TestBuilder(
Expand Down Expand Up @@ -233,7 +234,7 @@ def parse_generic_test(
file_key_name=file_key_name,
description=builder.description,
)
self.render_test_update(node, config, builder, schema_file_id)
self.render_test_update(node, config, builder, schema_file_id, test_index)

return node

Expand Down Expand Up @@ -278,18 +279,33 @@ def store_env_vars(self, target, schema_file_id, env_vars):
# In the future we will look at generalizing this
# more to handle additional macros or to use static
# parsing to avoid jinja overhead.
def render_test_update(self, node, config, builder, schema_file_id):
def render_test_update(self, node, config, builder, schema_file_id, test_index: int):
macro_unique_id = self.macro_resolver.get_macro_id(
node.package_name, "test_" + builder.name
)
# Add the depends_on here so we can limit the macros added
# to the context in rendering processing
node.depends_on.add_macro(macro_unique_id)

# Set attached_node for generic test nodes, if available.
# Generic test node inherits attached node's group config value.
attached_node = self._lookup_attached_node(builder.target, builder.version)
if attached_node:
node.attached_node = attached_node.unique_id
node.group, node.group = attached_node.group, attached_node.group

# Index for lookups on patch file, used when setting unrendered_config for tests
patch_file_index = (
f"{node.column_name}_{test_index}" if node.column_name else str(test_index)
)

if macro_unique_id in ["macro.dbt.test_not_null", "macro.dbt.test_unique"]:
config_call_dict = builder.config
config._config_call_dict = config_call_dict
# This sets the config from dbt_project
self.update_parsed_node_config(node, config)
self.update_parsed_node_config(
node, config, patch_file_id=schema_file_id, patch_file_index=patch_file_index
)
# source node tests are processed at patch_source time
if isinstance(builder.target, UnpatchedSourceDefinition):
sources = [builder.target.fqn[-2], builder.target.fqn[-1]]
Expand All @@ -312,19 +328,14 @@ def render_test_update(self, node, config, builder, schema_file_id):
add_rendered_test_kwargs(context, node, capture_macros=True)
# the parsed node is not rendered in the native context.
get_rendered(node.raw_code, context, node, capture_macros=True)
self.update_parsed_node_config(node, config)
self.update_parsed_node_config(
node, config, patch_file_id=schema_file_id, patch_file_index=patch_file_index
)
# env_vars should have been updated in the context env_var method
except ValidationError as exc:
# we got a ValidationError - probably bad types in config()
raise SchemaConfigError(exc, node=node) from exc

# Set attached_node for generic test nodes, if available.
# Generic test node inherits attached node's group config value.
attached_node = self._lookup_attached_node(builder.target, builder.version)
if attached_node:
node.attached_node = attached_node.unique_id
node.group, node.group = attached_node.group, attached_node.group

def parse_node(self, block: GenericTestBlock) -> GenericTestNode:
"""In schema parsing, we rewrite most of the part of parse_node that
builds the initial node to be parsed, but rendering is basically the
Expand All @@ -337,6 +348,7 @@ def parse_node(self, block: GenericTestBlock) -> GenericTestNode:
column_name=block.column_name,
schema_file_id=block.file.file_id,
version=block.version,
test_index=block.test_index,
)
self.add_test_node(block, node)
return node
Expand Down Expand Up @@ -371,6 +383,7 @@ def parse_test(
data_test: TestDef,
column: Optional[UnparsedColumn],
version: Optional[NodeVersion],
test_index: int,
) -> None:
if isinstance(data_test, str):
data_test = {data_test: {}}
Expand All @@ -395,15 +408,17 @@ def parse_test(
column_name=column_name,
tags=column_tags,
version=version,
test_index=test_index,
)
self.parse_node(block)

def parse_tests(self, block: TestBlock) -> None:
# TODO: plumb indexing here
for column in block.columns:
self.parse_column_tests(block, column, None)

for data_test in block.data_tests:
self.parse_test(block, data_test, None, None)
for data_test_idx, data_test in enumerate(block.data_tests):
self.parse_test(block, data_test, None, None, data_test_idx)

def parse_versioned_tests(self, block: VersionedTestBlock) -> None:
if not block.target.versions:
Expand Down
31 changes: 31 additions & 0 deletions core/dbt/parser/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,31 @@ def get_key_dicts(self) -> Iterable[Dict[str, Any]]:
if "config" in entry:
unrendered_config = entry["config"]

unrendered_data_test_configs = {}
if "data_tests" in entry:
for data_test_idx, data_test in enumerate(entry["data_tests"]):
if isinstance(data_test, dict) and len(data_test):
data_test_definition = list(data_test.values())[0]
if isinstance(data_test_definition, dict) and data_test_definition.get(
"config"
):
unrendered_data_test_configs[f"{entry['name']}_{data_test_idx}"] = (
data_test_definition["config"]
)

if "columns" in entry:
for column in entry["columns"]:
if isinstance(column, dict) and column.get("data_tests"):
for data_test_idx, data_test in enumerate(column["data_tests"]):
if isinstance(data_test, dict) and len(data_test) == 1:
data_test_definition = list(data_test.values())[0]
if isinstance(
data_test_definition, dict
) and data_test_definition.get("config"):
unrendered_data_test_configs[
f"{entry['name']}_{column['name']}_{data_test_idx}"
] = data_test_definition["config"]

unrendered_version_configs = {}
if "versions" in entry:
for version in entry["versions"]:
Expand All @@ -486,6 +511,12 @@ def get_key_dicts(self) -> Iterable[Dict[str, Any]]:
if unrendered_config:
schema_file.add_unrendered_config(unrendered_config, self.key, entry["name"])

for test_name, unrendered_data_test_config in unrendered_data_test_configs.items():
print(f"ADD: {test_name}")
schema_file.add_unrendered_config(
unrendered_data_test_config, "data_tests", test_name
)

for version, unrendered_version_config in unrendered_version_configs.items():
schema_file.add_unrendered_config(
unrendered_version_config, self.key, entry["name"], version
Expand Down
Loading