Skip to content

Commit 230194c

Browse files
omkar-fossellisms
authored andcommitted
Migrate public endpoint Get Task to FastAPI (apache#43718)
* Migrate public endpoint Get Task to FastAPI, with main resynced * Re-run static checks * Remove extra router line
1 parent de13c93 commit 230194c

File tree

15 files changed

+1313
-4
lines changed

15 files changed

+1313
-4
lines changed

airflow/api_connexion/endpoints/task_endpoint.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,14 @@
2525
from airflow.auth.managers.models.resource_details import DagAccessEntity
2626
from airflow.exceptions import TaskNotFound
2727
from airflow.utils.airflow_flask_app import get_airflow_app
28+
from airflow.utils.api_migration import mark_fastapi_migration_done
2829

2930
if TYPE_CHECKING:
3031
from airflow import DAG
3132
from airflow.api_connexion.types import APIResponse
3233

3334

35+
@mark_fastapi_migration_done
3436
@security.requires_access_dag("GET", DagAccessEntity.TASK)
3537
def get_task(*, dag_id: str, task_id: str) -> APIResponse:
3638
"""Get simplified representation of a task."""

airflow/api_fastapi/common/types.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,71 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19+
import inspect
20+
from datetime import timedelta
1921
from typing import Annotated
2022

21-
from pydantic import AfterValidator, AwareDatetime
23+
from pydantic import AfterValidator, AliasGenerator, AwareDatetime, BaseModel, BeforeValidator, ConfigDict
2224

25+
from airflow.models.mappedoperator import MappedOperator
26+
from airflow.serialization.serialized_objects import SerializedBaseOperator
2327
from airflow.utils import timezone
2428

2529
UtcDateTime = Annotated[AwareDatetime, AfterValidator(lambda d: d.astimezone(timezone.utc))]
2630
"""UTCDateTime is a datetime with timezone information"""
31+
32+
33+
def _validate_timedelta_field(td: timedelta | None) -> TimeDelta | None:
34+
"""Validate the execution_timeout property."""
35+
if td is None:
36+
return None
37+
return TimeDelta(
38+
days=td.days,
39+
seconds=td.seconds,
40+
microseconds=td.microseconds,
41+
)
42+
43+
44+
class TimeDelta(BaseModel):
45+
"""TimeDelta can be used to interact with datetime.timedelta objects."""
46+
47+
object_type: str = "TimeDelta"
48+
days: int
49+
seconds: int
50+
microseconds: int
51+
52+
model_config = ConfigDict(
53+
alias_generator=AliasGenerator(
54+
serialization_alias=lambda field_name: {
55+
"object_type": "__type",
56+
}.get(field_name, field_name),
57+
)
58+
)
59+
60+
61+
TimeDeltaWithValidation = Annotated[TimeDelta, BeforeValidator(_validate_timedelta_field)]
62+
63+
64+
def get_class_ref(obj) -> dict[str, str | None]:
65+
"""Return the class_ref dict for obj."""
66+
is_mapped_or_serialized = isinstance(obj, (MappedOperator, SerializedBaseOperator))
67+
68+
module_path = None
69+
if is_mapped_or_serialized:
70+
module_path = obj._task_module
71+
else:
72+
module_type = inspect.getmodule(obj)
73+
module_path = module_type.__name__ if module_type else None
74+
75+
class_name = None
76+
if is_mapped_or_serialized:
77+
class_name = obj._task_type
78+
elif obj.__class__ is type:
79+
class_name = obj.__name__
80+
else:
81+
class_name = type(obj).__name__
82+
83+
return {
84+
"module_path": module_path,
85+
"class_name": class_name,
86+
}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from __future__ import annotations
19+
20+
from collections import abc
21+
from datetime import datetime
22+
23+
from pydantic import BaseModel, computed_field, field_validator
24+
25+
from airflow.api_fastapi.common.types import TimeDeltaWithValidation
26+
from airflow.serialization.serialized_objects import encode_priority_weight_strategy
27+
from airflow.task.priority_strategy import PriorityWeightStrategy
28+
29+
30+
class TaskResponse(BaseModel):
31+
"""Task serializer for responses."""
32+
33+
task_id: str | None
34+
task_display_name: str | None
35+
owner: str | None
36+
start_date: datetime | None
37+
end_date: datetime | None
38+
trigger_rule: str | None
39+
depends_on_past: bool
40+
wait_for_downstream: bool
41+
retries: float | None
42+
queue: str | None
43+
pool: str | None
44+
pool_slots: float | None
45+
execution_timeout: TimeDeltaWithValidation | None
46+
retry_delay: TimeDeltaWithValidation | None
47+
retry_exponential_backoff: bool
48+
priority_weight: float | None
49+
weight_rule: str | None
50+
ui_color: str | None
51+
ui_fgcolor: str | None
52+
template_fields: list[str] | None
53+
downstream_task_ids: list[str] | None
54+
doc_md: str | None
55+
operator_name: str | None
56+
params: abc.MutableMapping | None
57+
class_ref: dict | None
58+
is_mapped: bool | None
59+
60+
@field_validator("weight_rule", mode="before")
61+
@classmethod
62+
def validate_weight_rule(cls, wr: str | PriorityWeightStrategy | None) -> str | None:
63+
"""Validate the weight_rule property."""
64+
if wr is None:
65+
return None
66+
if isinstance(wr, str):
67+
return wr
68+
return encode_priority_weight_strategy(wr)
69+
70+
@field_validator("params", mode="before")
71+
@classmethod
72+
def get_params(cls, params: abc.MutableMapping | None) -> dict | None:
73+
"""Convert params attribute to dict representation."""
74+
if params is None:
75+
return None
76+
return {param_name: param_val.dump() for param_name, param_val in params.items()}
77+
78+
# Mypy issue https://github.com/python/mypy/issues/1362
79+
@computed_field # type: ignore[misc]
80+
@property
81+
def extra_links(self) -> list[str]:
82+
"""Extract and return extra_links."""
83+
return getattr(self, "operator_extra_links", [])

0 commit comments

Comments
 (0)