Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions airflow-core/src/airflow/api_fastapi/auth/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def _get_kid_from_header(self, unvalidated: str) -> str:
raise jwt.InvalidTokenError("Missing 'kid' in token header")
return header["kid"]

async def _get_validation_key(self, unvalidated: str) -> str | jwt.PyJWK:
async def get_validation_key(self, unvalidated: str) -> str | jwt.PyJWK:
if self.secret_key:
return self.secret_key

Expand All @@ -324,7 +324,7 @@ async def avalidated_claims(
self, unvalidated: str, required_claims: dict[str, Any] | None = None
) -> dict[str, Any]:
"""Decode the JWT token, returning the validated claims or raising an exception."""
key = await self._get_validation_key(unvalidated)
key = await self.get_validation_key(unvalidated)
claims = jwt.decode(
unvalidated,
key,
Expand Down
44 changes: 23 additions & 21 deletions airflow-core/src/airflow/api_fastapi/execution_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,27 +133,29 @@ async def dispatch(self, request: Request, call_next):

response: Response = await call_next(request)

refreshed_token: str | None = None
auth_header = request.headers.get("authorization")
if auth_header and auth_header.lower().startswith("bearer "):
token = auth_header.split(" ", 1)[1]
try:
async with svcs.Container(request.app.state.svcs_registry) as services:
validator: JWTValidator = await services.aget(JWTValidator)
claims = await validator.avalidated_claims(token, {})

now = int(time.time())
validity = conf.getint("execution_api", "jwt_expiration_time")
refresh_when_less_than = max(int(validity * 0.20), 30)
valid_left = int(claims.get("exp", 0)) - now
if valid_left <= refresh_when_less_than:
generator: JWTGenerator = await services.aget(JWTGenerator)
refreshed_token = generator.generate(claims)
except Exception as err:
# Do not block the response if refreshing fails; log a warning for visibility
logger.warning(
"JWT reissue middleware failed to refresh token", error=str(err), exc_info=True
)
refreshed_token: str | None = getattr(request.state, "refreshed_token", None)

if not refreshed_token:
auth_header = request.headers.get("authorization")
if auth_header and auth_header.lower().startswith("bearer "):
token = auth_header.split(" ", 1)[1]
try:
async with svcs.Container(request.app.state.svcs_registry) as services:
validator: JWTValidator = await services.aget(JWTValidator)
claims = await validator.avalidated_claims(token, {})

now = int(time.time())
validity = conf.getint("execution_api", "jwt_expiration_time")
refresh_when_less_than = max(int(validity * 0.20), 30)
valid_left = int(claims.get("exp", 0)) - now
if valid_left <= refresh_when_less_than:
generator: JWTGenerator = await services.aget(JWTGenerator)
refreshed_token = generator.generate(claims)
except Exception as err:
# Do not block the response if refreshing fails; log a warning for visibility
logger.warning(
"JWT reissue middleware failed to refresh token", error=str(err), exc_info=True
)

if refreshed_token:
response.headers["Refreshed-API-Token"] = refreshed_token
Expand Down
96 changes: 94 additions & 2 deletions airflow-core/src/airflow/api_fastapi/execution_api/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,41 @@

from typing import Any

import jwt
import structlog
import svcs
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPBearer
from sqlalchemy import select

from airflow.api_fastapi.auth.tokens import JWTValidator
from airflow.api_fastapi.auth.tokens import JWTGenerator, JWTValidator
from airflow.api_fastapi.common.db.common import AsyncSessionDep
from airflow.api_fastapi.execution_api.datamodels.token import TIToken
from airflow.configuration import conf
from airflow.models import DagModel, TaskInstance
from airflow.models.dagbundle import DagBundleModel
from airflow.models.team import Team
from airflow.utils.session import create_session_async
from airflow.utils.state import TaskInstanceState

log = structlog.get_logger(logger_name=__name__)

# Valid states for token refresh - task must be queued or running
REFRESHABLE_TASK_STATES = frozenset({TaskInstanceState.QUEUED, TaskInstanceState.RUNNING})


async def _is_task_in_refreshable_state(task_instance_id: str) -> bool:
"""
Check if a task instance is in a state that allows token refresh.

Only tasks in QUEUED or RUNNING state can have their tokens refreshed.
This prevents refreshing tokens for completed/failed tasks.
"""
async with create_session_async() as session:
stmt = select(TaskInstance.state).where(TaskInstance.id == task_instance_id)
state = await session.scalar(stmt)
return state in REFRESHABLE_TASK_STATES


# See https://github.com/fastapi/fastapi/issues/13056
async def _container(request: Request):
Expand Down Expand Up @@ -88,14 +107,87 @@ async def __call__( # type: ignore[override]
validators = self.required_claims
claims = await validator.avalidated_claims(creds.credentials, validators)
return TIToken(id=claims["sub"], claims=claims)
except jwt.ExpiredSignatureError:
# Token expired - try to refresh if task is still in a valid state
log.debug("JWT token expired, attempting to refresh")
return await self._handle_expired_token(request, creds.credentials, validator, services)
except Exception as err:
log.warning(
"Failed to validate JWT",
exc_info=True,
token=creds.credentials,
)
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Invalid auth token: {err}")

async def _handle_expired_token(
self,
request: Request,
token: str,
validator: JWTValidator,
services,
) -> TIToken:
"""Handle an expired JWT by refreshing if task is still active."""
try:
key = await validator.get_validation_key(token)
claims = jwt.decode(
token,
key,
audience=validator.audience,
issuer=validator.issuer,
algorithms=validator.algorithm,
options={"verify_exp": False},
leeway=validator.leeway,
)

task_instance_id = claims.get("sub")
if not task_instance_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid auth token: missing subject claim",
)

if self.required_claims:
for claim, expected_value in self.required_claims.items():
if expected_value.get("essential") and (
claim not in claims or claims[claim] != expected_value.get("value")
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Invalid auth token: invalid claim '{claim}'",
)

if self.path_param_name:
path_id = request.path_params.get(self.path_param_name)
if path_id != task_instance_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid auth token: subject mismatch",
)

