Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-84 Add ability to update dag run note in PATCH dag_run endpoint #43508

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions airflow/api_fastapi/core_api/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1318,9 +1318,9 @@ paths:
patch:
tags:
- DagRun
summary: Patch Dag Run State
summary: Patch Dag Run
description: Modify a DAG Run.
operationId: patch_dag_run_state
operationId: patch_dag_run
parameters:
- name: dag_id
in: path
Expand Down Expand Up @@ -3694,10 +3694,16 @@ components:
DAGRunPatchBody:
properties:
state:
$ref: '#/components/schemas/DAGRunPatchStates'
anyOf:
- $ref: '#/components/schemas/DAGRunPatchStates'
- type: 'null'
note:
anyOf:
- type: string
maxLength: 1000
- type: 'null'
title: Note
type: object
required:
- state
title: DAGRunPatchBody
description: DAG Run Serializer for PATCH requests.
DAGRunPatchStates:
Expand Down
34 changes: 20 additions & 14 deletions airflow/api_fastapi/core_api/routes/public/dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def delete_dag_run(dag_id: str, dag_run_id: str, session: Annotated[Session, Dep
]
),
)
def patch_dag_run_state(
def patch_dag_run(
dag_id: str,
dag_run_id: str,
patch_body: DAGRunPatchBody,
Expand All @@ -121,23 +121,29 @@ def patch_dag_run_state(
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with id {dag_id} was not found")

if update_mask:
if update_mask != ["state"]:
raise HTTPException(
status.HTTP_400_BAD_REQUEST, "Only `state` field can be updated through the REST API"
)
data = patch_body.model_dump(include=set(update_mask))
else:
update_mask = ["state"]
data = patch_body.model_dump()

for attr_name in update_mask:
for attr_name, attr_value in data.items():
if attr_name == "state":
state = getattr(patch_body, attr_name)
if state == DAGRunPatchStates.SUCCESS:
set_dag_run_state_to_success(dag=dag, run_id=dag_run.run_id, commit=True)
elif state == DAGRunPatchStates.QUEUED:
set_dag_run_state_to_queued(dag=dag, run_id=dag_run.run_id, commit=True)
attr_value = getattr(patch_body, "state")
if attr_value == DAGRunPatchStates.SUCCESS:
set_dag_run_state_to_success(dag=dag, run_id=dag_run.run_id, commit=True, session=session)
elif attr_value == DAGRunPatchStates.QUEUED:
set_dag_run_state_to_queued(dag=dag, run_id=dag_run.run_id, commit=True, session=session)
elif attr_value == DAGRunPatchStates.FAILED:
set_dag_run_state_to_failed(dag=dag, run_id=dag_run.run_id, commit=True, session=session)
elif attr_name == "note":
# Once Authentication is implemented in this FastAPI app,
# user id will be added when updating dag run note
# Refer to https://github.com/apache/airflow/issues/43534
dag_run = session.get(DagRun, dag_run.id)
if dag_run.dag_run_note is None:
dag_run.note = (attr_value, None)
else:
set_dag_run_state_to_failed(dag=dag, run_id=dag_run.run_id, commit=True)
dag_run.dag_run_note.content = attr_value
pierrejeambrun marked this conversation as resolved.
Show resolved Hide resolved

session.refresh(dag_run)
dag_run = session.get(DagRun, dag_run.id)

return DAGRunResponse.model_validate(dag_run, from_attributes=True)
3 changes: 2 additions & 1 deletion airflow/api_fastapi/core_api/serializers/dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class DAGRunPatchStates(str, Enum):
class DAGRunPatchBody(BaseModel):
"""DAG Run Serializer for PATCH requests."""

state: DAGRunPatchStates
state: DAGRunPatchStates | None = None
note: str | None = Field(None, max_length=1000)


class DAGRunResponse(BaseModel):
Expand Down
4 changes: 2 additions & 2 deletions airflow/ui/openapi-gen/queries/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -935,8 +935,8 @@ export type DagServicePatchDagMutationResult = Awaited<
export type ConnectionServicePatchConnectionMutationResult = Awaited<
ReturnType<typeof ConnectionService.patchConnection>
>;
export type DagRunServicePatchDagRunStateMutationResult = Awaited<
ReturnType<typeof DagRunService.patchDagRunState>
export type DagRunServicePatchDagRunMutationResult = Awaited<
ReturnType<typeof DagRunService.patchDagRun>
>;
export type PoolServicePatchPoolMutationResult = Awaited<
ReturnType<typeof PoolService.patchPool>
Expand Down
8 changes: 4 additions & 4 deletions airflow/ui/openapi-gen/queries/queries.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1912,7 +1912,7 @@ export const useConnectionServicePatchConnection = <
...options,
});
/**
* Patch Dag Run State
* Patch Dag Run
* Modify a DAG Run.
* @param data The data for the request.
* @param data.dagId
Expand All @@ -1922,8 +1922,8 @@ export const useConnectionServicePatchConnection = <
* @returns DAGRunResponse Successful Response
* @throws ApiError
*/
export const useDagRunServicePatchDagRunState = <
TData = Common.DagRunServicePatchDagRunStateMutationResult,
export const useDagRunServicePatchDagRun = <
TData = Common.DagRunServicePatchDagRunMutationResult,
TError = unknown,
TContext = unknown,
>(
Expand Down Expand Up @@ -1954,7 +1954,7 @@ export const useDagRunServicePatchDagRunState = <
TContext
>({
mutationFn: ({ dagId, dagRunId, requestBody, updateMask }) =>
DagRunService.patchDagRunState({
DagRunService.patchDagRun({
dagId,
dagRunId,
requestBody,
Expand Down
22 changes: 20 additions & 2 deletions airflow/ui/openapi-gen/requests/schemas.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -981,11 +981,29 @@ export const $DAGResponse = {
export const $DAGRunPatchBody = {
properties: {
state: {
$ref: "#/components/schemas/DAGRunPatchStates",
anyOf: [
{
$ref: "#/components/schemas/DAGRunPatchStates",
},
{
type: "null",
},
],
},
note: {
anyOf: [
{
type: "string",
maxLength: 1000,
},
{
type: "null",
},
],
title: "Note",
},
},
type: "object",
required: ["state"],
title: "DAGRunPatchBody",
description: "DAG Run Serializer for PATCH requests.",
} as const;
Expand Down
12 changes: 6 additions & 6 deletions airflow/ui/openapi-gen/requests/services.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ import type {
GetDagRunResponse,
DeleteDagRunData,
DeleteDagRunResponse,
PatchDagRunStateData,
PatchDagRunStateResponse,
PatchDagRunData,
PatchDagRunResponse,
GetDagSourceData,
GetDagSourceResponse,
GetEventLogData,
Expand Down Expand Up @@ -794,7 +794,7 @@ export class DagRunService {
}

/**
* Patch Dag Run State
* Patch Dag Run
* Modify a DAG Run.
* @param data The data for the request.
* @param data.dagId
Expand All @@ -804,9 +804,9 @@ export class DagRunService {
* @returns DAGRunResponse Successful Response
* @throws ApiError
*/
public static patchDagRunState(
data: PatchDagRunStateData,
): CancelablePromise<PatchDagRunStateResponse> {
public static patchDagRun(
data: PatchDagRunData,
): CancelablePromise<PatchDagRunResponse> {
return __request(OpenAPI, {
method: "PATCH",
url: "/public/dags/{dag_id}/dagRuns/{dag_run_id}",
Expand Down
9 changes: 5 additions & 4 deletions airflow/ui/openapi-gen/requests/types.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ export type DAGResponse = {
* DAG Run Serializer for PATCH requests.
*/
export type DAGRunPatchBody = {
state: DAGRunPatchStates;
state?: DAGRunPatchStates | null;
note?: string | null;
};

/**
Expand Down Expand Up @@ -932,14 +933,14 @@ export type DeleteDagRunData = {

export type DeleteDagRunResponse = void;

export type PatchDagRunStateData = {
export type PatchDagRunData = {
dagId: string;
dagRunId: string;
requestBody: DAGRunPatchBody;
updateMask?: Array<string> | null;
};

export type PatchDagRunStateResponse = DAGRunResponse;
export type PatchDagRunResponse = DAGRunResponse;

export type GetDagSourceData = {
accept?: string;
Expand Down Expand Up @@ -1775,7 +1776,7 @@ export type $OpenApiTs = {
};
};
patch: {
req: PatchDagRunStateData;
req: PatchDagRunData;
res: {
/**
* Successful Response
Expand Down
88 changes: 72 additions & 16 deletions tests/api_fastapi/core_api/routes/public/test_dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
DAG2_RUN2_TRIGGERED_BY = DagRunTriggeredByType.REST_API
START_DATE = datetime(2024, 6, 15, 0, 0, tzinfo=timezone.utc)
EXECUTION_DATE = datetime(2024, 6, 16, 0, 0, tzinfo=timezone.utc)
DAG1_NOTE = "test_note"
DAG1_RUN1_NOTE = "test_note"


@pytest.fixture(autouse=True)
Expand All @@ -66,13 +66,13 @@ def setup(dag_maker, session=None):
start_date=START_DATE,
):
EmptyOperator(task_id="task_1")
dag1 = dag_maker.create_dagrun(
dag_run1 = dag_maker.create_dagrun(
run_id=DAG1_RUN1_ID,
state=DAG1_RUN1_STATE,
run_type=DAG1_RUN1_RUN_TYPE,
triggered_by=DAG1_RUN1_TRIGGERED_BY,
)
dag1.note = (DAG1_NOTE, 1)
dag_run1.note = (DAG1_RUN1_NOTE, 1)

dag_maker.create_dagrun(
run_id=DAG1_RUN2_ID,
Expand Down Expand Up @@ -114,7 +114,14 @@ class TestGetDagRun:
@pytest.mark.parametrize(
"dag_id, run_id, state, run_type, triggered_by, dag_run_note",
[
(DAG1_ID, DAG1_RUN1_ID, DAG1_RUN1_STATE, DAG1_RUN1_RUN_TYPE, DAG1_RUN1_TRIGGERED_BY, DAG1_NOTE),
(
DAG1_ID,
DAG1_RUN1_ID,
DAG1_RUN1_STATE,
DAG1_RUN1_RUN_TYPE,
DAG1_RUN1_TRIGGERED_BY,
DAG1_RUN1_NOTE,
),
(DAG1_ID, DAG1_RUN2_ID, DAG1_RUN2_STATE, DAG1_RUN2_RUN_TYPE, DAG1_RUN2_TRIGGERED_BY, None),
(DAG2_ID, DAG2_RUN1_ID, DAG2_RUN1_STATE, DAG2_RUN1_RUN_TYPE, DAG2_RUN1_TRIGGERED_BY, None),
(DAG2_ID, DAG2_RUN2_ID, DAG2_RUN2_STATE, DAG2_RUN2_RUN_TYPE, DAG2_RUN2_TRIGGERED_BY, None),
Expand All @@ -140,36 +147,85 @@ def test_get_dag_run_not_found(self, test_client):

class TestPatchDagRun:
@pytest.mark.parametrize(
"dag_id, run_id, state, response_state",
"dag_id, run_id, patch_body, response_body",
[
(DAG1_ID, DAG1_RUN1_ID, DagRunState.FAILED, DagRunState.FAILED),
(DAG1_ID, DAG1_RUN2_ID, DagRunState.SUCCESS, DagRunState.SUCCESS),
(DAG2_ID, DAG2_RUN1_ID, DagRunState.QUEUED, DagRunState.QUEUED),
(
DAG1_ID,
DAG1_RUN1_ID,
{"state": DagRunState.FAILED, "note": "new_note2"},
{"state": DagRunState.FAILED, "note": "new_note2"},
),
(
DAG1_ID,
DAG1_RUN2_ID,
{"state": DagRunState.SUCCESS},
{"state": DagRunState.SUCCESS, "note": None},
),
(
DAG2_ID,
DAG2_RUN1_ID,
{"state": DagRunState.QUEUED},
{"state": DagRunState.QUEUED, "note": None},
),
(
DAG1_ID,
DAG1_RUN1_ID,
{"note": "updated note"},
{"state": DagRunState.SUCCESS, "note": "updated note"},
),
(
DAG1_ID,
DAG1_RUN2_ID,
{"note": "new note", "state": DagRunState.FAILED},
{"state": DagRunState.FAILED, "note": "new note"},
),
(DAG1_ID, DAG1_RUN2_ID, {"note": None}, {"state": DagRunState.FAILED, "note": None}),
],
)
def test_patch_dag_run(self, test_client, dag_id, run_id, state, response_state):
response = test_client.patch(f"/public/dags/{dag_id}/dagRuns/{run_id}", json={"state": state})
def test_patch_dag_run(self, test_client, dag_id, run_id, patch_body, response_body):
response = test_client.patch(f"/public/dags/{dag_id}/dagRuns/{run_id}", json=patch_body)
assert response.status_code == 200
body = response.json()
assert body["dag_id"] == dag_id
assert body["run_id"] == run_id
assert body["state"] == response_state
assert body.get("state") == response_body.get("state")
assert body.get("note") == response_body.get("note")

@pytest.mark.parametrize(
"query_params, patch_body, expected_status_code",
"query_params, patch_body, response_body, expected_status_code",
[
({"update_mask": ["state"]}, {"state": DagRunState.SUCCESS}, 200),
({}, {"state": DagRunState.SUCCESS}, 200),
({"update_mask": ["random"]}, {"state": DagRunState.SUCCESS}, 400),
({"update_mask": ["state"]}, {"state": DagRunState.SUCCESS}, {"state": "success"}, 200),
(
{"update_mask": ["note"]},
{"state": DagRunState.FAILED, "note": "new_note1"},
{"note": "new_note1", "state": "success"},
200,
),
(
{},
{"state": DagRunState.FAILED, "note": "new_note2"},
{"note": "new_note2", "state": "failed"},
200,
),
({"update_mask": ["note"]}, {}, {"state": "success", "note": None}, 200),
(
{"update_mask": ["random"]},
{"state": DagRunState.FAILED},
{"state": "success", "note": "test_note"},
200,
),
],
)
def test_patch_dag_run_with_update_mask(
self, test_client, query_params, patch_body, expected_status_code
self, test_client, query_params, patch_body, response_body, expected_status_code
):
response = test_client.patch(
f"/public/dags/{DAG1_ID}/dagRuns/{DAG1_RUN1_ID}", params=query_params, json=patch_body
)
response_json = response.json()
assert response.status_code == expected_status_code
for key, value in response_body.items():
assert response_json.get(key) == value

def test_patch_dag_run_not_found(self, test_client):
response = test_client.patch(
Expand Down