diff --git a/airflow-core/src/airflow/utils/dag_version_inflation_checker.py b/airflow-core/src/airflow/utils/dag_version_inflation_checker.py index ca7f58b26026f..e36fae7544b01 100644 --- a/airflow-core/src/airflow/utils/dag_version_inflation_checker.py +++ b/airflow-core/src/airflow/utils/dag_version_inflation_checker.py @@ -59,7 +59,7 @@ class DagVersionInflationCheckResult: def __init__(self, check_level: DagVersionInflationCheckLevel): self.check_level: DagVersionInflationCheckLevel = check_level - self.warnings: list[RuntimeVaryingValueWarning] = [] + self.warnings: dict[int, RuntimeVaryingValueWarning] = {} self.runtime_varying_values: dict = {} def format_warnings(self) -> str | None: @@ -72,7 +72,7 @@ def format_warnings(self) -> str | None: "It causes the Dag version to increase as values change on every Dag parse.", "", ] - for w in self.warnings: + for w in self.warnings.values(): lines.extend( [ f"Line {w.line}, Col {w.col}", @@ -137,6 +137,7 @@ class WarningContext(str, Enum): TASK_CONSTRUCTOR = "Task constructor" DAG_CONSTRUCTOR = "Dag constructor" + TASK_DECORATOR = "Task decorator" class RuntimeVaryingValueAnalyzer: @@ -305,17 +306,23 @@ def __init__(self, from_imports: dict[str, tuple[str, str]]): self.from_imports: dict[str, tuple[str, str]] = from_imports self.dag_instances: set[str] = set() self.is_in_dag_context: bool = False + self.function_def_context: str | None = None def is_dag_constructor(self, node: ast.Call) -> bool: """Check if a call is a Dag constructor.""" - if not isinstance(node.func, ast.Name): - return False - - func_name = node.func.id + # to handle use case "from airflow import sdk" and "with sdk.DAG()" + if isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name): + if node.func.value.id in self.from_imports: + module, original = self.from_imports[node.func.value.id] + if (module == "airflow" or module.startswith("airflow.")) and node.func.attr in ( + "DAG", + "dag", + ): + return True - # "from airflow import DAG" form or "from airflow.decorator import dag" - if func_name in self.from_imports: - module, original = self.from_imports[func_name] + # to handle use case "from airflow import DAG" form or "from airflow.decorator import dag" + if isinstance(node.func, ast.Name) and node.func.id in self.from_imports: + module, original = self.from_imports[node.func.id] if (module == "airflow" or module.startswith("airflow.")) and original in ("DAG", "dag"): return True @@ -329,8 +336,8 @@ def is_task_constructor(self, node: ast.Call) -> bool: 1. All calls within a Dag with block 2. Calls that receive a Dag instance as an argument (dag=...) """ - # Inside Dag with block - if self.is_in_dag_context: + # Check whether it is nside Dag with block and has task pattern name + if self.is_in_dag_context and self.check_is_task_by_name(node.func): return True # Passing Dag instance as argument @@ -345,6 +352,28 @@ def is_task_constructor(self, node: ast.Call) -> bool: return False + def is_task_decorator(self, node: ast.expr): + if isinstance(node, ast.Name) or isinstance(node, ast.Attribute): + return self.check_is_task_by_name(node) + if isinstance(node, ast.Call): + return self.is_task_decorator(node.func) + return False + + def check_is_task_by_name(self, node: ast.expr): + """Check if task function has task name.""" + + def check_is_task_function_name(name): + return name.lower().endswith("operator") or name.lower().endswith("task") + + if isinstance(node, ast.Name): + return check_is_task_function_name(node.id) + if isinstance(node, ast.Attribute): + if check_is_task_function_name(node.attr): + return True + return self.check_is_task_by_name(node.value) + + return False + def register_dag_instance(self, var_name: str): """Register a Dag instance variable name.""" self.dag_instances.add(var_name) @@ -375,6 +404,7 @@ def __init__(self, check_level: DagVersionInflationCheckLevel = DagVersionInflat self.imports: dict[str, str] = {} self.from_imports: dict[str, tuple[str, str]] = {} self.varying_vars: dict[str, tuple[int, str]] = {} + self.varying_functions: dict[str, RuntimeVaryingValueWarning] = {} self.check_level = check_level # Helper objects @@ -424,6 +454,9 @@ def visit_Call(self, node: ast.Call): Check not assign but just call the function or Dag definition via decorator. """ + if isinstance(node.func, ast.Name) and (warning := self.varying_functions.get(node.func.id)): + self.static_check_result.warnings[warning.line] = warning + if self.dag_detector.is_dag_constructor(node): self._check_and_warn(node, WarningContext.DAG_CONSTRUCTOR) @@ -464,6 +497,10 @@ def visit_With(self, node: ast.With): # check the value defined in with statement to detect entering Dag with block is_with_dag_context = True + # add dag variable defined in with Dag statement + if item.optional_vars and isinstance(item.optional_vars, ast.Name): + self._register_dag_instances([item.optional_vars]) + if is_with_dag_context: self.dag_detector.enter_dag_context() @@ -473,6 +510,21 @@ def visit_With(self, node: ast.With): # Exit Dag with block self.dag_detector.exit_dag_context() + def visit_FunctionDef(self, node: ast.FunctionDef): + for decorator in node.decorator_list: + if self.dag_detector.is_task_decorator(decorator): + if isinstance(decorator, ast.Call): + self._check_and_warn(decorator, WarningContext.TASK_DECORATOR) + return + self.visit(decorator) + + self.dag_detector.function_def_context = node.name + + for body in node.body: + self.visit(body) + + self.dag_detector.function_def_context = None + def _register_dag_instances(self, targets: list): """Register Dag instance variable names.""" for target in targets: @@ -489,19 +541,22 @@ def _track_varying_assignment(self, node: ast.Assign): def _check_and_warn(self, call: ast.Call, context: WarningContext): """Check function call arguments and generate warnings.""" if self.value_analyzer.get_varying_source(call): - self.static_check_result.warnings.append( - RuntimeVaryingValueWarning( - line=call.lineno, - col=call.col_offset, - code=ast.unparse(call), - message=self._get_warning_message(context), - ) + warning = RuntimeVaryingValueWarning( + line=call.lineno, + col=call.col_offset, + code=ast.unparse(call), + message=self._get_warning_message(context), ) + if self.dag_detector.function_def_context: + self.varying_functions[self.dag_detector.function_def_context] = warning + else: + self.static_check_result.warnings[warning.line] = warning + def _get_warning_message(self, context: WarningContext) -> str: """Get appropriate warning message based on context.""" if self.dag_detector.is_in_dag_context and context == WarningContext.TASK_CONSTRUCTOR: - return "Don't use runtime-varying values as function arguments within with Dag block" + return "Don't use runtime-varying values as arguments of task within with Dag block" return f"Don't use runtime-varying value as argument in {context.value}" diff --git a/airflow-core/tests/unit/utils/test_dag_version_inflation_checker.py b/airflow-core/tests/unit/utils/test_dag_version_inflation_checker.py index fdc2160328ffe..8d858523be445 100644 --- a/airflow-core/tests/unit/utils/test_dag_version_inflation_checker.py +++ b/airflow-core/tests/unit/utils/test_dag_version_inflation_checker.py @@ -363,6 +363,28 @@ def test_is_task_constructor__true_when_dag_in_positional_args(self): result = self.detector.is_task_constructor(call_node) assert result is True + def test_is_task_decorator__check_when_normal_decorator(self): + code = """ +@task(task_id='task_id') +def test(): + print("test") +""" + call_node = ast.parse(code).body + + assert isinstance(call_node[0], ast.FunctionDef) + assert self.detector.is_task_decorator(call_node[0].decorator_list[0]) is True + + def test_is_task_decorator__check_when_attribute_decorator(self): + code = """ +@sdk.task +def test(): + print("test") + """ + call_node = ast.parse(code).body + + assert isinstance(call_node[0], ast.FunctionDef) + assert self.detector.is_task_decorator(call_node[0].decorator_list[0]) is True + def test_enter_and_exit_dag_context(self): """Properly track entering and exiting Dag with-blocks.""" assert self.detector.is_in_dag_context is False @@ -467,7 +489,7 @@ def test_visit_assign__warns_on_dag_with_varying_value(self): self.checker.visit(tree) assert len(self.checker.static_check_result.warnings) == 1 - assert any("Dag constructor" in w.message for w in self.checker.static_check_result.warnings) + assert any("Dag constructor" in w.message for w in self.checker.static_check_result.warnings.values()) def test_visit_call__detects_task_in_dag_context(self): """Detect task creation inside Dag with block.""" @@ -484,7 +506,7 @@ def test_visit_call__detects_task_in_dag_context(self): self.checker.visit(tree) assert len(self.checker.static_check_result.warnings) == 1 - assert any("PythonOperator" in w.code for w in self.checker.static_check_result.warnings) + assert any("PythonOperator" in w.code for w in self.checker.static_check_result.warnings.values()) def test_visit_for__warns_on_varying_range(self): """Warn when for-loop range is runtime-varying.""" @@ -498,7 +520,7 @@ def test_visit_for__warns_on_varying_range(self): schedule_interval='@daily', ) as dag: for i in [datetime.now(), "3"]: - task = BashOperator( + task = BashTask( task_id='print_bash_hello_{i}', bash_command=f'echo "Hello from DAG {i}!"', # !problem dag=dag, @@ -507,10 +529,10 @@ def test_visit_for__warns_on_varying_range(self): tree = ast.parse(code) self.checker.visit(tree) - warnings = self.checker.static_check_result.warnings + warnings = self.checker.static_check_result.warnings.values() assert len(warnings) == 1 - assert any("BashOperator" in w.code for w in warnings) + assert any("BashTask" in w.code for w in warnings) def test_check_and_warn__creates_warning_for_varying_arg(self): """Create a warning when detecting varying positional argument.""" @@ -522,7 +544,7 @@ def test_check_and_warn__creates_warning_for_varying_arg(self): self.checker._check_and_warn(call_node, WarningContext.DAG_CONSTRUCTOR) assert len(self.checker.static_check_result.warnings) == 1 - warning = self.checker.static_check_result.warnings[0] + warning = next(iter(self.checker.static_check_result.warnings.values())) assert WarningContext.DAG_CONSTRUCTOR.value in warning.message assert "datetime.now()" in warning.code @@ -536,7 +558,7 @@ def test_check_and_warn__creates_warning_for_varying_kwarg(self): self.checker._check_and_warn(call_node, WarningContext.TASK_CONSTRUCTOR) assert len(self.checker.static_check_result.warnings) == 1 - warning = self.checker.static_check_result.warnings[0] + warning = next(iter(self.checker.static_check_result.warnings.values())) assert "dag_id" in warning.code assert "datetime.now()" in warning.code @@ -552,7 +574,7 @@ def _check_code(self, code: str) -> list[RuntimeVaryingValueWarning]: tree = ast.parse(code) checker = AirflowRuntimeVaryingValueChecker() checker.visit(tree) - return checker.static_check_result.warnings + return list(checker.static_check_result.warnings.values()) def test_antipattern__dynamic_dag_id_with_timestamp(self): """ANTI-PATTERN: Using timestamps in Dag IDs.""" @@ -663,6 +685,7 @@ def test_dag_decorator_pattern__currently_not_detected(self): code = """ from airflow.decorators import dag, task from datetime import datetime +from random import random @dag(dag_id=f"my_dag_{datetime.now()}") # !problem def my_dag_function(): @@ -677,6 +700,10 @@ def my_task(): assert len(warnings) == 1 def test_dag_generated_in_for_or_function_statement(self): + """ + There are runtime-varying case in create_dag function. + But the function doesn't use in here so doesn't make warning + """ code = """ from airflow import DAG from airflow.operators.bash import BashOperator @@ -691,7 +718,7 @@ def create_dag(dag_id, task_id): with DAG( dag_id, - default_args=default_args, # !problem + default_args=default_args, # not problem, because the function create_dag not called in statement ) as dag: task1 = BashOperator( task_id=task_id @@ -732,4 +759,56 @@ def create_dag(dag_id, task_id): task1 >> task2 >> task3 """ warnings = self._check_code(code) - assert len(warnings) == 5 + assert len(warnings) == 4 + + def test_import_dag_from_sdk_module(self): + code = """ +from airflow import sdk +from airflow.providers.standard.operators.python import PythonOperator +from time import sleep +import datetime + +with sdk.DAG( + dag_id="test1", + schedule="* * * * *", + start_date=datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(minutes=1), + tags=["example"], +) as dag: + PythonOperator( + task_id="test1_task", + python_callable=lambda: print(datetime.now()), + ) +""" + warnings = self._check_code(code) + assert len(warnings) == 1 + + def test_python_task_with_task_decorator(self): + code = """ +from airflow.sdk import task, DAG +import datetime + +with DAG( + dag_id="test1", + schedule="* * * * *", + tags=["example"], +) as dag: + # the function is serialized with code string byte, so it doesn't affect dag version + + @task + def add_one_task(x: int): + from random import random + + for i in range(int(random() * 50)): + do_task(f"Simulating work... {i+1}") + sleep(1) + return x + 1 + + #!problem + @task(task_id=datetime.now()) + def test_task(): + print("test") + + add_one(3) +""" + warnings = self._check_code(code) + assert len(warnings) == 1