diff --git a/core/dbt/parser/base.py b/core/dbt/parser/base.py index 9074a9b24c0..837bfd3df38 100644 --- a/core/dbt/parser/base.py +++ b/core/dbt/parser/base.py @@ -335,6 +335,7 @@ def update_parsed_node_config( patch_config_dict=None, patch_file_id=None, validate_config_call_dict: bool = False, + patch_file_index: Optional[str] = None, ) -> None: """Given the ContextConfig used for parsing and the parsed node, generate and set the true values to use, overriding the temporary parse @@ -413,8 +414,20 @@ def update_parsed_node_config( if patch_file and isinstance(patch_file, SchemaSourceFile): schema_key = resource_types_to_schema_file_keys.get(parsed_node.resource_type) if schema_key: + lookup_name = parsed_node.name + lookup_version = getattr(parsed_node, "version", None) + + # Test lookup needs to consider attached node and indexing + if ( + parsed_node.resource_type == NodeType.Test + and hasattr(parsed_node, "attached_node") and parsed_node.attached_node + ): + if attached_node := self.manifest.nodes.get(parsed_node.attached_node): + lookup_name = f"{attached_node.name}_{patch_file_index}" + lookup_version = getattr(attached_node, "version", None) + if unrendered_patch_config := patch_file.get_unrendered_config( - schema_key, parsed_node.name, getattr(parsed_node, "version", None) + schema_key, lookup_name, lookup_version ): patch_config_dict = deep_merge( patch_config_dict, unrendered_patch_config diff --git a/core/dbt/parser/common.py b/core/dbt/parser/common.py index 378d459705d..9c2e2f020ce 100644 --- a/core/dbt/parser/common.py +++ b/core/dbt/parser/common.py @@ -35,6 +35,7 @@ "semantic_models": NodeType.SemanticModel, "saved_queries": NodeType.SavedQuery, "functions": NodeType.Function, + "data_tests": NodeType.Test, } resource_types_to_schema_file_keys = { @@ -183,6 +184,7 @@ class GenericTestBlock(TestBlock[Testable], Generic[Testable]): column_name: Optional[str] tags: List[str] version: Optional[NodeVersion] + test_index: int @classmethod def from_test_block( @@ -192,6 +194,7 @@ def from_test_block( column_name: Optional[str], tags: List[str], version: Optional[NodeVersion], + test_index: int, ) -> "GenericTestBlock": return cls( file=src.file, @@ -201,6 +204,7 @@ def from_test_block( column_name=column_name, tags=tags, version=version, + test_index=test_index, ) diff --git a/core/dbt/parser/schema_generic_tests.py b/core/dbt/parser/schema_generic_tests.py index 932547adbae..f265ab9b99a 100644 --- a/core/dbt/parser/schema_generic_tests.py +++ b/core/dbt/parser/schema_generic_tests.py @@ -81,8 +81,8 @@ def parse_column_tests( if not column.data_tests: return - for data_test in column.data_tests: - self.parse_test(block, data_test, column, version) + for data_test_idx, data_test in enumerate(column.data_tests): + self.parse_test(block, data_test, column, version, data_test_idx) def create_test_node( self, @@ -161,6 +161,7 @@ def parse_generic_test( column_name: Optional[str], schema_file_id: str, version: Optional[NodeVersion], + test_index: Optional[int] = None, ) -> GenericTestNode: try: builder = TestBuilder( @@ -233,7 +234,7 @@ def parse_generic_test( file_key_name=file_key_name, description=builder.description, ) - self.render_test_update(node, config, builder, schema_file_id) + self.render_test_update(node, config, builder, schema_file_id, test_index) return node @@ -278,18 +279,33 @@ def store_env_vars(self, target, schema_file_id, env_vars): # In the future we will look at generalizing this # more to handle additional macros or to use static # parsing to avoid jinja overhead. - def render_test_update(self, node, config, builder, schema_file_id): + def render_test_update(self, node, config, builder, schema_file_id, test_index: int): macro_unique_id = self.macro_resolver.get_macro_id( node.package_name, "test_" + builder.name ) # Add the depends_on here so we can limit the macros added # to the context in rendering processing node.depends_on.add_macro(macro_unique_id) + + # Set attached_node for generic test nodes, if available. + # Generic test node inherits attached node's group config value. + attached_node = self._lookup_attached_node(builder.target, builder.version) + if attached_node: + node.attached_node = attached_node.unique_id + node.group, node.group = attached_node.group, attached_node.group + + # Index for lookups on patch file, used when setting unrendered_config for tests + patch_file_index = ( + f"{node.column_name}_{test_index}" if node.column_name else str(test_index) + ) + if macro_unique_id in ["macro.dbt.test_not_null", "macro.dbt.test_unique"]: config_call_dict = builder.config config._config_call_dict = config_call_dict # This sets the config from dbt_project - self.update_parsed_node_config(node, config) + self.update_parsed_node_config( + node, config, patch_file_id=schema_file_id, patch_file_index=patch_file_index + ) # source node tests are processed at patch_source time if isinstance(builder.target, UnpatchedSourceDefinition): sources = [builder.target.fqn[-2], builder.target.fqn[-1]] @@ -312,19 +328,14 @@ def render_test_update(self, node, config, builder, schema_file_id): add_rendered_test_kwargs(context, node, capture_macros=True) # the parsed node is not rendered in the native context. get_rendered(node.raw_code, context, node, capture_macros=True) - self.update_parsed_node_config(node, config) + self.update_parsed_node_config( + node, config, patch_file_id=schema_file_id, patch_file_index=patch_file_index + ) # env_vars should have been updated in the context env_var method except ValidationError as exc: # we got a ValidationError - probably bad types in config() raise SchemaConfigError(exc, node=node) from exc - # Set attached_node for generic test nodes, if available. - # Generic test node inherits attached node's group config value. - attached_node = self._lookup_attached_node(builder.target, builder.version) - if attached_node: - node.attached_node = attached_node.unique_id - node.group, node.group = attached_node.group, attached_node.group - def parse_node(self, block: GenericTestBlock) -> GenericTestNode: """In schema parsing, we rewrite most of the part of parse_node that builds the initial node to be parsed, but rendering is basically the @@ -337,6 +348,7 @@ def parse_node(self, block: GenericTestBlock) -> GenericTestNode: column_name=block.column_name, schema_file_id=block.file.file_id, version=block.version, + test_index=block.test_index, ) self.add_test_node(block, node) return node @@ -371,6 +383,7 @@ def parse_test( data_test: TestDef, column: Optional[UnparsedColumn], version: Optional[NodeVersion], + test_index: int, ) -> None: if isinstance(data_test, str): data_test = {data_test: {}} @@ -395,15 +408,17 @@ def parse_test( column_name=column_name, tags=column_tags, version=version, + test_index=test_index, ) self.parse_node(block) def parse_tests(self, block: TestBlock) -> None: + # TODO: plumb indexing here for column in block.columns: self.parse_column_tests(block, column, None) - for data_test in block.data_tests: - self.parse_test(block, data_test, None, None) + for data_test_idx, data_test in enumerate(block.data_tests): + self.parse_test(block, data_test, None, None, data_test_idx) def parse_versioned_tests(self, block: VersionedTestBlock) -> None: if not block.target.versions: diff --git a/core/dbt/parser/schemas.py b/core/dbt/parser/schemas.py index 754c8df4472..cae1ae39694 100644 --- a/core/dbt/parser/schemas.py +++ b/core/dbt/parser/schemas.py @@ -466,6 +466,31 @@ def get_key_dicts(self) -> Iterable[Dict[str, Any]]: if "config" in entry: unrendered_config = entry["config"] + unrendered_data_test_configs = {} + if "data_tests" in entry: + for data_test_idx, data_test in enumerate(entry["data_tests"]): + if isinstance(data_test, dict) and len(data_test): + data_test_definition = list(data_test.values())[0] + if isinstance(data_test_definition, dict) and data_test_definition.get( + "config" + ): + unrendered_data_test_configs[f"{entry['name']}_{data_test_idx}"] = ( + data_test_definition["config"] + ) + + if "columns" in entry: + for column in entry["columns"]: + if isinstance(column, dict) and column.get("data_tests"): + for data_test_idx, data_test in enumerate(column["data_tests"]): + if isinstance(data_test, dict) and len(data_test) == 1: + data_test_definition = list(data_test.values())[0] + if isinstance( + data_test_definition, dict + ) and data_test_definition.get("config"): + unrendered_data_test_configs[ + f"{entry['name']}_{column['name']}_{data_test_idx}" + ] = data_test_definition["config"] + unrendered_version_configs = {} if "versions" in entry: for version in entry["versions"]: @@ -486,6 +511,12 @@ def get_key_dicts(self) -> Iterable[Dict[str, Any]]: if unrendered_config: schema_file.add_unrendered_config(unrendered_config, self.key, entry["name"]) + for test_name, unrendered_data_test_config in unrendered_data_test_configs.items(): + print(f"ADD: {test_name}") + schema_file.add_unrendered_config( + unrendered_data_test_config, "data_tests", test_name + ) + for version, unrendered_version_config in unrendered_version_configs.items(): schema_file.add_unrendered_config( unrendered_version_config, self.key, entry["name"], version