Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-84 Add ability to update dag run note in PATCH dag_run endpoint #43508

Merged
Merged
12 changes: 9 additions & 3 deletions airflow/api_fastapi/core_api/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2855,10 +2855,16 @@ components:
DAGRunPatchBody:
properties:
state:
$ref: '#/components/schemas/DAGRunPatchStates'
anyOf:
- $ref: '#/components/schemas/DAGRunPatchStates'
- type: 'null'
note:
anyOf:
- type: string
maxLength: 1000
- type: 'null'
title: Note
type: object
required:
- state
title: DAGRunPatchBody
description: DAG Run Serializer for PATCH requests.
DAGRunPatchStates:
Expand Down
46 changes: 30 additions & 16 deletions airflow/api_fastapi/core_api/routes/public/dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from fastapi import Depends, HTTPException, Query, Request, status
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, joinedload
from typing_extensions import Annotated

from airflow.api.common.mark_tasks import (
Expand Down Expand Up @@ -108,7 +108,10 @@ async def patch_dag_run_state(
update_mask: list[str] | None = Query(None),
) -> DAGRunResponse:
"""Modify a DAG Run."""
dag_run = session.scalar(select(DagRun).filter_by(dag_id=dag_id, run_id=dag_run_id))
ALLOWED_FIELD_MASK = ["state", "note"]
dag_run = session.scalar(
select(DagRun).filter_by(dag_id=dag_id, run_id=dag_run_id).options(joinedload(DagRun.dag_run_note))
)
if dag_run is None:
raise HTTPException(
status.HTTP_404_NOT_FOUND,
Expand All @@ -121,23 +124,34 @@ async def patch_dag_run_state(
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with id {dag_id} was not found")

if update_mask:
if update_mask != ["state"]:
raise HTTPException(
status.HTTP_400_BAD_REQUEST, "Only `state` field can be updated through the REST API"
)
for each in update_mask:
if each not in ALLOWED_FIELD_MASK:
raise HTTPException(400, f"Invalid field `{each}` in update mask")
else:
update_mask = ["state"]
update_mask = ALLOWED_FIELD_MASK

for attr_name in update_mask:
if attr_name == "state":
state = getattr(patch_body, attr_name)
if state == DAGRunPatchStates.SUCCESS:
set_dag_run_state_to_success(dag=dag, run_id=dag_run.run_id, commit=True)
elif state == DAGRunPatchStates.QUEUED:
set_dag_run_state_to_queued(dag=dag, run_id=dag_run.run_id, commit=True)
else:
set_dag_run_state_to_failed(dag=dag, run_id=dag_run.run_id, commit=True)
if "state" in update_mask:
attr_value = getattr(patch_body, "state")
if attr_value == DAGRunPatchStates.SUCCESS:
set_dag_run_state_to_success(dag=dag, run_id=dag_run.run_id, commit=True, session=session)
elif attr_value == DAGRunPatchStates.QUEUED:
set_dag_run_state_to_queued(dag=dag, run_id=dag_run.run_id, commit=True, session=session)
elif attr_value == DAGRunPatchStates.FAILED:
set_dag_run_state_to_failed(dag=dag, run_id=dag_run.run_id, commit=True, session=session)

dag_run = session.get(DagRun, dag_run.id)

for attr_name in update_mask:
attr_value = getattr(patch_body, attr_name)
if attr_value is None:
continue
rawwar marked this conversation as resolved.
Show resolved Hide resolved
if attr_name == "note":
# Once Authentication is implemented in this FastAPI app,
# user id will be added when updating dag run note
# Refer to https://github.com/apache/airflow/issues/43534
if dag_run.dag_run_note is None:
dag_run.note = (attr_value, None)
else:
dag_run.dag_run_note.content = attr_value
pierrejeambrun marked this conversation as resolved.
Show resolved Hide resolved

return DAGRunResponse.model_validate(dag_run, from_attributes=True)
3 changes: 2 additions & 1 deletion airflow/api_fastapi/core_api/serializers/dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class DAGRunPatchStates(str, Enum):
class DAGRunPatchBody(BaseModel):
"""DAG Run Serializer for PATCH requests."""

state: DAGRunPatchStates
state: DAGRunPatchStates | None = None
note: str | None = Field(None, max_length=1000)


class DAGRunResponse(BaseModel):
Expand Down
22 changes: 20 additions & 2 deletions airflow/ui/openapi-gen/requests/schemas.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -875,11 +875,29 @@ export const $DAGResponse = {
export const $DAGRunPatchBody = {
properties: {
state: {
$ref: "#/components/schemas/DAGRunPatchStates",
anyOf: [
{
$ref: "#/components/schemas/DAGRunPatchStates",
},
{
type: "null",
},
],
},
note: {
anyOf: [
{
type: "string",
maxLength: 1000,
},
{
type: "null",
},
],
title: "Note",
},
},
type: "object",
required: ["state"],
title: "DAGRunPatchBody",
description: "DAG Run Serializer for PATCH requests.",
} as const;
Expand Down
3 changes: 2 additions & 1 deletion airflow/ui/openapi-gen/requests/types.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ export type DAGResponse = {
* DAG Run Serializer for PATCH requests.
*/
export type DAGRunPatchBody = {
state: DAGRunPatchStates;
state?: DAGRunPatchStates | null;
note?: string | null;
};

/**
Expand Down
58 changes: 47 additions & 11 deletions tests/api_fastapi/core_api/routes/public/test_dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
DAG2_RUN2_TRIGGERED_BY = DagRunTriggeredByType.REST_API
START_DATE = datetime(2024, 6, 15, 0, 0, tzinfo=timezone.utc)
EXECUTION_DATE = datetime(2024, 6, 16, 0, 0, tzinfo=timezone.utc)
DAG1_NOTE = "test_note"
DAG1_RUN1_NOTE = "test_note"


@pytest.fixture(autouse=True)
Expand All @@ -66,13 +66,13 @@ def setup(dag_maker, session=None):
start_date=START_DATE,
):
EmptyOperator(task_id="task_1")
dag1 = dag_maker.create_dagrun(
dag_run1 = dag_maker.create_dagrun(
run_id=DAG1_RUN1_ID,
state=DAG1_RUN1_STATE,
run_type=DAG1_RUN1_RUN_TYPE,
triggered_by=DAG1_RUN1_TRIGGERED_BY,
)
dag1.note = (DAG1_NOTE, 1)
dag_run1.note = (DAG1_RUN1_NOTE, 1)

dag_maker.create_dagrun(
run_id=DAG1_RUN2_ID,
Expand Down Expand Up @@ -114,7 +114,14 @@ class TestGetDagRun:
@pytest.mark.parametrize(
"dag_id, run_id, state, run_type, triggered_by, dag_run_note",
[
(DAG1_ID, DAG1_RUN1_ID, DAG1_RUN1_STATE, DAG1_RUN1_RUN_TYPE, DAG1_RUN1_TRIGGERED_BY, DAG1_NOTE),
(
DAG1_ID,
DAG1_RUN1_ID,
DAG1_RUN1_STATE,
DAG1_RUN1_RUN_TYPE,
DAG1_RUN1_TRIGGERED_BY,
DAG1_RUN1_NOTE,
),
(DAG1_ID, DAG1_RUN2_ID, DAG1_RUN2_STATE, DAG1_RUN2_RUN_TYPE, DAG1_RUN2_TRIGGERED_BY, None),
(DAG2_ID, DAG2_RUN1_ID, DAG2_RUN1_STATE, DAG2_RUN1_RUN_TYPE, DAG2_RUN1_TRIGGERED_BY, None),
(DAG2_ID, DAG2_RUN2_ID, DAG2_RUN2_STATE, DAG2_RUN2_RUN_TYPE, DAG2_RUN2_TRIGGERED_BY, None),
Expand All @@ -140,20 +147,49 @@ def test_get_dag_run_not_found(self, test_client):

class TestPatchDagRun:
@pytest.mark.parametrize(
"dag_id, run_id, state, response_state",
"dag_id, run_id, patch_body, response_body",
[
(DAG1_ID, DAG1_RUN1_ID, DagRunState.FAILED, DagRunState.FAILED),
(DAG1_ID, DAG1_RUN2_ID, DagRunState.SUCCESS, DagRunState.SUCCESS),
(DAG2_ID, DAG2_RUN1_ID, DagRunState.QUEUED, DagRunState.QUEUED),
(
DAG1_ID,
DAG1_RUN1_ID,
{"state": DagRunState.FAILED},
{"state": DagRunState.FAILED, "note": DAG1_RUN1_NOTE},
),
(
DAG1_ID,
DAG1_RUN2_ID,
{"state": DagRunState.SUCCESS},
{"state": DagRunState.SUCCESS, "note": None},
),
(
DAG2_ID,
DAG2_RUN1_ID,
{"state": DagRunState.QUEUED},
{"state": DagRunState.QUEUED, "note": None},
),
(
DAG1_ID,
DAG1_RUN1_ID,
{"note": "updated note"},
{"state": DagRunState.SUCCESS, "note": "updated note"},
),
(
DAG1_ID,
DAG1_RUN2_ID,
{"note": "new note", "state": DagRunState.FAILED},
{"state": DagRunState.FAILED, "note": "new note"},
),
(DAG1_ID, DAG1_RUN2_ID, {"note": None}, {"state": DagRunState.FAILED, "note": None}),
],
)
def test_patch_dag_run(self, test_client, dag_id, run_id, state, response_state):
response = test_client.patch(f"/public/dags/{dag_id}/dagRuns/{run_id}", json={"state": state})
def test_patch_dag_run(self, test_client, dag_id, run_id, patch_body, response_body):
response = test_client.patch(f"/public/dags/{dag_id}/dagRuns/{run_id}", json=patch_body)
assert response.status_code == 200
body = response.json()
assert body["dag_id"] == dag_id
assert body["run_id"] == run_id
assert body["state"] == response_state
assert body.get("state") == response_body.get("state")
assert body.get("note") == response_body.get("note")

@pytest.mark.parametrize(
"query_params,patch_body, expected_status_code",
Expand Down