From 52760b21453d370f3c4176781c1ac4ce965b2c5d Mon Sep 17 00:00:00 2001 From: Dheeraj Turaga Date: Wed, 4 Mar 2026 15:42:59 -0600 Subject: [PATCH 1/7] Add runtime concurrency control for remote Edge Workers via CLI MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In distributed Airflow deployments using EdgeExecutor, edge workers run at remote sites that are often unreachable from the central control plane — behind firewalls, in air-gapped networks, or on edge computing nodes. Prior to this change, the worker's concurrency (the number of tasks it runs in parallel) could only be set at startup via the -c CLI flag. Adjusting it required direct shell access to the remote machine and a full worker restart, causing job execution downtime and defeating the operational value of the edge architecture. This contribution solves the problem entirely by introducing a server-driven concurrency control mechanism that works within the existing security and connectivity constraints of edge deployments. Rather than requiring a new communication channel, it repurposes the worker's existing heartbeat protocol — the only bidirectional link between the central Airflow API server and the remote worker — to deliver concurrency updates. An administrator runs: airflow edge set-worker-concurrency --edge-hostname --concurrency This writes the desired concurrency to the central database. On the worker's next heartbeat (typically within seconds), the API server returns the new value in its response, and the worker adopts it immediately — no restart, no remote access, no downtime. The design follows the established pattern for queue management introduced in this provider, but extends it to a runtime-mutable execution parameter for the first time. It required coordinated changes across five architectural layers: the database schema (new migration), the SQLAlchemy model, the FastAPI worker API response contract, the async worker heartbeat loop, and the CLI command surface. The solution preserves the core architectural guarantee of the edge executor — that workers never hold a direct database connection — while making a previously static configuration parameter dynamically controllable from the central site This capability was made possible by DB schema versioning introduced in #61155 --- .../airflow/providers/edge3/cli/definition.py | 15 ++++ .../providers/edge3/cli/edge_command.py | 24 +++++ .../src/airflow/providers/edge3/cli/worker.py | 7 ++ ...02_3_1_0_add_concurrency_to_edge_worker.py | 49 ++++++++++ .../src/airflow/providers/edge3/models/db.py | 1 + .../providers/edge3/models/edge_worker.py | 21 +++++ .../providers/edge3/worker_api/datamodels.py | 7 ++ .../edge3/worker_api/routes/worker.py | 5 +- .../edge3/worker_api/v2-edge-generated.yaml | 12 +-- .../tests/unit/edge3/cli/test_definition.py | 14 +++ .../edge3/tests/unit/edge3/cli/test_worker.py | 27 ++++++ .../edge3/worker_api/routes/test_worker.py | 90 ++++++++++++++++++- 12 files changed, 265 insertions(+), 7 deletions(-) create mode 100644 providers/edge3/src/airflow/providers/edge3/migrations/versions/0002_3_1_0_add_concurrency_to_edge_worker.py diff --git a/providers/edge3/src/airflow/providers/edge3/cli/definition.py b/providers/edge3/src/airflow/providers/edge3/cli/definition.py index 5cd611cc3a438..a1819e9a9a087 100644 --- a/providers/edge3/src/airflow/providers/edge3/cli/definition.py +++ b/providers/edge3/src/airflow/providers/edge3/cli/definition.py @@ -59,6 +59,12 @@ help="Comma delimited list of queues to add or remove.", required=True, ) +ARG_CONCURRENCY_REQUIRED = Arg( + ("-c", "--concurrency"), + type=int, + help="The number of worker processes. Must be a positive integer.", + required=True, +) ARG_WAIT_MAINT = Arg( ("-w", "--wait"), default=False, @@ -229,6 +235,15 @@ func=lazy_load_command("airflow.providers.edge3.cli.edge_command.shutdown_all_workers"), args=(ARG_YES,), ), + ActionCommand( + name="set-worker-concurrency", + help="Set the concurrency of a remote edge worker.", + func=lazy_load_command("airflow.providers.edge3.cli.edge_command.set_remote_worker_concurrency"), + args=( + ARG_REQUIRED_EDGE_HOSTNAME, + ARG_CONCURRENCY_REQUIRED, + ), + ), ] diff --git a/providers/edge3/src/airflow/providers/edge3/cli/edge_command.py b/providers/edge3/src/airflow/providers/edge3/cli/edge_command.py index 864d6851ace18..c18fbdc7f905b 100644 --- a/providers/edge3/src/airflow/providers/edge3/cli/edge_command.py +++ b/providers/edge3/src/airflow/providers/edge3/cli/edge_command.py @@ -427,3 +427,27 @@ def remove_worker_queues(args) -> None: except TypeError as e: logger.error(str(e)) raise SystemExit + + +@cli_utils.action_cli(check_db=False) +@providers_configuration_loaded +def set_remote_worker_concurrency(args) -> None: + """Set the concurrency of a remote edge worker.""" + _check_valid_db_connection() + _check_if_registered_edge_host(hostname=args.edge_hostname) + from airflow.providers.edge3.models.edge_worker import set_worker_concurrency + + if args.concurrency <= 0: + raise SystemExit("Error: Concurrency must be a positive integer.") + + try: + set_worker_concurrency(args.edge_hostname, args.concurrency) + logger.info( + "Concurrency set to %d for Edge Worker host %s by %s.", + args.concurrency, + args.edge_hostname, + getuser(), + ) + except TypeError as e: + logger.error(str(e)) + raise SystemExit diff --git a/providers/edge3/src/airflow/providers/edge3/cli/worker.py b/providers/edge3/src/airflow/providers/edge3/cli/worker.py index c4aa1d735c509..6ac43e388f82c 100644 --- a/providers/edge3/src/airflow/providers/edge3/cli/worker.py +++ b/providers/edge3/src/airflow/providers/edge3/cli/worker.py @@ -401,6 +401,13 @@ async def heartbeat(self, new_maintenance_comments: str | None = None) -> bool: new_maintenance_comments, ) self.queues = worker_info.queues + if worker_info.concurrency is not None and worker_info.concurrency != self.concurrency: + logger.info( + "Concurrency updated from %d to %d by remote request.", + self.concurrency, + worker_info.concurrency, + ) + self.concurrency = worker_info.concurrency if worker_info.state == EdgeWorkerState.MAINTENANCE_REQUEST: logger.info("Maintenance mode requested!") self.maintenance_mode = True diff --git a/providers/edge3/src/airflow/providers/edge3/migrations/versions/0002_3_1_0_add_concurrency_to_edge_worker.py b/providers/edge3/src/airflow/providers/edge3/migrations/versions/0002_3_1_0_add_concurrency_to_edge_worker.py new file mode 100644 index 0000000000000..3fb585307d020 --- /dev/null +++ b/providers/edge3/src/airflow/providers/edge3/migrations/versions/0002_3_1_0_add_concurrency_to_edge_worker.py @@ -0,0 +1,49 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Add concurrency column to edge_worker table. + +Revision ID: b3c4d5e6f7a8 +Revises: 9d34dfc2de06 +Create Date: 2026-03-04 00:00:00.000000 +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "b3c4d5e6f7a8" +down_revision = "9d34dfc2de06" +branch_labels = None +depends_on = None +edge3_version = "3.1.0" + + +def upgrade() -> None: + bind = op.get_bind() + inspector = sa.inspect(bind) + existing_columns = {col["name"] for col in inspector.get_columns("edge_worker")} + if "concurrency" not in existing_columns: + op.add_column("edge_worker", sa.Column("concurrency", sa.Integer(), nullable=True)) + + +def downgrade() -> None: + op.drop_column("edge_worker", "concurrency") diff --git a/providers/edge3/src/airflow/providers/edge3/models/db.py b/providers/edge3/src/airflow/providers/edge3/models/db.py index 1c98cb40235e1..36faadd92cb14 100644 --- a/providers/edge3/src/airflow/providers/edge3/models/db.py +++ b/providers/edge3/src/airflow/providers/edge3/models/db.py @@ -31,6 +31,7 @@ _REVISION_HEADS_MAP: dict[str, str] = { "3.0.0": "9d34dfc2de06", + "3.1.0": "b3c4d5e6f7a8", } diff --git a/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py b/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py index e4b0698e7e051..5b037f2903cb0 100644 --- a/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py +++ b/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py @@ -103,6 +103,7 @@ class EdgeWorkerModel(Base, LoggingMixin): jobs_success: Mapped[int] = mapped_column(Integer, default=0) jobs_failed: Mapped[int] = mapped_column(Integer, default=0) sysinfo: Mapped[str | None] = mapped_column(String(256)) + concurrency: Mapped[int | None] = mapped_column(Integer, nullable=True) def __init__( self, @@ -392,3 +393,23 @@ def remove_worker_queues(worker_name: str, queues: list[str], session: Session = logger.error(error_message) raise TypeError(error_message) worker.remove_queues(queues) + + +@provide_session +def set_worker_concurrency(worker_name: str, concurrency: int, session: Session = NEW_SESSION) -> None: + """Set the concurrency of an edge worker.""" + query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name) + worker: EdgeWorkerModel | None = session.scalar(query) + if not worker: + raise ValueError(f"Edge Worker {worker_name} not found in list of registered workers") + if worker.state in ( + EdgeWorkerState.OFFLINE, + EdgeWorkerState.OFFLINE_MAINTENANCE, + EdgeWorkerState.UNKNOWN, + ): + error_message = ( + f"Cannot set concurrency for edge worker {worker_name} as it is in {worker.state} state!" + ) + logger.error(error_message) + raise TypeError(error_message) + worker.concurrency = concurrency diff --git a/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py b/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py index c25ff53a93437..fc780b8766219 100644 --- a/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py +++ b/providers/edge3/src/airflow/providers/edge3/worker_api/datamodels.py @@ -200,3 +200,10 @@ class WorkerSetStateReturn(BaseModel): str | None, Field(description="Comments about the maintenance state of the worker."), ] = None + concurrency: Annotated[ + int | None, + Field( + description="Desired concurrency for the worker set by an administrator. " + "None means no remote override; the worker uses its startup value.", + ), + ] = None diff --git a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py index 8e3dce56e1564..34368c98c4ae2 100644 --- a/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py +++ b/providers/edge3/src/airflow/providers/edge3/worker_api/routes/worker.py @@ -238,7 +238,10 @@ def set_state( ) _assert_version(body.sysinfo) # Exception only after worker state is in the DB return WorkerSetStateReturn( - state=worker.state, queues=worker.queues, maintenance_comments=worker.maintenance_comment + state=worker.state, + queues=worker.queues, + maintenance_comments=worker.maintenance_comment, + concurrency=worker.concurrency, ) diff --git a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml index 94feded9ff924..c1b87e6523f94 100644 --- a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml +++ b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml @@ -1247,11 +1247,6 @@ components: type: type: string title: Error Type - input: - title: Input - ctx: - type: object - title: Context type: object required: - loc @@ -1413,6 +1408,13 @@ components: - type: 'null' title: Maintenance Comments description: Comments about the maintenance state of the worker. + concurrency: + anyOf: + - type: integer + - type: 'null' + title: Concurrency + description: Desired concurrency for the worker set by an administrator. + None means no remote override; the worker uses its startup value. type: object required: - state diff --git a/providers/edge3/tests/unit/edge3/cli/test_definition.py b/providers/edge3/tests/unit/edge3/cli/test_definition.py index a99bbc7565e7f..0f4927ae52e1e 100644 --- a/providers/edge3/tests/unit/edge3/cli/test_definition.py +++ b/providers/edge3/tests/unit/edge3/cli/test_definition.py @@ -234,3 +234,17 @@ def test_shutdown_all_workers_args(self): params = ["edge", "shutdown-all-workers", "--yes"] args = self.arg_parser.parse_args(params) assert args.yes is True + + def test_set_worker_concurrency_args(self): + """Test set-worker-concurrency command with required arguments.""" + params = [ + "edge", + "set-worker-concurrency", + "--edge-hostname", + "remote-worker-1", + "--concurrency", + "16", + ] + args = self.arg_parser.parse_args(params) + assert args.edge_hostname == "remote-worker-1" + assert args.concurrency == 16 diff --git a/providers/edge3/tests/unit/edge3/cli/test_worker.py b/providers/edge3/tests/unit/edge3/cli/test_worker.py index b8fa906dbd3f6..a9c16d108221f 100644 --- a/providers/edge3/tests/unit/edge3/cli/test_worker.py +++ b/providers/edge3/tests/unit/edge3/cli/test_worker.py @@ -392,6 +392,33 @@ async def test_heartbeat( assert "queue1" in (queue_list) assert "queue2" in (queue_list) + @patch("airflow.providers.edge3.cli.worker.worker_set_state") + async def test_heartbeat_adopts_remote_concurrency(self, mock_set_state, worker_with_job: EdgeWorker): + EdgeWorker.jobs = [] + EdgeWorker.drain = False + EdgeWorker.maintenance_mode = False + mock_set_state.return_value = WorkerSetStateReturn( + state=EdgeWorkerState.IDLE, queues=None, concurrency=32 + ) + with conf_vars({("edge", "api_url"): "https://invalid-api-test-endpoint"}): + await worker_with_job.heartbeat() + assert worker_with_job.concurrency == 32 + + @patch("airflow.providers.edge3.cli.worker.worker_set_state") + async def test_heartbeat_no_concurrency_override_keeps_startup_value( + self, mock_set_state, worker_with_job: EdgeWorker + ): + EdgeWorker.jobs = [] + EdgeWorker.drain = False + EdgeWorker.maintenance_mode = False + original_concurrency = worker_with_job.concurrency + mock_set_state.return_value = WorkerSetStateReturn( + state=EdgeWorkerState.IDLE, queues=None, concurrency=None + ) + with conf_vars({("edge", "api_url"): "https://invalid-api-test-endpoint"}): + await worker_with_job.heartbeat() + assert worker_with_job.concurrency == original_concurrency + @patch("airflow.providers.edge3.cli.worker.worker_set_state") async def test_version_mismatch(self, mock_set_state, worker_with_job): mock_set_state.side_effect = EdgeWorkerVersionException("") diff --git a/providers/edge3/tests/unit/edge3/worker_api/routes/test_worker.py b/providers/edge3/tests/unit/edge3/worker_api/routes/test_worker.py index 57a0bd810213d..fc5cd111a1d23 100644 --- a/providers/edge3/tests/unit/edge3/worker_api/routes/test_worker.py +++ b/providers/edge3/tests/unit/edge3/worker_api/routes/test_worker.py @@ -26,7 +26,11 @@ from airflow.providers.common.compat.sdk import timezone from airflow.providers.edge3.cli.worker import EdgeWorker -from airflow.providers.edge3.models.edge_worker import EdgeWorkerModel, EdgeWorkerState +from airflow.providers.edge3.models.edge_worker import ( + EdgeWorkerModel, + EdgeWorkerState, + set_worker_concurrency, +) from airflow.providers.edge3.worker_api.datamodels import WorkerQueueUpdateBody, WorkerStateBody from airflow.providers.edge3.worker_api.routes.worker import ( _assert_version, @@ -248,6 +252,90 @@ def test_set_state(self, session: Session, cli_worker: EdgeWorker): assert worker[0].queues == queues assert return_queues == ["default", "default2"] + def test_set_state_returns_concurrency(self, session: Session, cli_worker: EdgeWorker): + """set_state includes the DB-stored concurrency override in its response.""" + rwm = EdgeWorkerModel( + worker_name="test2_worker", + state=EdgeWorkerState.IDLE, + queues=["default"], + first_online=timezone.utcnow(), + ) + rwm.concurrency = 16 + session.add(rwm) + session.commit() + + body = WorkerStateBody( + state=EdgeWorkerState.RUNNING, + jobs_active=0, + queues=["default"], + sysinfo=cli_worker._get_sysinfo(), + ) + result = set_state("test2_worker", body, session) + assert result.concurrency == 16 + + def test_set_state_returns_none_concurrency_when_not_overridden( + self, session: Session, cli_worker: EdgeWorker + ): + """set_state returns None for concurrency when no override is set.""" + rwm = EdgeWorkerModel( + worker_name="test2_worker", + state=EdgeWorkerState.IDLE, + queues=["default"], + first_online=timezone.utcnow(), + ) + session.add(rwm) + session.commit() + + body = WorkerStateBody( + state=EdgeWorkerState.RUNNING, + jobs_active=0, + queues=["default"], + sysinfo=cli_worker._get_sysinfo(), + ) + result = set_state("test2_worker", body, session) + assert result.concurrency is None + + def test_set_worker_concurrency(self, session: Session): + rwm = EdgeWorkerModel( + worker_name="test2_worker", + state=EdgeWorkerState.IDLE, + queues=["default"], + first_online=timezone.utcnow(), + ) + session.add(rwm) + session.commit() + + set_worker_concurrency("test2_worker", 16, session=session) + session.commit() + + worker = session.scalars(select(EdgeWorkerModel)).one() + assert worker.concurrency == 16 + + @pytest.mark.parametrize( + "offline_state", + [ + pytest.param(EdgeWorkerState.OFFLINE, id="offline"), + pytest.param(EdgeWorkerState.OFFLINE_MAINTENANCE, id="offline-maintenance"), + pytest.param(EdgeWorkerState.UNKNOWN, id="unknown"), + ], + ) + def test_set_worker_concurrency_rejects_offline(self, session: Session, offline_state: EdgeWorkerState): + rwm = EdgeWorkerModel( + worker_name="test2_worker", + state=offline_state, + queues=["default"], + first_online=timezone.utcnow(), + ) + session.add(rwm) + session.commit() + + with pytest.raises(TypeError, match="Cannot set concurrency"): + set_worker_concurrency("test2_worker", 8, session=session) + + def test_set_worker_concurrency_raises_for_unknown_worker(self, session: Session): + with pytest.raises(ValueError, match="not found"): + set_worker_concurrency("nonexistent", 8, session=session) + @pytest.mark.parametrize( ("add_queues", "remove_queues", "expected_queues"), [ From 0c470b1bc92e5c4b1b60264e8669fdd603022789 Mon Sep 17 00:00:00 2001 From: Dheeraj Turaga Date: Thu, 5 Mar 2026 17:03:53 -0600 Subject: [PATCH 2/7] Fix unit test --- providers/edge3/tests/unit/edge3/cli/test_definition.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/providers/edge3/tests/unit/edge3/cli/test_definition.py b/providers/edge3/tests/unit/edge3/cli/test_definition.py index 0f4927ae52e1e..0c1dac8e6d3e7 100644 --- a/providers/edge3/tests/unit/edge3/cli/test_definition.py +++ b/providers/edge3/tests/unit/edge3/cli/test_definition.py @@ -53,8 +53,8 @@ def test_edge_cli_commands_count(self): assert len(commands) == 1 def test_edge_commands_count(self): - """Test that EDGE_COMMANDS contains all 13 subcommands.""" - assert len(EDGE_COMMANDS) == 13 + """Test that EDGE_COMMANDS contains all 14 subcommands.""" + assert len(EDGE_COMMANDS) == 14 @pytest.mark.parametrize( "command", From fbf6136635a9577d19484e8d7ec765493e0b5650 Mon Sep 17 00:00:00 2001 From: Dheeraj Turaga Date: Thu, 5 Mar 2026 18:04:41 -0600 Subject: [PATCH 3/7] update migration files to 3.2.0 --- ...e_worker.py => 0002_3_2_0_add_concurrency_to_edge_worker.py} | 2 +- providers/edge3/src/airflow/providers/edge3/models/db.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) rename providers/edge3/src/airflow/providers/edge3/migrations/versions/{0002_3_1_0_add_concurrency_to_edge_worker.py => 0002_3_2_0_add_concurrency_to_edge_worker.py} (98%) diff --git a/providers/edge3/src/airflow/providers/edge3/migrations/versions/0002_3_1_0_add_concurrency_to_edge_worker.py b/providers/edge3/src/airflow/providers/edge3/migrations/versions/0002_3_2_0_add_concurrency_to_edge_worker.py similarity index 98% rename from providers/edge3/src/airflow/providers/edge3/migrations/versions/0002_3_1_0_add_concurrency_to_edge_worker.py rename to providers/edge3/src/airflow/providers/edge3/migrations/versions/0002_3_2_0_add_concurrency_to_edge_worker.py index 3fb585307d020..edf6cc7f39ccd 100644 --- a/providers/edge3/src/airflow/providers/edge3/migrations/versions/0002_3_1_0_add_concurrency_to_edge_worker.py +++ b/providers/edge3/src/airflow/providers/edge3/migrations/versions/0002_3_2_0_add_concurrency_to_edge_worker.py @@ -34,7 +34,7 @@ down_revision = "9d34dfc2de06" branch_labels = None depends_on = None -edge3_version = "3.1.0" +edge3_version = "3.2.0" def upgrade() -> None: diff --git a/providers/edge3/src/airflow/providers/edge3/models/db.py b/providers/edge3/src/airflow/providers/edge3/models/db.py index 36faadd92cb14..bf2af6dd3b61e 100644 --- a/providers/edge3/src/airflow/providers/edge3/models/db.py +++ b/providers/edge3/src/airflow/providers/edge3/models/db.py @@ -31,7 +31,7 @@ _REVISION_HEADS_MAP: dict[str, str] = { "3.0.0": "9d34dfc2de06", - "3.1.0": "b3c4d5e6f7a8", + "3.2.0": "b3c4d5e6f7a8", } From 90f734ded1f0c82b0f9bba80b79c42b2c644ba02 Mon Sep 17 00:00:00 2001 From: Dheeraj Turaga Date: Thu, 5 Mar 2026 18:08:39 -0600 Subject: [PATCH 4/7] Not sure why prek got rid of this --- .../providers/edge3/worker_api/v2-edge-generated.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml index c1b87e6523f94..a5daeacb86fbd 100644 --- a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml +++ b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml @@ -1247,6 +1247,11 @@ components: type: type: string title: Error Type + input: + title: Input + ctx: + type: object + title: Context type: object required: - loc From e54e1c2c4702eb57bc022a03381e8f14932a907b Mon Sep 17 00:00:00 2001 From: Dheeraj Turaga Date: Thu, 5 Mar 2026 22:30:59 -0600 Subject: [PATCH 5/7] Fix edge3 DB migration skipping incremental migrations on existing installs When edge3 tables existed before alembic tracking was introduced, airflow db migrate would stamp directly to head without applying incremental migrations, leaving the schema out of sync (e.g. missing the concurrency column added in 3.2.0). --- .../src/airflow/providers/edge3/models/db.py | 27 +++++++ .../edge3/tests/unit/edge3/models/test_db.py | 71 ++++++++++++++++++- 2 files changed, 97 insertions(+), 1 deletion(-) diff --git a/providers/edge3/src/airflow/providers/edge3/models/db.py b/providers/edge3/src/airflow/providers/edge3/models/db.py index bf2af6dd3b61e..b9d3609bd4be4 100644 --- a/providers/edge3/src/airflow/providers/edge3/models/db.py +++ b/providers/edge3/src/airflow/providers/edge3/models/db.py @@ -46,6 +46,33 @@ class EdgeDBManager(BaseDBManager): supports_table_dropping = True revision_heads_map = _REVISION_HEADS_MAP + def initdb(self): + """ + Initialize the database, handling pre-alembic installations. + + If the edge3 tables already exist but the alembic version table does not + (e.g. created via create_all before the migration chain was introduced), + stamp to the first revision and run upgradedb() so every incremental + migration is applied rather than jumping straight to head. + """ + db_exists = self.get_current_revision() + if db_exists: + self.upgradedb() + else: + from airflow import settings + + existing_tables = set(inspect(settings.engine).get_table_names()) + if any(table in existing_tables for table in self.metadata.tables): + script = self.get_script_object() + base_revision = next(r.revision for r in script.walk_revisions() if r.down_revision is None) + config = self.get_alembic_config() + from alembic import command + + command.stamp(config, base_revision) + self.upgradedb() + else: + self.create_db_from_orm() + def drop_tables(self, connection): """Drop only edge3 tables in reverse dependency order.""" if not self.supports_table_dropping: diff --git a/providers/edge3/tests/unit/edge3/models/test_db.py b/providers/edge3/tests/unit/edge3/models/test_db.py index 4424c4b1e59f7..bf4442727e819 100644 --- a/providers/edge3/tests/unit/edge3/models/test_db.py +++ b/providers/edge3/tests/unit/edge3/models/test_db.py @@ -205,11 +205,80 @@ def test_initdb_existing_db(self, mock_get_rev, mock_create, mock_upgrade, sessi mock_create.assert_not_called() def test_revision_heads_map_populated(self): - """Test that _REVISION_HEADS_MAP is populated with the initial migration.""" + """Test that _REVISION_HEADS_MAP is populated with all known migrations.""" from airflow.providers.edge3.models.db import _REVISION_HEADS_MAP assert "3.0.0" in _REVISION_HEADS_MAP assert _REVISION_HEADS_MAP["3.0.0"] == "9d34dfc2de06" + assert "3.2.0" in _REVISION_HEADS_MAP + assert _REVISION_HEADS_MAP["3.2.0"] == "b3c4d5e6f7a8" + + def test_initdb_stamps_and_upgrades_when_tables_exist_without_version(self, session): + """Test that initdb runs incremental migrations when tables exist but alembic version table does not.""" + from sqlalchemy import inspect, text + + from airflow import settings + from airflow.providers.edge3.models.db import EdgeDBManager + + manager = EdgeDBManager(session) + config = manager.get_alembic_config() + + # Simulate pre-alembic state: tables exist but no version table and no concurrency column + with settings.engine.begin() as conn: + conn.execute(text("DELETE FROM alembic_version_edge3")) + inspector = inspect(conn) + if "concurrency" in {col["name"] for col in inspector.get_columns("edge_worker")}: + from alembic.migration import MigrationContext + from alembic.operations import Operations + + mc = MigrationContext.configure(conn, opts={"render_as_batch": True}) + ops = Operations(mc) + ops.drop_column("edge_worker", "concurrency") + + # initdb() should detect tables exist, stamp to base, then upgrade + manager.initdb() + + with settings.engine.connect() as conn: + version = conn.execute(text("SELECT version_num FROM alembic_version_edge3")).scalar() + columns = {col["name"] for col in inspect(conn).get_columns("edge_worker")} + + assert version == "b3c4d5e6f7a8" + assert "concurrency" in columns + + def test_migration_adds_concurrency_column(self, session): + """Test that upgrading from 3.0.0 actually adds the concurrency column.""" + from alembic import command + from alembic.migration import MigrationContext + from alembic.operations import Operations + from sqlalchemy import inspect + + from airflow import settings + from airflow.providers.edge3.models.db import EdgeDBManager + + manager = EdgeDBManager(session) + config = manager.get_alembic_config() + + # DDL must be committed before alembic opens its own connection — use engine.begin() + # so the DROP is visible to the fresh connection that upgradedb() creates internally. + with settings.engine.begin() as conn: + inspector = inspect(conn) + if "concurrency" in {col["name"] for col in inspector.get_columns("edge_worker")}: + mc = MigrationContext.configure(conn, opts={"render_as_batch": True}) + ops = Operations(mc) + ops.drop_column("edge_worker", "concurrency") + + # Stamp to old revision (pre-concurrency) using alembic's own connection + command.stamp(config, "9d34dfc2de06") + + # Run the upgrade — migration 0002 should detect the missing column and add it + manager.upgradedb() + + # Verify with a fresh connection (upgradedb also uses its own connection) + with settings.engine.connect() as conn: + inspector = inspect(conn) + columns = {col["name"] for col in inspector.get_columns("edge_worker")} + + assert "concurrency" in columns, "Migration 0002 should have added the concurrency column" def test_drop_tables_handles_missing_tables(self, session): """Test that drop_tables handles missing tables gracefully.""" From 2a89a31442fdb13be8cef08a5bcff9e13d044ef1 Mon Sep 17 00:00:00 2001 From: Dheeraj Turaga Date: Fri, 6 Mar 2026 08:11:52 -0600 Subject: [PATCH 6/7] Fix tests --- providers/edge3/src/airflow/providers/edge3/models/db.py | 2 +- providers/edge3/tests/unit/edge3/models/test_db.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/providers/edge3/src/airflow/providers/edge3/models/db.py b/providers/edge3/src/airflow/providers/edge3/models/db.py index b9d3609bd4be4..e98b3cc24b97f 100644 --- a/providers/edge3/src/airflow/providers/edge3/models/db.py +++ b/providers/edge3/src/airflow/providers/edge3/models/db.py @@ -52,7 +52,7 @@ def initdb(self): If the edge3 tables already exist but the alembic version table does not (e.g. created via create_all before the migration chain was introduced), - stamp to the first revision and run upgradedb() so every incremental + stamp to the first revision and run the incremental upgrade so every migration is applied rather than jumping straight to head. """ db_exists = self.get_current_revision() diff --git a/providers/edge3/tests/unit/edge3/models/test_db.py b/providers/edge3/tests/unit/edge3/models/test_db.py index bf4442727e819..c5bbc5194d44b 100644 --- a/providers/edge3/tests/unit/edge3/models/test_db.py +++ b/providers/edge3/tests/unit/edge3/models/test_db.py @@ -168,11 +168,13 @@ def test_create_db_from_orm(self, mock_command, session): __import__("airflow.providers.edge3.models.db", fromlist=["EdgeDBManager"]).EdgeDBManager, "get_current_revision", ) - def test_initdb_new_db(self, mock_get_rev, mock_create, mock_upgrade, session): + @mock.patch("airflow.providers.edge3.models.db.inspect") + def test_initdb_new_db(self, mock_inspect, mock_get_rev, mock_create, mock_upgrade, session): """Test that initdb calls create_db_from_orm for new databases.""" from airflow.providers.edge3.models.db import EdgeDBManager mock_get_rev.return_value = None + mock_inspect.return_value.get_table_names.return_value = [] # no tables exist manager = EdgeDBManager(session) manager.initdb() @@ -221,7 +223,6 @@ def test_initdb_stamps_and_upgrades_when_tables_exist_without_version(self, sess from airflow.providers.edge3.models.db import EdgeDBManager manager = EdgeDBManager(session) - config = manager.get_alembic_config() # Simulate pre-alembic state: tables exist but no version table and no concurrency column with settings.engine.begin() as conn: From 7d2d3b5ff20b4d5e8ad117b551deae3fb0819c7a Mon Sep 17 00:00:00 2001 From: Dheeraj Turaga Date: Fri, 6 Mar 2026 13:40:56 -0600 Subject: [PATCH 7/7] Fix more tests --- providers/edge3/tests/unit/edge3/models/test_db.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/providers/edge3/tests/unit/edge3/models/test_db.py b/providers/edge3/tests/unit/edge3/models/test_db.py index c5bbc5194d44b..3bef0a1e563ea 100644 --- a/providers/edge3/tests/unit/edge3/models/test_db.py +++ b/providers/edge3/tests/unit/edge3/models/test_db.py @@ -226,8 +226,9 @@ def test_initdb_stamps_and_upgrades_when_tables_exist_without_version(self, sess # Simulate pre-alembic state: tables exist but no version table and no concurrency column with settings.engine.begin() as conn: - conn.execute(text("DELETE FROM alembic_version_edge3")) inspector = inspect(conn) + if inspector.has_table("alembic_version_edge3"): + conn.execute(text("DELETE FROM alembic_version_edge3")) if "concurrency" in {col["name"] for col in inspector.get_columns("edge_worker")}: from alembic.migration import MigrationContext from alembic.operations import Operations