Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 74 additions & 19 deletions airflow-core/src/airflow/utils/dag_version_inflation_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class DagVersionInflationCheckResult:

def __init__(self, check_level: DagVersionInflationCheckLevel):
self.check_level: DagVersionInflationCheckLevel = check_level
self.warnings: list[RuntimeVaryingValueWarning] = []
self.warnings: dict[int, RuntimeVaryingValueWarning] = {}
self.runtime_varying_values: dict = {}

def format_warnings(self) -> str | None:
Expand All @@ -72,7 +72,7 @@ def format_warnings(self) -> str | None:
"It causes the Dag version to increase as values change on every Dag parse.",
"",
]
for w in self.warnings:
for w in self.warnings.values():
lines.extend(
[
f"Line {w.line}, Col {w.col}",
Expand Down Expand Up @@ -137,6 +137,7 @@ class WarningContext(str, Enum):

TASK_CONSTRUCTOR = "Task constructor"
DAG_CONSTRUCTOR = "Dag constructor"
TASK_DECORATOR = "Task decorator"


class RuntimeVaryingValueAnalyzer:
Expand Down Expand Up @@ -305,17 +306,23 @@ def __init__(self, from_imports: dict[str, tuple[str, str]]):
self.from_imports: dict[str, tuple[str, str]] = from_imports
self.dag_instances: set[str] = set()
self.is_in_dag_context: bool = False
self.function_def_context: str | None = None

def is_dag_constructor(self, node: ast.Call) -> bool:
"""Check if a call is a Dag constructor."""
if not isinstance(node.func, ast.Name):
return False

func_name = node.func.id
# to handle use case "from airflow import sdk" and "with sdk.DAG()"
if isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name):
if node.func.value.id in self.from_imports:
module, original = self.from_imports[node.func.value.id]
if (module == "airflow" or module.startswith("airflow.")) and node.func.attr in (
"DAG",
"dag",
):
return True

# "from airflow import DAG" form or "from airflow.decorator import dag"
if func_name in self.from_imports:
module, original = self.from_imports[func_name]
# to handle use case "from airflow import DAG" form or "from airflow.decorator import dag"
if isinstance(node.func, ast.Name) and node.func.id in self.from_imports:
module, original = self.from_imports[node.func.id]
if (module == "airflow" or module.startswith("airflow.")) and original in ("DAG", "dag"):
return True

Expand All @@ -329,8 +336,8 @@ def is_task_constructor(self, node: ast.Call) -> bool:
1. All calls within a Dag with block
2. Calls that receive a Dag instance as an argument (dag=...)
"""
# Inside Dag with block
if self.is_in_dag_context:
# Check whether it is nside Dag with block and has task pattern name
if self.is_in_dag_context and self.check_is_task_by_name(node.func):
return True

# Passing Dag instance as argument
Expand All @@ -345,6 +352,28 @@ def is_task_constructor(self, node: ast.Call) -> bool:

return False

def is_task_decorator(self, node: ast.expr):
if isinstance(node, ast.Name) or isinstance(node, ast.Attribute):
return self.check_is_task_by_name(node)
if isinstance(node, ast.Call):
return self.is_task_decorator(node.func)
return False

def check_is_task_by_name(self, node: ast.expr):
"""Check if task function has task name."""

def check_is_task_function_name(name):
return name.lower().endswith("operator") or name.lower().endswith("task")

if isinstance(node, ast.Name):
return check_is_task_function_name(node.id)
if isinstance(node, ast.Attribute):
if check_is_task_function_name(node.attr):
return True
return self.check_is_task_by_name(node.value)

return False

def register_dag_instance(self, var_name: str):
"""Register a Dag instance variable name."""
self.dag_instances.add(var_name)
Expand Down Expand Up @@ -375,6 +404,7 @@ def __init__(self, check_level: DagVersionInflationCheckLevel = DagVersionInflat
self.imports: dict[str, str] = {}
self.from_imports: dict[str, tuple[str, str]] = {}
self.varying_vars: dict[str, tuple[int, str]] = {}
self.varying_functions: dict[str, RuntimeVaryingValueWarning] = {}
self.check_level = check_level

# Helper objects
Expand Down Expand Up @@ -424,6 +454,9 @@ def visit_Call(self, node: ast.Call):

Check not assign but just call the function or Dag definition via decorator.
"""
if isinstance(node.func, ast.Name) and (warning := self.varying_functions.get(node.func.id)):
self.static_check_result.warnings[warning.line] = warning

if self.dag_detector.is_dag_constructor(node):
self._check_and_warn(node, WarningContext.DAG_CONSTRUCTOR)

Expand Down Expand Up @@ -464,6 +497,10 @@ def visit_With(self, node: ast.With):
# check the value defined in with statement to detect entering Dag with block
is_with_dag_context = True

# add dag variable defined in with Dag statement
if item.optional_vars and isinstance(item.optional_vars, ast.Name):
self._register_dag_instances([item.optional_vars])

if is_with_dag_context:
self.dag_detector.enter_dag_context()

Expand All @@ -473,6 +510,21 @@ def visit_With(self, node: ast.With):
# Exit Dag with block
self.dag_detector.exit_dag_context()

def visit_FunctionDef(self, node: ast.FunctionDef):
for decorator in node.decorator_list:
if self.dag_detector.is_task_decorator(decorator):
if isinstance(decorator, ast.Call):
self._check_and_warn(decorator, WarningContext.TASK_DECORATOR)
return
self.visit(decorator)

self.dag_detector.function_def_context = node.name

for body in node.body:
self.visit(body)

self.dag_detector.function_def_context = None

def _register_dag_instances(self, targets: list):
"""Register Dag instance variable names."""
for target in targets:
Expand All @@ -489,19 +541,22 @@ def _track_varying_assignment(self, node: ast.Assign):
def _check_and_warn(self, call: ast.Call, context: WarningContext):
"""Check function call arguments and generate warnings."""
if self.value_analyzer.get_varying_source(call):
self.static_check_result.warnings.append(
RuntimeVaryingValueWarning(
line=call.lineno,
col=call.col_offset,
code=ast.unparse(call),
message=self._get_warning_message(context),
)
warning = RuntimeVaryingValueWarning(
line=call.lineno,
col=call.col_offset,
code=ast.unparse(call),
message=self._get_warning_message(context),
)

if self.dag_detector.function_def_context:
self.varying_functions[self.dag_detector.function_def_context] = warning
else:
self.static_check_result.warnings[warning.line] = warning

def _get_warning_message(self, context: WarningContext) -> str:
"""Get appropriate warning message based on context."""
if self.dag_detector.is_in_dag_context and context == WarningContext.TASK_CONSTRUCTOR:
return "Don't use runtime-varying values as function arguments within with Dag block"
return "Don't use runtime-varying values as arguments of task within with Dag block"
return f"Don't use runtime-varying value as argument in {context.value}"


Expand Down
99 changes: 89 additions & 10 deletions airflow-core/tests/unit/utils/test_dag_version_inflation_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,28 @@ def test_is_task_constructor__true_when_dag_in_positional_args(self):
result = self.detector.is_task_constructor(call_node)
assert result is True

def test_is_task_decorator__check_when_normal_decorator(self):
code = """
@task(task_id='task_id')
def test():
print("test")
"""
call_node = ast.parse(code).body

assert isinstance(call_node[0], ast.FunctionDef)
assert self.detector.is_task_decorator(call_node[0].decorator_list[0]) is True

def test_is_task_decorator__check_when_attribute_decorator(self):
code = """
@sdk.task
def test():
print("test")
"""
call_node = ast.parse(code).body

assert isinstance(call_node[0], ast.FunctionDef)
assert self.detector.is_task_decorator(call_node[0].decorator_list[0]) is True

def test_enter_and_exit_dag_context(self):
"""Properly track entering and exiting Dag with-blocks."""
assert self.detector.is_in_dag_context is False
Expand Down Expand Up @@ -467,7 +489,7 @@ def test_visit_assign__warns_on_dag_with_varying_value(self):
self.checker.visit(tree)

assert len(self.checker.static_check_result.warnings) == 1
assert any("Dag constructor" in w.message for w in self.checker.static_check_result.warnings)
assert any("Dag constructor" in w.message for w in self.checker.static_check_result.warnings.values())

def test_visit_call__detects_task_in_dag_context(self):
"""Detect task creation inside Dag with block."""
Expand All @@ -484,7 +506,7 @@ def test_visit_call__detects_task_in_dag_context(self):
self.checker.visit(tree)

assert len(self.checker.static_check_result.warnings) == 1
assert any("PythonOperator" in w.code for w in self.checker.static_check_result.warnings)
assert any("PythonOperator" in w.code for w in self.checker.static_check_result.warnings.values())

def test_visit_for__warns_on_varying_range(self):
"""Warn when for-loop range is runtime-varying."""
Expand All @@ -498,7 +520,7 @@ def test_visit_for__warns_on_varying_range(self):
schedule_interval='@daily',
) as dag:
for i in [datetime.now(), "3"]:
task = BashOperator(
task = BashTask(
task_id='print_bash_hello_{i}',
bash_command=f'echo "Hello from DAG {i}!"', # !problem
dag=dag,
Expand All @@ -507,10 +529,10 @@ def test_visit_for__warns_on_varying_range(self):
tree = ast.parse(code)

self.checker.visit(tree)
warnings = self.checker.static_check_result.warnings
warnings = self.checker.static_check_result.warnings.values()

assert len(warnings) == 1
assert any("BashOperator" in w.code for w in warnings)
assert any("BashTask" in w.code for w in warnings)

def test_check_and_warn__creates_warning_for_varying_arg(self):
"""Create a warning when detecting varying positional argument."""
Expand All @@ -522,7 +544,7 @@ def test_check_and_warn__creates_warning_for_varying_arg(self):
self.checker._check_and_warn(call_node, WarningContext.DAG_CONSTRUCTOR)

assert len(self.checker.static_check_result.warnings) == 1
warning = self.checker.static_check_result.warnings[0]
warning = next(iter(self.checker.static_check_result.warnings.values()))
assert WarningContext.DAG_CONSTRUCTOR.value in warning.message
assert "datetime.now()" in warning.code

Expand All @@ -536,7 +558,7 @@ def test_check_and_warn__creates_warning_for_varying_kwarg(self):
self.checker._check_and_warn(call_node, WarningContext.TASK_CONSTRUCTOR)

assert len(self.checker.static_check_result.warnings) == 1
warning = self.checker.static_check_result.warnings[0]
warning = next(iter(self.checker.static_check_result.warnings.values()))
assert "dag_id" in warning.code
assert "datetime.now()" in warning.code

Expand All @@ -552,7 +574,7 @@ def _check_code(self, code: str) -> list[RuntimeVaryingValueWarning]:
tree = ast.parse(code)
checker = AirflowRuntimeVaryingValueChecker()
checker.visit(tree)
return checker.static_check_result.warnings
return list(checker.static_check_result.warnings.values())

def test_antipattern__dynamic_dag_id_with_timestamp(self):
"""ANTI-PATTERN: Using timestamps in Dag IDs."""
Expand Down Expand Up @@ -663,6 +685,7 @@ def test_dag_decorator_pattern__currently_not_detected(self):
code = """
from airflow.decorators import dag, task
from datetime import datetime
from random import random

@dag(dag_id=f"my_dag_{datetime.now()}") # !problem
def my_dag_function():
Expand All @@ -677,6 +700,10 @@ def my_task():
assert len(warnings) == 1

def test_dag_generated_in_for_or_function_statement(self):
"""
There are runtime-varying case in create_dag function.
But the function doesn't use in here so doesn't make warning
"""
code = """
from airflow import DAG
from airflow.operators.bash import BashOperator
Expand All @@ -691,7 +718,7 @@ def create_dag(dag_id, task_id):

with DAG(
dag_id,
default_args=default_args, # !problem
default_args=default_args, # not problem, because the function create_dag not called in statement
) as dag:
task1 = BashOperator(
task_id=task_id
Expand Down Expand Up @@ -732,4 +759,56 @@ def create_dag(dag_id, task_id):
task1 >> task2 >> task3
"""
warnings = self._check_code(code)
assert len(warnings) == 5
assert len(warnings) == 4

def test_import_dag_from_sdk_module(self):
code = """
from airflow import sdk
from airflow.providers.standard.operators.python import PythonOperator
from time import sleep
import datetime

with sdk.DAG(
dag_id="test1",
schedule="* * * * *",
start_date=datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(minutes=1),
tags=["example"],
) as dag:
PythonOperator(
task_id="test1_task",
python_callable=lambda: print(datetime.now()),
)
"""
warnings = self._check_code(code)
assert len(warnings) == 1

def test_python_task_with_task_decorator(self):
code = """
from airflow.sdk import task, DAG
import datetime

with DAG(
dag_id="test1",
schedule="* * * * *",
tags=["example"],
) as dag:
# the function is serialized with code string byte, so it doesn't affect dag version

@task
def add_one_task(x: int):
from random import random

for i in range(int(random() * 50)):
do_task(f"Simulating work... {i+1}")
sleep(1)
return x + 1

#!problem
@task(task_id=datetime.now())
def test_task():
print("test")

add_one(3)
"""
warnings = self._check_code(code)
assert len(warnings) == 1
Loading