diff --git a/airflow-core/src/airflow/serialization/helpers.py b/airflow-core/src/airflow/serialization/helpers.py index 723c113709a87..5367865673f04 100644 --- a/airflow-core/src/airflow/serialization/helpers.py +++ b/airflow-core/src/airflow/serialization/helpers.py @@ -21,6 +21,7 @@ import contextlib from typing import TYPE_CHECKING, Any +from airflow._shared.module_loading import qualname from airflow._shared.secrets_masker import redact from airflow.configuration import conf from airflow.settings import json @@ -33,6 +34,9 @@ def serialize_template_field(template_field: Any, name: str) -> str | dict | lis """ Return a serializable representation of the templated field. + If ``templated_field`` is provided via a callable, compute MD5 hash of source + and return following serialized value: `` Any: try: serialized = template_field.serialize() except AttributeError: - serialized = str(template_field) + if callable(template_field): + full_qualified_name = qualname(template_field, True) + serialized = f"" + else: + serialized = str(template_field) if len(serialized) > max_length: rendered = redact(serialized, name) return ( diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py b/airflow-core/tests/unit/serialization/test_dag_serialization.py index bc9cd305a2422..5e054ea2114a2 100644 --- a/airflow-core/tests/unit/serialization/test_dag_serialization.py +++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py @@ -1639,6 +1639,31 @@ def test_task_resources(self): assert deserialized_task.resources == task.resources assert isinstance(deserialized_task.resources, Resources) + def test_template_field_via_callable_serialization(self): + """ + Test operator template fields serialization when provided as a callable. + """ + + def fn_template_field_callable(context, jinja_env): + pass + + def fn_returns_callable(): + def get_arg(context, jinja_env): + pass + + return get_arg + + task = MockOperator(task_id="task1", arg1=fn_template_field_callable, arg2=fn_returns_callable()) + serialized_task = OperatorSerialization.serialize_operator(task) + assert ( + serialized_task.get("arg1") + == ".fn_template_field_callable>" + ) + assert ( + serialized_task.get("arg2") + == ".fn_returns_callable..get_arg>" + ) + def test_task_group_serialization(self): """ Test TaskGroup serialization/deserialization. diff --git a/shared/module_loading/src/airflow_shared/module_loading/__init__.py b/shared/module_loading/src/airflow_shared/module_loading/__init__.py index 3268dc3ddc23e..8868cc704ab34 100644 --- a/shared/module_loading/src/airflow_shared/module_loading/__init__.py +++ b/shared/module_loading/src/airflow_shared/module_loading/__init__.py @@ -63,10 +63,13 @@ def import_string(dotted_path: str): raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute/class') -def qualname(o: object | Callable) -> str: - """Convert an attribute/class/function to a string importable by ``import_string``.""" - if callable(o) and hasattr(o, "__module__") and hasattr(o, "__name__"): - return f"{o.__module__}.{o.__name__}" +def qualname(o: object | Callable, use_qualname: bool = False) -> str: + """Convert an attribute/class/callable to a string importable by ``import_string``.""" + if callable(o) and hasattr(o, "__module__"): + if use_qualname and hasattr(o, "__qualname__"): + return f"{o.__module__}.{o.__qualname__}" + if hasattr(o, "__name__"): + return f"{o.__module__}.{o.__name__}" cls = o