diff --git a/airflow-core/src/airflow/api_fastapi/auth/tokens.py b/airflow-core/src/airflow/api_fastapi/auth/tokens.py index 276ae17153da0..090b2c011326d 100644 --- a/airflow-core/src/airflow/api_fastapi/auth/tokens.py +++ b/airflow-core/src/airflow/api_fastapi/auth/tokens.py @@ -305,7 +305,7 @@ def _get_kid_from_header(self, unvalidated: str) -> str: raise jwt.InvalidTokenError("Missing 'kid' in token header") return header["kid"] - async def _get_validation_key(self, unvalidated: str) -> str | jwt.PyJWK: + async def get_validation_key(self, unvalidated: str) -> str | jwt.PyJWK: if self.secret_key: return self.secret_key @@ -324,7 +324,7 @@ async def avalidated_claims( self, unvalidated: str, required_claims: dict[str, Any] | None = None ) -> dict[str, Any]: """Decode the JWT token, returning the validated claims or raising an exception.""" - key = await self._get_validation_key(unvalidated) + key = await self.get_validation_key(unvalidated) claims = jwt.decode( unvalidated, key, diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py b/airflow-core/src/airflow/api_fastapi/execution_api/app.py index 9d93f3bf84daf..909e3bf14b628 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py @@ -133,27 +133,29 @@ async def dispatch(self, request: Request, call_next): response: Response = await call_next(request) - refreshed_token: str | None = None - auth_header = request.headers.get("authorization") - if auth_header and auth_header.lower().startswith("bearer "): - token = auth_header.split(" ", 1)[1] - try: - async with svcs.Container(request.app.state.svcs_registry) as services: - validator: JWTValidator = await services.aget(JWTValidator) - claims = await validator.avalidated_claims(token, {}) - - now = int(time.time()) - validity = conf.getint("execution_api", "jwt_expiration_time") - refresh_when_less_than = max(int(validity * 0.20), 30) - valid_left = int(claims.get("exp", 0)) - now - if valid_left <= refresh_when_less_than: - generator: JWTGenerator = await services.aget(JWTGenerator) - refreshed_token = generator.generate(claims) - except Exception as err: - # Do not block the response if refreshing fails; log a warning for visibility - logger.warning( - "JWT reissue middleware failed to refresh token", error=str(err), exc_info=True - ) + refreshed_token: str | None = getattr(request.state, "refreshed_token", None) + + if not refreshed_token: + auth_header = request.headers.get("authorization") + if auth_header and auth_header.lower().startswith("bearer "): + token = auth_header.split(" ", 1)[1] + try: + async with svcs.Container(request.app.state.svcs_registry) as services: + validator: JWTValidator = await services.aget(JWTValidator) + claims = await validator.avalidated_claims(token, {}) + + now = int(time.time()) + validity = conf.getint("execution_api", "jwt_expiration_time") + refresh_when_less_than = max(int(validity * 0.20), 30) + valid_left = int(claims.get("exp", 0)) - now + if valid_left <= refresh_when_less_than: + generator: JWTGenerator = await services.aget(JWTGenerator) + refreshed_token = generator.generate(claims) + except Exception as err: + # Do not block the response if refreshing fails; log a warning for visibility + logger.warning( + "JWT reissue middleware failed to refresh token", error=str(err), exc_info=True + ) if refreshed_token: response.headers["Refreshed-API-Token"] = refreshed_token diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/deps.py b/airflow-core/src/airflow/api_fastapi/execution_api/deps.py index fce188d48ed6d..fcc591402c9d9 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/deps.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/deps.py @@ -20,22 +20,41 @@ from typing import Any +import jwt import structlog import svcs from fastapi import Depends, HTTPException, Request, status from fastapi.security import HTTPBearer from sqlalchemy import select -from airflow.api_fastapi.auth.tokens import JWTValidator +from airflow.api_fastapi.auth.tokens import JWTGenerator, JWTValidator from airflow.api_fastapi.common.db.common import AsyncSessionDep from airflow.api_fastapi.execution_api.datamodels.token import TIToken from airflow.configuration import conf from airflow.models import DagModel, TaskInstance from airflow.models.dagbundle import DagBundleModel from airflow.models.team import Team +from airflow.utils.session import create_session_async +from airflow.utils.state import TaskInstanceState log = structlog.get_logger(logger_name=__name__) +# Valid states for token refresh - task must be queued or running +REFRESHABLE_TASK_STATES = frozenset({TaskInstanceState.QUEUED, TaskInstanceState.RUNNING}) + + +async def _is_task_in_refreshable_state(task_instance_id: str) -> bool: + """ + Check if a task instance is in a state that allows token refresh. + + Only tasks in QUEUED or RUNNING state can have their tokens refreshed. + This prevents refreshing tokens for completed/failed tasks. + """ + async with create_session_async() as session: + stmt = select(TaskInstance.state).where(TaskInstance.id == task_instance_id) + state = await session.scalar(stmt) + return state in REFRESHABLE_TASK_STATES + # See https://github.com/fastapi/fastapi/issues/13056 async def _container(request: Request): @@ -88,14 +107,87 @@ async def __call__( # type: ignore[override] validators = self.required_claims claims = await validator.avalidated_claims(creds.credentials, validators) return TIToken(id=claims["sub"], claims=claims) + except jwt.ExpiredSignatureError: + # Token expired - try to refresh if task is still in a valid state + log.debug("JWT token expired, attempting to refresh") + return await self._handle_expired_token(request, creds.credentials, validator, services) except Exception as err: log.warning( "Failed to validate JWT", exc_info=True, - token=creds.credentials, ) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Invalid auth token: {err}") + async def _handle_expired_token( + self, + request: Request, + token: str, + validator: JWTValidator, + services, + ) -> TIToken: + """Handle an expired JWT by refreshing if task is still active.""" + try: + key = await validator.get_validation_key(token) + claims = jwt.decode( + token, + key, + audience=validator.audience, + issuer=validator.issuer, + algorithms=validator.algorithm, + options={"verify_exp": False}, + leeway=validator.leeway, + ) + + task_instance_id = claims.get("sub") + if not task_instance_id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Invalid auth token: missing subject claim", + ) + + if self.required_claims: + for claim, expected_value in self.required_claims.items(): + if expected_value.get("essential") and ( + claim not in claims or claims[claim] != expected_value.get("value") + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Invalid auth token: invalid claim '{claim}'", + ) + + if self.path_param_name: + path_id = request.path_params.get(self.path_param_name) + if path_id != task_instance_id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Invalid auth token: subject mismatch", + ) + + if not await _is_task_in_refreshable_state(task_instance_id): + log.warning( + "Token refresh rejected: task not in refreshable state", task_instance_id=task_instance_id + ) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Token expired and task is not in a refreshable state", + ) + + generator: JWTGenerator = await services.aget(JWTGenerator) + refreshed_token = generator.generate(claims) + request.state.refreshed_token = refreshed_token + + log.info("Refreshed expired JWT token", task_instance_id=task_instance_id) + return TIToken(id=claims["sub"], claims=claims) + + except HTTPException: + raise + except Exception as err: + log.warning("Failed to refresh expired JWT token", exc_info=True, error=str(err)) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Token expired and refresh failed: {err}", + ) + JWTBearerDep: TIToken = Depends(JWTBearer()) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py b/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py index 640d920137c7b..38a80f0e9002a 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py @@ -16,8 +16,17 @@ # under the License. from __future__ import annotations +import time +from unittest.mock import AsyncMock, patch +from uuid import uuid4 + +import jwt import pytest +from fastapi.testclient import TestClient +from airflow.api_fastapi.app import cached_app +from airflow.api_fastapi.auth.tokens import JWTGenerator, JWTValidator +from airflow.api_fastapi.execution_api.app import lifespan from airflow.api_fastapi.execution_api.datamodels.taskinstance import TaskInstance from airflow.api_fastapi.execution_api.versions import bundle @@ -94,3 +103,115 @@ def test_multiple_requests_with_different_correlation_ids(self, client): # Verify they didn't interfere with each other assert correlation_id_1 != correlation_id_2 + + +class TestExpiredTokenRefresh: + """Tests for expired JWT token refresh functionality.""" + + @pytest.fixture + def expired_token_client(self): + import svcs + + app = cached_app(apps="execution") + original_registry = lifespan.registry + lifespan.registry = svcs.Registry() + + with TestClient(app) as client: + yield client + + lifespan.registry = original_registry + + @pytest.fixture + def setup_variable(self): + from airflow.models.variable import Variable + + Variable.set(key="test_var", value="test_value") + yield + try: + Variable.delete(key="test_var") + except Exception: + pass + + def test_expired_token_refreshed_for_running_task(self, expired_token_client, setup_variable): + """Expired token is refreshed when task is in RUNNING state.""" + task_id = str(uuid4()) + + mock_validator = AsyncMock(spec=JWTValidator) + mock_validator.avalidated_claims.side_effect = jwt.ExpiredSignatureError("Token has expired") + mock_validator.audience = "test-audience" + mock_validator.issuer = None + mock_validator.algorithm = ["HS256"] + mock_validator.leeway = 0 + mock_validator.get_validation_key = AsyncMock(return_value="test-secret-key") + + mock_generator = AsyncMock(spec=JWTGenerator) + mock_generator.generate.return_value = "new-refreshed-token" + + lifespan.registry.register_value(JWTValidator, mock_validator) + lifespan.registry.register_value(JWTGenerator, mock_generator) + + expired_token = jwt.encode( + { + "sub": task_id, + "exp": int(time.time()) - 3600, + "iat": int(time.time()) - 7200, + "nbf": int(time.time()) - 7200, + "aud": "test-audience", + }, + "test-secret-key", + algorithm="HS256", + ) + + with patch( + "airflow.api_fastapi.execution_api.deps._is_task_in_refreshable_state", + new_callable=AsyncMock, + return_value=True, + ): + response = expired_token_client.get( + "/execution/variables/test_var", + headers={"Authorization": f"Bearer {expired_token}"}, + ) + + assert response.status_code == 200 + assert response.headers.get("Refreshed-API-Token") == "new-refreshed-token" + mock_generator.generate.assert_called() + + def test_expired_token_rejected_for_completed_task(self, expired_token_client): + """Expired token is rejected when task is not in RUNNING/QUEUED state.""" + task_id = str(uuid4()) + + mock_validator = AsyncMock(spec=JWTValidator) + mock_validator.avalidated_claims.side_effect = jwt.ExpiredSignatureError("Token has expired") + mock_validator.audience = "test-audience" + mock_validator.issuer = None + mock_validator.algorithm = ["HS256"] + mock_validator.leeway = 0 + mock_validator.get_validation_key = AsyncMock(return_value="test-secret-key") + + lifespan.registry.register_value(JWTValidator, mock_validator) + + expired_token = jwt.encode( + { + "sub": task_id, + "exp": int(time.time()) - 3600, + "iat": int(time.time()) - 7200, + "nbf": int(time.time()) - 7200, + "aud": "test-audience", + }, + "test-secret-key", + algorithm="HS256", + ) + + with patch( + "airflow.api_fastapi.execution_api.deps._is_task_in_refreshable_state", + new_callable=AsyncMock, + return_value=False, + ): + response = expired_token_client.get( + "/execution/variables/test_var", + headers={"Authorization": f"Bearer {expired_token}"}, + ) + + assert response.status_code == 403 + assert "not in a refreshable state" in response.json()["detail"] + assert "Refreshed-API-Token" not in response.headers