Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions providers/edge3/src/airflow/providers/edge3/cli/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@
help="Comma delimited list of queues to add or remove.",
required=True,
)
ARG_CONCURRENCY_REQUIRED = Arg(
("-c", "--concurrency"),
type=int,
help="The number of worker processes. Must be a positive integer.",
required=True,
)
ARG_WAIT_MAINT = Arg(
("-w", "--wait"),
default=False,
Expand Down Expand Up @@ -229,6 +235,15 @@
func=lazy_load_command("airflow.providers.edge3.cli.edge_command.shutdown_all_workers"),
args=(ARG_YES,),
),
ActionCommand(
name="set-worker-concurrency",
help="Set the concurrency of a remote edge worker.",
func=lazy_load_command("airflow.providers.edge3.cli.edge_command.set_remote_worker_concurrency"),
args=(
ARG_REQUIRED_EDGE_HOSTNAME,
ARG_CONCURRENCY_REQUIRED,
),
),
]


Expand Down
24 changes: 24 additions & 0 deletions providers/edge3/src/airflow/providers/edge3/cli/edge_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,3 +427,27 @@ def remove_worker_queues(args) -> None:
except TypeError as e:
logger.error(str(e))
raise SystemExit


@cli_utils.action_cli(check_db=False)
@providers_configuration_loaded
def set_remote_worker_concurrency(args) -> None:
"""Set the concurrency of a remote edge worker."""
_check_valid_db_connection()
_check_if_registered_edge_host(hostname=args.edge_hostname)
from airflow.providers.edge3.models.edge_worker import set_worker_concurrency

if args.concurrency <= 0:
raise SystemExit("Error: Concurrency must be a positive integer.")

try:
set_worker_concurrency(args.edge_hostname, args.concurrency)
logger.info(
"Concurrency set to %d for Edge Worker host %s by %s.",
args.concurrency,
args.edge_hostname,
getuser(),
)
except TypeError as e:
logger.error(str(e))
raise SystemExit
7 changes: 7 additions & 0 deletions providers/edge3/src/airflow/providers/edge3/cli/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,13 @@ async def heartbeat(self, new_maintenance_comments: str | None = None) -> bool:
new_maintenance_comments,
)
self.queues = worker_info.queues
if worker_info.concurrency is not None and worker_info.concurrency != self.concurrency:
logger.info(
"Concurrency updated from %d to %d by remote request.",
self.concurrency,
worker_info.concurrency,
)
self.concurrency = worker_info.concurrency
if worker_info.state == EdgeWorkerState.MAINTENANCE_REQUEST:
logger.info("Maintenance mode requested!")
self.maintenance_mode = True
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Add concurrency column to edge_worker table.

Revision ID: b3c4d5e6f7a8
Revises: 9d34dfc2de06
Create Date: 2026-03-04 00:00:00.000000
"""

from __future__ import annotations

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "b3c4d5e6f7a8"
down_revision = "9d34dfc2de06"
branch_labels = None
depends_on = None
edge3_version = "3.2.0"


def upgrade() -> None:
bind = op.get_bind()
inspector = sa.inspect(bind)
existing_columns = {col["name"] for col in inspector.get_columns("edge_worker")}
if "concurrency" not in existing_columns:
op.add_column("edge_worker", sa.Column("concurrency", sa.Integer(), nullable=True))


def downgrade() -> None:
op.drop_column("edge_worker", "concurrency")
28 changes: 28 additions & 0 deletions providers/edge3/src/airflow/providers/edge3/models/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

_REVISION_HEADS_MAP: dict[str, str] = {
"3.0.0": "9d34dfc2de06",
"3.2.0": "b3c4d5e6f7a8",
}


Expand All @@ -45,6 +46,33 @@ class EdgeDBManager(BaseDBManager):
supports_table_dropping = True
revision_heads_map = _REVISION_HEADS_MAP

def initdb(self):
"""
Initialize the database, handling pre-alembic installations.

