Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 1e21cd5

Browse files
Support for capacity provider (#407)
Co-authored-by: nate nowack <[email protected]>
1 parent f9e5e9b commit 1e21cd5

File tree

2 files changed

+70
-8
lines changed

2 files changed

+70
-8
lines changed

prefect_aws/workers/ecs_worker.py

+34-8
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@
7070
from pydantic import VERSION as PYDANTIC_VERSION
7171

7272
if PYDANTIC_VERSION.startswith("2."):
73-
from pydantic.v1 import Field, root_validator
73+
from pydantic.v1 import BaseModel, Field, root_validator
7474
else:
75-
from pydantic import Field, root_validator
75+
from pydantic import Field, root_validator, BaseModel
7676

7777
from slugify import slugify
7878
from tenacity import retry, stop_after_attempt, wait_fixed, wait_random
@@ -126,6 +126,7 @@
126126
taskRoleArn: "{{ task_role_arn }}"
127127
tags: "{{ labels }}"
128128
taskDefinition: "{{ task_definition_arn }}"
129+
capacityProviderStrategy: "{{ capacity_provider_strategy }}"
129130
"""
130131

131132
# Create task run retry settings
@@ -245,6 +246,16 @@ def mask_api_key(task_run_request):
245246
)
246247

247248

249+
class CapacityProvider(BaseModel):
250+
"""
251+
The capacity provider strategy to use when running the task.
252+
"""
253+
254+
capacityProvider: str
255+
weight: int
256+
base: int
257+
258+
248259
class ECSJobConfiguration(BaseJobConfiguration):
249260
"""
250261
Job configuration for an ECS worker.
@@ -425,6 +436,14 @@ class ECSVariables(BaseVariables):
425436
),
426437
)
427438
)
439+
capacity_provider_strategy: Optional[List[CapacityProvider]] = Field(
440+
default_factory=list,
441+
description=(
442+
"The capacity provider strategy to use when running the task. "
443+
"If a capacity provider strategy is specified, the selected launch"
444+
" type will be ignored."
445+
),
446+
)
428447
image: Optional[str] = Field(
429448
default=None,
430449
description=(
@@ -1449,17 +1468,24 @@ def _prepare_task_run_request(
14491468

14501469
task_run_request.setdefault("taskDefinition", task_definition_arn)
14511470
assert task_run_request["taskDefinition"] == task_definition_arn
1471+
capacityProviderStrategy = task_run_request.get("capacityProviderStrategy")
14521472

1453-
if task_run_request.get("launchType") == "FARGATE_SPOT":
1473+
if capacityProviderStrategy:
1474+
# Should not be provided at all if capacityProviderStrategy is set, see https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_RunTask.html#ECS-RunTask-request-capacityProviderStrategy # noqa
1475+
self._logger.warning(
1476+
"Found capacityProviderStrategy. "
1477+
"Removing launchType from task run request."
1478+
)
1479+
task_run_request.pop("launchType", None)
1480+
1481+
elif task_run_request.get("launchType") == "FARGATE_SPOT":
14541482
# Should not be provided at all for FARGATE SPOT
14551483
task_run_request.pop("launchType", None)
14561484

14571485
# A capacity provider strategy is required for FARGATE SPOT
1458-
task_run_request.setdefault(
1459-
"capacityProviderStrategy",
1460-
[{"capacityProvider": "FARGATE_SPOT", "weight": 1}],
1461-
)
1462-
1486+
task_run_request["capacityProviderStrategy"] = [
1487+
{"capacityProvider": "FARGATE_SPOT", "weight": 1}
1488+
]
14631489
overrides = task_run_request.get("overrides", {})
14641490
container_overrides = overrides.get("containerOverrides", [])
14651491

tests/workers/test_ecs_worker.py

+36
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,7 @@ async def test_launch_types(
506506
# Instead, it requires a capacity provider strategy but this is not supported
507507
# by moto and is not present on the task even when provided so we assert on the
508508
# mock call to ensure it is sent
509+
509510
assert mock_run_task.call_args[0][1].get("capacityProviderStrategy") == [
510511
{"capacityProvider": "FARGATE_SPOT", "weight": 1}
511512
]
@@ -2050,6 +2051,41 @@ async def test_user_defined_environment_variables_in_task_definition_template(
20502051
]
20512052

20522053

2054+
@pytest.mark.usefixtures("ecs_mocks")
2055+
async def test_user_defined_capacity_provider_strategy(
2056+
aws_credentials: AwsCredentials, flow_run: FlowRun
2057+
):
2058+
configuration = await construct_configuration(
2059+
aws_credentials=aws_credentials,
2060+
capacity_provider_strategy=[
2061+
{"base": 0, "weight": 1, "capacityProvider": "r6i.large"}
2062+
],
2063+
)
2064+
session = aws_credentials.get_boto3_session()
2065+
ecs_client = session.client("ecs")
2066+
2067+
async with ECSWorker(work_pool_name="test") as worker:
2068+
# Capture the task run call because moto does not track
2069+
# 'capacityProviderStrategy'
2070+
original_run_task = worker._create_task_run
2071+
mock_run_task = MagicMock(side_effect=original_run_task)
2072+
worker._create_task_run = mock_run_task
2073+
2074+
result = await run_then_stop_task(worker, configuration, flow_run)
2075+
2076+
assert result.status_code == 0
2077+
_, task_arn = parse_identifier(result.identifier)
2078+
2079+
task = describe_task(ecs_client, task_arn)
2080+
assert not task.get("launchType")
2081+
# Instead, it requires a capacity provider strategy but this is not supported
2082+
# by moto and is not present on the task even when provided so we assert on the
2083+
# mock call to ensure it is sent
2084+
assert mock_run_task.call_args[0][1].get("capacityProviderStrategy") == [
2085+
{"base": 0, "weight": 1, "capacityProvider": "r6i.large"},
2086+
]
2087+
2088+
20532089
@pytest.mark.usefixtures("ecs_mocks")
20542090
async def test_user_defined_environment_variables_in_task_run_request_template(
20552091
aws_credentials: AwsCredentials, flow_run: FlowRun

0 commit comments

Comments
 (0)