Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion airflow-core/src/airflow/serialization/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import contextlib
from typing import TYPE_CHECKING, Any

from airflow._shared.module_loading import qualname
from airflow._shared.secrets_masker import redact
from airflow.configuration import conf
from airflow.settings import json
Expand All @@ -33,6 +34,9 @@ def serialize_template_field(template_field: Any, name: str) -> str | dict | lis
"""
Return a serializable representation of the templated field.

If ``templated_field`` is provided via a callable, compute MD5 hash of source
and return following serialized value: ``<callable fingerprint(MD5) hash_value``

If ``templated_field`` contains a class or instance that requires recursive
templating, store them as strings. Otherwise simply return the field as-is.
"""
Expand Down Expand Up @@ -71,7 +75,11 @@ def sort_dict_recursively(obj: Any) -> Any:
try:
serialized = template_field.serialize()
except AttributeError:
serialized = str(template_field)
if callable(template_field):
full_qualified_name = qualname(template_field, True)
serialized = f"<callable {full_qualified_name}>"
else:
serialized = str(template_field)
if len(serialized) > max_length:
rendered = redact(serialized, name)
return (
Expand Down
25 changes: 25 additions & 0 deletions airflow-core/tests/unit/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1639,6 +1639,31 @@ def test_task_resources(self):
assert deserialized_task.resources == task.resources
assert isinstance(deserialized_task.resources, Resources)

def test_template_field_via_callable_serialization(self):
"""
Test operator template fields serialization when provided as a callable.
"""

def fn_template_field_callable(context, jinja_env):
pass

def fn_returns_callable():
def get_arg(context, jinja_env):
pass

return get_arg

task = MockOperator(task_id="task1", arg1=fn_template_field_callable, arg2=fn_returns_callable())
serialized_task = OperatorSerialization.serialize_operator(task)
assert (
serialized_task.get("arg1")
== "<callable unit.serialization.test_dag_serialization.TestStringifiedDAGs.test_template_field_via_callable_serialization.<locals>.fn_template_field_callable>"
)
assert (
serialized_task.get("arg2")
== "<callable unit.serialization.test_dag_serialization.TestStringifiedDAGs.test_template_field_via_callable_serialization.<locals>.fn_returns_callable.<locals>.get_arg>"
)

def test_task_group_serialization(self):
"""
Test TaskGroup serialization/deserialization.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,13 @@ def import_string(dotted_path: str):
raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute/class')


def qualname(o: object | Callable) -> str:
"""Convert an attribute/class/function to a string importable by ``import_string``."""
if callable(o) and hasattr(o, "__module__") and hasattr(o, "__name__"):
return f"{o.__module__}.{o.__name__}"
def qualname(o: object | Callable, use_qualname: bool = False) -> str:
"""Convert an attribute/class/callable to a string importable by ``import_string``."""
if callable(o) and hasattr(o, "__module__"):
if use_qualname and hasattr(o, "__qualname__"):
return f"{o.__module__}.{o.__qualname__}"
if hasattr(o, "__name__"):
return f"{o.__module__}.{o.__name__}"

cls = o

Expand Down