If the edge3 tables already exist but the alembic version table does not
(e.g. created via create_all before the migration chain was introduced),
stamp to the first revision and run the incremental upgrade so every
migration is applied rather than jumping straight to head.
"""
db_exists = self.get_current_revision()
if db_exists:
self.upgradedb()
else:
from airflow import settings

existing_tables = set(inspect(settings.engine).get_table_names())
if any(table in existing_tables for table in self.metadata.tables):
script = self.get_script_object()
base_revision = next(r.revision for r in script.walk_revisions() if r.down_revision is None)
config = self.get_alembic_config()
from alembic import command

command.stamp(config, base_revision)
self.upgradedb()
else:
self.create_db_from_orm()

def drop_tables(self, connection):
"""Drop only edge3 tables in reverse dependency order."""
if not self.supports_table_dropping:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class EdgeWorkerModel(Base, LoggingMixin):
jobs_success: Mapped[int] = mapped_column(Integer, default=0)
jobs_failed: Mapped[int] = mapped_column(Integer, default=0)
sysinfo: Mapped[str | None] = mapped_column(String(256))
concurrency: Mapped[int | None] = mapped_column(Integer, nullable=True)

def __init__(
self,
Expand Down Expand Up @@ -392,3 +393,23 @@ def remove_worker_queues(worker_name: str, queues: list[str], session: Session =
logger.error(error_message)
raise TypeError(error_message)
worker.remove_queues(queues)


@provide_session
def set_worker_concurrency(worker_name: str, concurrency: int, session: Session = NEW_SESSION) -> None:
"""Set the concurrency of an edge worker."""
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name)
worker: EdgeWorkerModel | None = session.scalar(query)
if not worker:
raise ValueError(f"Edge Worker {worker_name} not found in list of registered workers")
if worker.state in (
EdgeWorkerState.OFFLINE,
EdgeWorkerState.OFFLINE_MAINTENANCE,
EdgeWorkerState.UNKNOWN,
):
error_message = (
f"Cannot set concurrency for edge worker {worker_name} as it is in {worker.state} state!"
)
logger.error(error_message)
raise TypeError(error_message)
worker.concurrency = concurrency
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,10 @@ class WorkerSetStateReturn(BaseModel):
str | None,
Field(description="Comments about the maintenance state of the worker."),
] = None
concurrency: Annotated[
int | None,
Field(
description="Desired concurrency for the worker set by an administrator. "
"None means no remote override; the worker uses its startup value.",
),
] = None
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,10 @@ def set_state(
)
_assert_version(body.sysinfo) # Exception only after worker state is in the DB
return WorkerSetStateReturn(
state=worker.state, queues=worker.queues, maintenance_comments=worker.maintenance_comment
state=worker.state,
queues=worker.queues,
maintenance_comments=worker.maintenance_comment,
concurrency=worker.concurrency,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1413,6 +1413,13 @@ components:
- type: 'null'
title: Maintenance Comments
description: Comments about the maintenance state of the worker.
concurrency:
anyOf:
- type: integer
- type: 'null'
title: Concurrency
description: Desired concurrency for the worker set by an administrator.
None means no remote override; the worker uses its startup value.
type: object
required:
- state
Expand Down
18 changes: 16 additions & 2 deletions providers/edge3/tests/unit/edge3/cli/test_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def test_edge_cli_commands_count(self):
assert len(commands) == 1

def test_edge_commands_count(self):
"""Test that EDGE_COMMANDS contains all 13 subcommands."""
assert len(EDGE_COMMANDS) == 13
"""Test that EDGE_COMMANDS contains all 14 subcommands."""
assert len(EDGE_COMMANDS) == 14

@pytest.mark.parametrize(
"command",
Expand Down Expand Up @@ -234,3 +234,17 @@ def test_shutdown_all_workers_args(self):
params = ["edge", "shutdown-all-workers", "--yes"]
args = self.arg_parser.parse_args(params)
assert args.yes is True

def test_set_worker_concurrency_args(self):
"""Test set-worker-concurrency command with required arguments."""
params = [
"edge",
"set-worker-concurrency",
"--edge-hostname",
"remote-worker-1",
"--concurrency",
"16",
]
args = self.arg_parser.parse_args(params)
assert args.edge_hostname == "remote-worker-1"
assert args.concurrency == 16
27 changes: 27 additions & 0 deletions providers/edge3/tests/unit/edge3/cli/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,33 @@ async def test_heartbeat(
assert "queue1" in (queue_list)
assert "queue2" in (queue_list)

@patch("airflow.providers.edge3.cli.worker.worker_set_state")
async def test_heartbeat_adopts_remote_concurrency(self, mock_set_state, worker_with_job: EdgeWorker):
EdgeWorker.jobs = []
EdgeWorker.drain = False
EdgeWorker.maintenance_mode = False
mock_set_state.return_value = WorkerSetStateReturn(
state=EdgeWorkerState.IDLE, queues=None, concurrency=32
)
with conf_vars({("edge", "api_url"): "https://invalid-api-test-endpoint"}):
await worker_with_job.heartbeat()
assert worker_with_job.concurrency == 32

@patch("airflow.providers.edge3.cli.worker.worker_set_state")
async def test_heartbeat_no_concurrency_override_keeps_startup_value(
self, mock_set_state, worker_with_job: EdgeWorker
):
EdgeWorker.jobs = []
EdgeWorker.drain = False
EdgeWorker.maintenance_mode = False
original_concurrency = worker_with_job.concurrency
mock_set_state.return_value = WorkerSetStateReturn(
state=EdgeWorkerState.IDLE, queues=None, concurrency=None
)
with conf_vars({("edge", "api_url"): "https://invalid-api-test-endpoint"}):
await worker_with_job.heartbeat()
assert worker_with_job.concurrency == original_concurrency

@patch("airflow.providers.edge3.cli.worker.worker_set_state")
async def test_version_mismatch(self, mock_set_state, worker_with_job):
mock_set_state.side_effect = EdgeWorkerVersionException("")
Expand Down
75 changes: 73 additions & 2 deletions providers/edge3/tests/unit/edge3/models/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,13 @@ def test_create_db_from_orm(self, mock_command, session):
__import__("airflow.providers.edge3.models.db", fromlist=["EdgeDBManager"]).EdgeDBManager,
"get_current_revision",
)
def test_initdb_new_db(self, mock_get_rev, mock_create, mock_upgrade, session):
@mock.patch("airflow.providers.edge3.models.db.inspect")
def test_initdb_new_db(self, mock_inspect, mock_get_rev, mock_create, mock_upgrade, session):
"""Test that initdb calls create_db_from_orm for new databases."""
from airflow.providers.edge3.models.db import EdgeDBManager

mock_get_rev.return_value = None
mock_inspect.return_value.get_table_names.return_value = [] # no tables exist

manager = EdgeDBManager(session)
manager.initdb()
Expand Down Expand Up @@ -205,11 +207,80 @@ def test_initdb_existing_db(self, mock_get_rev, mock_create, mock_upgrade, sessi
mock_create.assert_not_called()

def test_revision_heads_map_populated(self):
"""Test that _REVISION_HEADS_MAP is populated with the initial migration."""
"""Test that _REVISION_HEADS_MAP is populated with all known migrations."""
from airflow.providers.edge3.models.db import _REVISION_HEADS_MAP

assert "3.0.0" in _REVISION_HEADS_MAP
assert _REVISION_HEADS_MAP["3.0.0"] == "9d34dfc2de06"
assert "3.2.0" in _REVISION_HEADS_MAP
assert _REVISION_HEADS_MAP["3.2.0"] == "b3c4d5e6f7a8"

def test_initdb_stamps_and_upgrades_when_tables_exist_without_version(self, session):
"""Test that initdb runs incremental migrations when tables exist but alembic version table does not."""
from sqlalchemy import inspect, text

from airflow import settings
from airflow.providers.edge3.models.db import EdgeDBManager

manager = EdgeDBManager(session)

# Simulate pre-alembic state: tables exist but no version table and no concurrency column
with settings.engine.begin() as conn:
inspector = inspect(conn)
if inspector.has_table("alembic_version_edge3"):
conn.execute(text("DELETE FROM alembic_version_edge3"))
if "concurrency" in {col["name"] for col in inspector.get_columns("edge_worker")}:
from alembic.migration import MigrationContext
from alembic.operations import Operations

mc = MigrationContext.configure(conn, opts={"render_as_batch": True})
ops = Operations(mc)
ops.drop_column("edge_worker", "concurrency")

# initdb() should detect tables exist, stamp to base, then upgrade
manager.initdb()

with settings.engine.connect() as conn:
version = conn.execute(text("SELECT version_num FROM alembic_version_edge3")).scalar()
columns = {col["name"] for col in inspect(conn).get_columns("edge_worker")}

assert version == "b3c4d5e6f7a8"
assert "concurrency" in columns

def test_migration_adds_concurrency_column(self, session):
"""Test that upgrading from 3.0.0 actually adds the concurrency column."""
from alembic import command
from alembic.migration import MigrationContext
from alembic.operations import Operations
from sqlalchemy import inspect

from airflow import settings
from airflow.providers.edge3.models.db import EdgeDBManager

manager = EdgeDBManager(session)
config = manager.get_alembic_config()

# DDL must be committed before alembic opens its own connection — use engine.begin()
# so the DROP is visible to the fresh connection that upgradedb() creates internally.
with settings.engine.begin() as conn:
inspector = inspect(conn)
if "concurrency" in {col["name"] for col in inspector.get_columns("edge_worker")}:
mc = MigrationContext.configure(conn, opts={"render_as_batch": True})
ops = Operations(mc)
ops.drop_column("edge_worker", "concurrency")

# Stamp to old revision (pre-concurrency) using alembic's own connection
command.stamp(config, "9d34dfc2de06")

# Run the upgrade — migration 0002 should detect the missing column and add it
manager.upgradedb()

# Verify with a fresh connection (upgradedb also uses its own connection)
with settings.engine.connect() as conn:
inspector = inspect(conn)
columns = {col["name"] for col in inspector.get_columns("edge_worker")}

assert "concurrency" in columns, "Migration 0002 should have added the concurrency column"

def test_drop_tables_handles_missing_tables(self, session):
"""Test that drop_tables handles missing tables gracefully."""
Expand Down
Loading