if not await _is_task_in_refreshable_state(task_instance_id):
log.warning(
"Token refresh rejected: task not in refreshable state", task_instance_id=task_instance_id
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Token expired and task is not in a refreshable state",
)

generator: JWTGenerator = await services.aget(JWTGenerator)
refreshed_token = generator.generate(claims)
request.state.refreshed_token = refreshed_token

log.info("Refreshed expired JWT token", task_instance_id=task_instance_id)
return TIToken(id=claims["sub"], claims=claims)

except HTTPException:
raise
except Exception as err:
log.warning("Failed to refresh expired JWT token", exc_info=True, error=str(err))
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Token expired and refresh failed: {err}",
)


JWTBearerDep: TIToken = Depends(JWTBearer())

Expand Down
121 changes: 121 additions & 0 deletions airflow-core/tests/unit/api_fastapi/execution_api/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,17 @@
# under the License.
from __future__ import annotations

import time
from unittest.mock import AsyncMock, patch
from uuid import uuid4

import jwt
import pytest
from fastapi.testclient import TestClient

from airflow.api_fastapi.app import cached_app
from airflow.api_fastapi.auth.tokens import JWTGenerator, JWTValidator
from airflow.api_fastapi.execution_api.app import lifespan
from airflow.api_fastapi.execution_api.datamodels.taskinstance import TaskInstance
from airflow.api_fastapi.execution_api.versions import bundle

Expand Down Expand Up @@ -94,3 +103,115 @@ def test_multiple_requests_with_different_correlation_ids(self, client):

# Verify they didn't interfere with each other
assert correlation_id_1 != correlation_id_2


class TestExpiredTokenRefresh:
"""Tests for expired JWT token refresh functionality."""

@pytest.fixture
def expired_token_client(self):
import svcs

app = cached_app(apps="execution")
original_registry = lifespan.registry
lifespan.registry = svcs.Registry()

with TestClient(app) as client:
yield client

lifespan.registry = original_registry

@pytest.fixture
def setup_variable(self):
from airflow.models.variable import Variable

Variable.set(key="test_var", value="test_value")
yield
try:
Variable.delete(key="test_var")
except Exception:
pass

def test_expired_token_refreshed_for_running_task(self, expired_token_client, setup_variable):
"""Expired token is refreshed when task is in RUNNING state."""
task_id = str(uuid4())

mock_validator = AsyncMock(spec=JWTValidator)
mock_validator.avalidated_claims.side_effect = jwt.ExpiredSignatureError("Token has expired")
mock_validator.audience = "test-audience"
mock_validator.issuer = None
mock_validator.algorithm = ["HS256"]
mock_validator.leeway = 0
mock_validator.get_validation_key = AsyncMock(return_value="test-secret-key")

mock_generator = AsyncMock(spec=JWTGenerator)
mock_generator.generate.return_value = "new-refreshed-token"

lifespan.registry.register_value(JWTValidator, mock_validator)
lifespan.registry.register_value(JWTGenerator, mock_generator)

expired_token = jwt.encode(
{
"sub": task_id,
"exp": int(time.time()) - 3600,
"iat": int(time.time()) - 7200,
"nbf": int(time.time()) - 7200,
"aud": "test-audience",
},
"test-secret-key",
algorithm="HS256",
)

with patch(
"airflow.api_fastapi.execution_api.deps._is_task_in_refreshable_state",
new_callable=AsyncMock,
return_value=True,
):
response = expired_token_client.get(
"/execution/variables/test_var",
headers={"Authorization": f"Bearer {expired_token}"},
)

assert response.status_code == 200
assert response.headers.get("Refreshed-API-Token") == "new-refreshed-token"
mock_generator.generate.assert_called()

def test_expired_token_rejected_for_completed_task(self, expired_token_client):
"""Expired token is rejected when task is not in RUNNING/QUEUED state."""
task_id = str(uuid4())

mock_validator = AsyncMock(spec=JWTValidator)
mock_validator.avalidated_claims.side_effect = jwt.ExpiredSignatureError("Token has expired")
mock_validator.audience = "test-audience"
mock_validator.issuer = None
mock_validator.algorithm = ["HS256"]
mock_validator.leeway = 0
mock_validator.get_validation_key = AsyncMock(return_value="test-secret-key")

lifespan.registry.register_value(JWTValidator, mock_validator)

expired_token = jwt.encode(
{
"sub": task_id,
"exp": int(time.time()) - 3600,
"iat": int(time.time()) - 7200,
"nbf": int(time.time()) - 7200,
"aud": "test-audience",
},
"test-secret-key",
algorithm="HS256",
)

with patch(
"airflow.api_fastapi.execution_api.deps._is_task_in_refreshable_state",
new_callable=AsyncMock,
return_value=False,
):
response = expired_token_client.get(
"/execution/variables/test_var",
headers={"Authorization": f"Bearer {expired_token}"},
)

assert response.status_code == 403
assert "not in a refreshable state" in response.json()["detail"]
assert "Refreshed-API-Token" not in response.headers
Loading