Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-82 Save references between assets and triggers #43826

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions airflow/assets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@

from sqlalchemy.orm.session import Session

from airflow.triggers.base import BaseTrigger


__all__ = ["Asset", "AssetAll", "AssetAny", "Dataset"]


Expand Down Expand Up @@ -276,20 +279,43 @@ class Asset(os.PathLike, BaseAsset):
uri: str
group: str
extra: dict[str, Any]
watchers: list[BaseTrigger]

asset_type: ClassVar[str] = "asset"
__version__: ClassVar[int] = 1

@overload
def __init__(self, name: str, uri: str, *, group: str = "", extra: dict | None = None) -> None:
def __init__(
self,
name: str,
uri: str,
*,
group: str = "",
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
) -> None:
"""Canonical; both name and uri are provided."""

@overload
def __init__(self, name: str, *, group: str = "", extra: dict | None = None) -> None:
def __init__(
self,
name: str,
*,
group: str = "",
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
) -> None:
"""It's possible to only provide the name, either by keyword or as the only positional argument."""

@overload
def __init__(self, *, uri: str, group: str = "", extra: dict | None = None) -> None:
def __init__(
self,
*,
uri: str,
group: str = "",
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
) -> None:
"""It's possible to only provide the URI as a keyword argument."""

def __init__(
Expand All @@ -299,6 +325,7 @@ def __init__(
*,
group: str = "",
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
) -> None:
if name is None and uri is None:
raise TypeError("Asset() requires either 'name' or 'uri'")
Expand All @@ -311,6 +338,7 @@ def __init__(
self.uri = _sanitize_uri(_validate_non_empty_identifier(self, fields["uri"], uri))
self.group = _validate_identifier(self, fields["group"], group) if group else self.asset_type
self.extra = _set_extra_default(extra)
self.watchers = watchers or []

def __fspath__(self) -> str:
return self.uri
Expand Down
74 changes: 74 additions & 0 deletions airflow/dag_processing/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
)
from airflow.models.dag import DAG, DagModel, DagOwnerAttributes, DagTag
from airflow.models.dagrun import DagRun
from airflow.models.trigger import Trigger
from airflow.utils.sqlalchemy import with_row_locks
from airflow.utils.timezone import utcnow
from airflow.utils.types import DagRunType
Expand All @@ -55,6 +56,7 @@
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select

from airflow.triggers.base import BaseTrigger
from airflow.typing_compat import Self

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -425,3 +427,75 @@ def add_task_asset_references(
for task_id, asset_id in referenced_outlets
if (task_id, asset_id) not in orm_refs
)

def add_asset_trigger_references(
self, dags: dict[str, DagModel], assets: dict[tuple[str, str], AssetModel], *, session: Session
) -> None:
# Update references from assets being used
refs_to_add: dict[tuple[str, str], set[str]] = {}
refs_to_remove: dict[tuple[str, str], set[str]] = {}
triggers: dict[str, BaseTrigger] = {}
for name_uri, asset in self.assets.items():
asset_model = assets[name_uri]
trigger_class_path_to_trigger_dict: dict[str, BaseTrigger] = {
trigger.serialize()[0]: trigger for trigger in asset.watchers
}
triggers.update(trigger_class_path_to_trigger_dict)

trigger_class_paths_from_asset: set[str] = set(trigger_class_path_to_trigger_dict.keys())
trigger_class_paths_from_asset_model: set[str] = {
trigger.classpath for trigger in asset_model.triggers
}

# Optimization: no diff between the DB and DAG definitions, no update needed
if trigger_class_paths_from_asset == trigger_class_paths_from_asset_model:
continue

diff_to_add = trigger_class_paths_from_asset - trigger_class_paths_from_asset_model
diff_to_remove = trigger_class_paths_from_asset_model - trigger_class_paths_from_asset
if diff_to_add:
refs_to_add[name_uri] = diff_to_add
if diff_to_remove:
refs_to_remove[name_uri] = diff_to_remove

if refs_to_add:
all_classpaths = {classpath for classpaths in refs_to_add.values() for classpath in classpaths}
orm_triggers: dict[str, Trigger] = {
trigger.classpath: trigger
for trigger in session.scalars(select(Trigger).where(Trigger.classpath.in_(all_classpaths)))
}

# Create new triggers
new_trigger_models = [
trigger
for trigger in [
Trigger.from_object(triggers[classpath])
for classpath in all_classpaths
if classpath not in orm_triggers
]
]
session.add_all(new_trigger_models)
orm_triggers.update((trigger.classpath, trigger) for trigger in new_trigger_models)

# Add new references
for name_uri, classpaths in refs_to_add.items():
asset_model = assets[name_uri]
asset_model.triggers.extend(
[orm_triggers.get(trigger_class_path) for trigger_class_path in classpaths]
)

if refs_to_remove:
# Remove old references
for name_uri, classpaths in refs_to_remove.items():
asset_model = assets[name_uri]
asset_model.triggers = [
trigger for trigger in asset_model.triggers if trigger.classpath not in classpaths
]

# Remove references from assets no longer used
orphan_assets = session.scalars(
select(AssetModel).filter(~AssetModel.consuming_dags.any()).filter(AssetModel.triggers.any())
)
for asset_model in orphan_assets:
if (asset_model.name, asset_model.uri) not in self.assets:
asset_model.triggers = []
3 changes: 3 additions & 0 deletions airflow/decorators/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

if TYPE_CHECKING:
from airflow.io.path import ObjectStoragePath
from airflow.triggers.base import BaseTrigger


class _AssetMainOperator(PythonOperator):
Expand Down Expand Up @@ -116,6 +117,7 @@ class asset:
uri: str | ObjectStoragePath | None = None
group: str = ""
extra: dict[str, Any] = attrs.field(factory=dict)
watchers: list[BaseTrigger] = attrs.field(factory=list)

def __call__(self, f: Callable) -> AssetDefinition:
if (name := f.__name__) != f.__qualname__:
Expand All @@ -126,6 +128,7 @@ def __call__(self, f: Callable) -> AssetDefinition:
uri=name if self.uri is None else str(self.uri),
group=self.group,
extra=self.extra,
watchers=self.watchers,
function=f,
schedule=self.schedule,
)
1 change: 1 addition & 0 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1842,6 +1842,7 @@ def bulk_write_to_db(
asset_op.add_dag_asset_references(orm_dags, orm_assets, session=session)
asset_op.add_dag_asset_alias_references(orm_dags, orm_asset_aliases, session=session)
asset_op.add_task_asset_references(orm_dags, orm_assets, session=session)
asset_op.add_asset_trigger_references(orm_dags, orm_assets, session=session)
session.flush()

@provide_session
Expand Down
3 changes: 1 addition & 2 deletions airflow/models/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select

from airflow.serialization.pydantic.trigger import TriggerPydantic
from airflow.triggers.base import BaseTrigger


Expand Down Expand Up @@ -141,7 +140,7 @@ def rotate_fernet_key(self):
@classmethod
@internal_api_call
@provide_session
def from_object(cls, trigger: BaseTrigger, session=NEW_SESSION) -> Trigger | TriggerPydantic:
def from_object(cls, trigger: BaseTrigger, session=NEW_SESSION) -> Trigger:
"""Alternative constructor that creates a trigger row based directly off of a Trigger object."""
classpath, kwargs = trigger.serialize()
return cls(classpath=classpath, kwargs=kwargs)
Expand Down
6 changes: 3 additions & 3 deletions task_sdk/src/airflow/sdk/definitions/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,9 +557,9 @@ def __lt__(self, other):
def __hash__(self):
hash_components: list[Any] = [type(self)]
for c in _DAG_HASH_ATTRS:
# task_ids returns a list and lists can't be hashed
if c == "task_ids":
val = tuple(self.task_dict)
# If it is a list, convert to tuple because lists can't be hashed
vincbeck marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(getattr(self, c, None), list):
val = tuple(getattr(self, c))
else:
val = getattr(self, c, None)
try:
Expand Down