diff --git a/providers/edge3/src/airflow/providers/edge3/cli/definition.py b/providers/edge3/src/airflow/providers/edge3/cli/definition.py index 5cd611cc3a438..a1819e9a9a087 100644 --- a/providers/edge3/src/airflow/providers/edge3/cli/definition.py +++ b/providers/edge3/src/airflow/providers/edge3/cli/definition.py @@ -59,6 +59,12 @@ help="Comma delimited list of queues to add or remove.", required=True, ) +ARG_CONCURRENCY_REQUIRED = Arg( + ("-c", "--concurrency"), + type=int, + help="The number of worker processes. Must be a positive integer.", + required=True, +) ARG_WAIT_MAINT = Arg( ("-w", "--wait"), default=False, @@ -229,6 +235,15 @@ func=lazy_load_command("airflow.providers.edge3.cli.edge_command.shutdown_all_workers"), args=(ARG_YES,), ), + ActionCommand( + name="set-worker-concurrency", + help="Set the concurrency of a remote edge worker.", + func=lazy_load_command("airflow.providers.edge3.cli.edge_command.set_remote_worker_concurrency"), + args=( + ARG_REQUIRED_EDGE_HOSTNAME, + ARG_CONCURRENCY_REQUIRED, + ), + ), ] diff --git a/providers/edge3/src/airflow/providers/edge3/cli/edge_command.py b/providers/edge3/src/airflow/providers/edge3/cli/edge_command.py index 864d6851ace18..c18fbdc7f905b 100644 --- a/providers/edge3/src/airflow/providers/edge3/cli/edge_command.py +++ b/providers/edge3/src/airflow/providers/edge3/cli/edge_command.py @@ -427,3 +427,27 @@ def remove_worker_queues(args) -> None: except TypeError as e: logger.error(str(e)) raise SystemExit + + +@cli_utils.action_cli(check_db=False) +@providers_configuration_loaded +def set_remote_worker_concurrency(args) -> None: + """Set the concurrency of a remote edge worker.""" + _check_valid_db_connection() + _check_if_registered_edge_host(hostname=args.edge_hostname) + from airflow.providers.edge3.models.edge_worker import set_worker_concurrency + + if args.concurrency <= 0: + raise SystemExit("Error: Concurrency must be a positive integer.") + + try: + set_worker_concurrency(args.edge_hostname, args.concurrency) + logger.info( + "Concurrency set to %d for Edge Worker host %s by %s.", + args.concurrency, + args.edge_hostname, + getuser(), + ) + except TypeError as e: + logger.error(str(e)) + raise SystemExit diff --git a/providers/edge3/src/airflow/providers/edge3/cli/worker.py b/providers/edge3/src/airflow/providers/edge3/cli/worker.py index c4aa1d735c509..6ac43e388f82c 100644 --- a/providers/edge3/src/airflow/providers/edge3/cli/worker.py +++ b/providers/edge3/src/airflow/providers/edge3/cli/worker.py @@ -401,6 +401,13 @@ async def heartbeat(self, new_maintenance_comments: str | None = None) -> bool: new_maintenance_comments, ) self.queues = worker_info.queues + if worker_info.concurrency is not None and worker_info.concurrency != self.concurrency: + logger.info( + "Concurrency updated from %d to %d by remote request.", + self.concurrency, + worker_info.concurrency, + ) + self.concurrency = worker_info.concurrency if worker_info.state == EdgeWorkerState.MAINTENANCE_REQUEST: logger.info("Maintenance mode requested!") self.maintenance_mode = True diff --git a/providers/edge3/src/airflow/providers/edge3/migrations/versions/0002_3_2_0_add_concurrency_to_edge_worker.py b/providers/edge3/src/airflow/providers/edge3/migrations/versions/0002_3_2_0_add_concurrency_to_edge_worker.py new file mode 100644 index 0000000000000..edf6cc7f39ccd --- /dev/null +++ b/providers/edge3/src/airflow/providers/edge3/migrations/versions/0002_3_2_0_add_concurrency_to_edge_worker.py @@ -0,0 +1,49 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Add concurrency column to edge_worker table. + +Revision ID: b3c4d5e6f7a8 +Revises: 9d34dfc2de06 +Create Date: 2026-03-04 00:00:00.000000 +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "b3c4d5e6f7a8" +down_revision = "9d34dfc2de06" +branch_labels = None +depends_on = None +edge3_version = "3.2.0" + + +def upgrade() -> None: + bind = op.get_bind() + inspector = sa.inspect(bind) + existing_columns = {col["name"] for col in inspector.get_columns("edge_worker")} + if "concurrency" not in existing_columns: + op.add_column("edge_worker", sa.Column("concurrency", sa.Integer(), nullable=True)) + + +def downgrade() -> None: + op.drop_column("edge_worker", "concurrency") diff --git a/providers/edge3/src/airflow/providers/edge3/models/db.py b/providers/edge3/src/airflow/providers/edge3/models/db.py index 1c98cb40235e1..e98b3cc24b97f 100644 --- a/providers/edge3/src/airflow/providers/edge3/models/db.py +++ b/providers/edge3/src/airflow/providers/edge3/models/db.py @@ -31,6 +31,7 @@ _REVISION_HEADS_MAP: dict[str, str] = { "3.0.0": "9d34dfc2de06", + "3.2.0": "b3c4d5e6f7a8", } @@ -45,6 +46,33 @@ class EdgeDBManager(BaseDBManager): supports_table_dropping = True revision_heads_map = _REVISION_HEADS_MAP + def initdb(self): + """ + Initialize the database, handling pre-alembic installations. + + If the edge3 tables already exist but the alembic version table does not + (e.g. created via create_all before the migration chain was introduced), + stamp to the first revision and run the incremental upgrade so every + migration is applied rather than jumping straight to head. + """ + db_exists = self.get_current_revision() + if db_exists: + self.upgradedb() + else: + from airflow import settings + + existing_tables = set(inspect(settings.engine).get_table_names()) + if any(table in existing_tables for table in self.metadata.tables): + script = self.get_script_object() + base_revision = next(r.revision for r in script.walk_revisions() if r.down_revision is None) + config = self.get_alembic_config() + from alembic import command + + command.stamp(config, base_revision) + self.upgradedb() + else: + self.create_db_from_orm() + def drop_tables(self, connection): """Drop only edge3 tables in reverse dependency order.""" if not self.supports_table_dropping: diff --git a/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py b/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py index e4b0698e7e051..5b037f2903cb0 100644 --- a/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py +++ b/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py @@ -103,6 +103,7 @@ class EdgeWorkerModel(Base, LoggingMixin): jobs_success: Mapped[int] = mapped_column(Integer, default=0) jobs_failed: Mapped[int] = mapped_column(Integer, default=0) sysinfo: Mapped[str | None] = mapped_column(String(256)) + concurrency: Mapped[int | None] = mapped_column(Integer, nullable=True) def __init__( self, @@ -392,3 +393,23 @@ def remove_worker_queues(worker_name: str, queues: list[str], session: Session = logger.error(error_message) raise TypeError(error_message) worker.remove_queues(queues) + + +@provide_session +def set_worker_concurrency(worker_name: str, concurrency: int, session: Session = NEW_SESSION) -> None: + """Set the concurrency of an edge worker.""" + query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name) + worker: EdgeWorkerModel | None = session.scalar(query) + if not worker: + raise ValueError(f"Edge Worker {worker_name} not found in list of registered workers") + if worker.state in ( + EdgeWorkerState.OFFLINE, + EdgeWorkerState.OFFLINE_MAINTENANCE, + EdgeWorkerState.UNKNOWN, + ): + error_message = ( + f"Cannot set concurrency for edge worker {worker_name} as it is in {worker.state} state!" + ) + logger.error(error_message) + raise TypeError(error_message) + worker.concurrency = concurrency diff --git a/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py b/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py index c25ff53a93437..fc780b8766219 100644 --- a/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py +++ b/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py @@ -200,3 +200,10 @@ class WorkerSetStateReturn(BaseModel): str | None, Field(description="Comments about the maintenance state of the worker."), ] = None + concurrency: Annotated[ + int | None, + Field( + description="Desired concurrency for the worker set by an administrator. " + "None means no remote override; the worker uses its startup value.", + ), + ] = None diff --git a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py index 8e3dce56e1564..34368c98c4ae2 100644 --- a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py +++ b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py @@ -238,7 +238,10 @@ def set_state( ) _assert_version(body.sysinfo) # Exception only after worker state is in the DB return WorkerSetStateReturn( - state=worker.state, queues=worker.queues, maintenance_comments=worker.maintenance_comment + state=worker.state, + queues=worker.queues, + maintenance_comments=worker.maintenance_comment, + concurrency=worker.concurrency, ) diff --git a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml index 94feded9ff924..a5daeacb86fbd 100644 --- a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml +++ b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml @@ -1413,6 +1413,13 @@ components: - type: 'null' title: Maintenance Comments description: Comments about the maintenance state of the worker. + concurrency: + anyOf: + - type: integer + - type: 'null' + title: Concurrency + description: Desired concurrency for the worker set by an administrator. + None means no remote override; the worker uses its startup value. type: object required: - state diff --git a/providers/edge3/tests/unit/edge3/cli/test_definition.py b/providers/edge3/tests/unit/edge3/cli/test_definition.py index a99bbc7565e7f..0c1dac8e6d3e7 100644 --- a/providers/edge3/tests/unit/edge3/cli/test_definition.py +++ b/providers/edge3/tests/unit/edge3/cli/test_definition.py @@ -53,8 +53,8 @@ def test_edge_cli_commands_count(self): assert len(commands) == 1 def test_edge_commands_count(self): - """Test that EDGE_COMMANDS contains all 13 subcommands.""" - assert len(EDGE_COMMANDS) == 13 + """Test that EDGE_COMMANDS contains all 14 subcommands.""" + assert len(EDGE_COMMANDS) == 14 @pytest.mark.parametrize( "command", @@ -234,3 +234,17 @@ def test_shutdown_all_workers_args(self): params = ["edge", "shutdown-all-workers", "--yes"] args = self.arg_parser.parse_args(params) assert args.yes is True + + def test_set_worker_concurrency_args(self): + """Test set-worker-concurrency command with required arguments.""" + params = [ + "edge", + "set-worker-concurrency", + "--edge-hostname", + "remote-worker-1", + "--concurrency", + "16", + ] + args = self.arg_parser.parse_args(params) + assert args.edge_hostname == "remote-worker-1" + assert args.concurrency == 16 diff --git a/providers/edge3/tests/unit/edge3/cli/test_worker.py b/providers/edge3/tests/unit/edge3/cli/test_worker.py index b8fa906dbd3f6..a9c16d108221f 100644 --- a/providers/edge3/tests/unit/edge3/cli/test_worker.py +++ b/providers/edge3/tests/unit/edge3/cli/test_worker.py @@ -392,6 +392,33 @@ async def test_heartbeat( assert "queue1" in (queue_list) assert "queue2" in (queue_list) + @patch("airflow.providers.edge3.cli.worker.worker_set_state") + async def test_heartbeat_adopts_remote_concurrency(self, mock_set_state, worker_with_job: EdgeWorker): + EdgeWorker.jobs = [] + EdgeWorker.drain = False + EdgeWorker.maintenance_mode = False + mock_set_state.return_value = WorkerSetStateReturn( + state=EdgeWorkerState.IDLE, queues=None, concurrency=32 + ) + with conf_vars({("edge", "api_url"): "https://invalid-api-test-endpoint"}): + await worker_with_job.heartbeat() + assert worker_with_job.concurrency == 32 + + @patch("airflow.providers.edge3.cli.worker.worker_set_state") + async def test_heartbeat_no_concurrency_override_keeps_startup_value( + self, mock_set_state, worker_with_job: EdgeWorker + ): + EdgeWorker.jobs = [] + EdgeWorker.drain = False + EdgeWorker.maintenance_mode = False + original_concurrency = worker_with_job.concurrency + mock_set_state.return_value = WorkerSetStateReturn( + state=EdgeWorkerState.IDLE, queues=None, concurrency=None + ) + with conf_vars({("edge", "api_url"): "https://invalid-api-test-endpoint"}): + await worker_with_job.heartbeat() + assert worker_with_job.concurrency == original_concurrency + @patch("airflow.providers.edge3.cli.worker.worker_set_state") async def test_version_mismatch(self, mock_set_state, worker_with_job): mock_set_state.side_effect = EdgeWorkerVersionException("") diff --git a/providers/edge3/tests/unit/edge3/models/test_db.py b/providers/edge3/tests/unit/edge3/models/test_db.py index 4424c4b1e59f7..3bef0a1e563ea 100644 --- a/providers/edge3/tests/unit/edge3/models/test_db.py +++ b/providers/edge3/tests/unit/edge3/models/test_db.py @@ -168,11 +168,13 @@ def test_create_db_from_orm(self, mock_command, session): __import__("airflow.providers.edge3.models.db", fromlist=["EdgeDBManager"]).EdgeDBManager, "get_current_revision", ) - def test_initdb_new_db(self, mock_get_rev, mock_create, mock_upgrade, session): + @mock.patch("airflow.providers.edge3.models.db.inspect") + def test_initdb_new_db(self, mock_inspect, mock_get_rev, mock_create, mock_upgrade, session): """Test that initdb calls create_db_from_orm for new databases.""" from airflow.providers.edge3.models.db import EdgeDBManager mock_get_rev.return_value = None + mock_inspect.return_value.get_table_names.return_value = [] # no tables exist manager = EdgeDBManager(session) manager.initdb() @@ -205,11 +207,80 @@ def test_initdb_existing_db(self, mock_get_rev, mock_create, mock_upgrade, sessi mock_create.assert_not_called() def test_revision_heads_map_populated(self): - """Test that _REVISION_HEADS_MAP is populated with the initial migration.""" + """Test that _REVISION_HEADS_MAP is populated with all known migrations.""" from airflow.providers.edge3.models.db import _REVISION_HEADS_MAP assert "3.0.0" in _REVISION_HEADS_MAP assert _REVISION_HEADS_MAP["3.0.0"] == "9d34dfc2de06" + assert "3.2.0" in _REVISION_HEADS_MAP + assert _REVISION_HEADS_MAP["3.2.0"] == "b3c4d5e6f7a8" + + def test_initdb_stamps_and_upgrades_when_tables_exist_without_version(self, session): + """Test that initdb runs incremental migrations when tables exist but alembic version table does not.""" + from sqlalchemy import inspect, text + + from airflow import settings + from airflow.providers.edge3.models.db import EdgeDBManager + + manager = EdgeDBManager(session) + + # Simulate pre-alembic state: tables exist but no version table and no concurrency column + with settings.engine.begin() as conn: + inspector = inspect(conn) + if inspector.has_table("alembic_version_edge3"): + conn.execute(text("DELETE FROM alembic_version_edge3")) + if "concurrency" in {col["name"] for col in inspector.get_columns("edge_worker")}: + from alembic.migration import MigrationContext + from alembic.operations import Operations + + mc = MigrationContext.configure(conn, opts={"render_as_batch": True}) + ops = Operations(mc) + ops.drop_column("edge_worker", "concurrency") + + # initdb() should detect tables exist, stamp to base, then upgrade + manager.initdb() + + with settings.engine.connect() as conn: + version = conn.execute(text("SELECT version_num FROM alembic_version_edge3")).scalar() + columns = {col["name"] for col in inspect(conn).get_columns("edge_worker")} + + assert version == "b3c4d5e6f7a8" + assert "concurrency" in columns + + def test_migration_adds_concurrency_column(self, session): + """Test that upgrading from 3.0.0 actually adds the concurrency column.""" + from alembic import command + from alembic.migration import MigrationContext + from alembic.operations import Operations + from sqlalchemy import inspect + + from airflow import settings + from airflow.providers.edge3.models.db import EdgeDBManager + + manager = EdgeDBManager(session) + config = manager.get_alembic_config() + + # DDL must be committed before alembic opens its own connection — use engine.begin() + # so the DROP is visible to the fresh connection that upgradedb() creates internally. + with settings.engine.begin() as conn: + inspector = inspect(conn) + if "concurrency" in {col["name"] for col in inspector.get_columns("edge_worker")}: + mc = MigrationContext.configure(conn, opts={"render_as_batch": True}) + ops = Operations(mc) + ops.drop_column("edge_worker", "concurrency") + + # Stamp to old revision (pre-concurrency) using alembic's own connection + command.stamp(config, "9d34dfc2de06") + + # Run the upgrade — migration 0002 should detect the missing column and add it + manager.upgradedb() + + # Verify with a fresh connection (upgradedb also uses its own connection) + with settings.engine.connect() as conn: + inspector = inspect(conn) + columns = {col["name"] for col in inspector.get_columns("edge_worker")} + + assert "concurrency" in columns, "Migration 0002 should have added the concurrency column" def test_drop_tables_handles_missing_tables(self, session): """Test that drop_tables handles missing tables gracefully.""" diff --git a/providers/edge3/tests/unit/edge3/worker_api/routes/test_worker.py b/providers/edge3/tests/unit/edge3/worker_api/routes/test_worker.py index 57a0bd810213d..fc5cd111a1d23 100644 --- a/providers/edge3/tests/unit/edge3/worker_api/routes/test_worker.py +++ b/providers/edge3/tests/unit/edge3/worker_api/routes/test_worker.py @@ -26,7 +26,11 @@ from airflow.providers.common.compat.sdk import timezone from airflow.providers.edge3.cli.worker import EdgeWorker -from airflow.providers.edge3.models.edge_worker import EdgeWorkerModel, EdgeWorkerState +from airflow.providers.edge3.models.edge_worker import ( + EdgeWorkerModel, + EdgeWorkerState, + set_worker_concurrency, +) from airflow.providers.edge3.worker_api.datamodels import WorkerQueueUpdateBody, WorkerStateBody from airflow.providers.edge3.worker_api.routes.worker import ( _assert_version, @@ -248,6 +252,90 @@ def test_set_state(self, session: Session, cli_worker: EdgeWorker): assert worker[0].queues == queues assert return_queues == ["default", "default2"] + def test_set_state_returns_concurrency(self, session: Session, cli_worker: EdgeWorker): + """set_state includes the DB-stored concurrency override in its response.""" + rwm = EdgeWorkerModel( + worker_name="test2_worker", + state=EdgeWorkerState.IDLE, + queues=["default"], + first_online=timezone.utcnow(), + ) + rwm.concurrency = 16 + session.add(rwm) + session.commit() + + body = WorkerStateBody( + state=EdgeWorkerState.RUNNING, + jobs_active=0, + queues=["default"], + sysinfo=cli_worker._get_sysinfo(), + ) + result = set_state("test2_worker", body, session) + assert result.concurrency == 16 + + def test_set_state_returns_none_concurrency_when_not_overridden( + self, session: Session, cli_worker: EdgeWorker + ): + """set_state returns None for concurrency when no override is set.""" + rwm = EdgeWorkerModel( + worker_name="test2_worker", + state=EdgeWorkerState.IDLE, + queues=["default"], + first_online=timezone.utcnow(), + ) + session.add(rwm) + session.commit() + + body = WorkerStateBody( + state=EdgeWorkerState.RUNNING, + jobs_active=0, + queues=["default"], + sysinfo=cli_worker._get_sysinfo(), + ) + result = set_state("test2_worker", body, session) + assert result.concurrency is None + + def test_set_worker_concurrency(self, session: Session): + rwm = EdgeWorkerModel( + worker_name="test2_worker", + state=EdgeWorkerState.IDLE, + queues=["default"], + first_online=timezone.utcnow(), + ) + session.add(rwm) + session.commit() + + set_worker_concurrency("test2_worker", 16, session=session) + session.commit() + + worker = session.scalars(select(EdgeWorkerModel)).one() + assert worker.concurrency == 16 + + @pytest.mark.parametrize( + "offline_state", + [ + pytest.param(EdgeWorkerState.OFFLINE, id="offline"), + pytest.param(EdgeWorkerState.OFFLINE_MAINTENANCE, id="offline-maintenance"), + pytest.param(EdgeWorkerState.UNKNOWN, id="unknown"), + ], + ) + def test_set_worker_concurrency_rejects_offline(self, session: Session, offline_state: EdgeWorkerState): + rwm = EdgeWorkerModel( + worker_name="test2_worker", + state=offline_state, + queues=["default"], + first_online=timezone.utcnow(), + ) + session.add(rwm) + session.commit() + + with pytest.raises(TypeError, match="Cannot set concurrency"): + set_worker_concurrency("test2_worker", 8, session=session) + + def test_set_worker_concurrency_raises_for_unknown_worker(self, session: Session): + with pytest.raises(ValueError, match="not found"): + set_worker_concurrency("nonexistent", 8, session=session) + @pytest.mark.parametrize( ("add_queues", "remove_queues", "expected_queues"), [