|
70 | 70 | from pydantic import VERSION as PYDANTIC_VERSION
|
71 | 71 |
|
72 | 72 | if PYDANTIC_VERSION.startswith("2."):
|
73 |
| - from pydantic.v1 import Field, root_validator |
| 73 | + from pydantic.v1 import BaseModel, Field, root_validator |
74 | 74 | else:
|
75 |
| - from pydantic import Field, root_validator |
| 75 | + from pydantic import Field, root_validator, BaseModel |
76 | 76 |
|
77 | 77 | from slugify import slugify
|
78 | 78 | from tenacity import retry, stop_after_attempt, wait_fixed, wait_random
|
|
126 | 126 | taskRoleArn: "{{ task_role_arn }}"
|
127 | 127 | tags: "{{ labels }}"
|
128 | 128 | taskDefinition: "{{ task_definition_arn }}"
|
| 129 | +capacityProviderStrategy: "{{ capacity_provider_strategy }}" |
129 | 130 | """
|
130 | 131 |
|
131 | 132 | # Create task run retry settings
|
@@ -245,6 +246,16 @@ def mask_api_key(task_run_request):
|
245 | 246 | )
|
246 | 247 |
|
247 | 248 |
|
| 249 | +class CapacityProvider(BaseModel): |
| 250 | + """ |
| 251 | + The capacity provider strategy to use when running the task. |
| 252 | + """ |
| 253 | + |
| 254 | + capacityProvider: str |
| 255 | + weight: int |
| 256 | + base: int |
| 257 | + |
| 258 | + |
248 | 259 | class ECSJobConfiguration(BaseJobConfiguration):
|
249 | 260 | """
|
250 | 261 | Job configuration for an ECS worker.
|
@@ -425,6 +436,14 @@ class ECSVariables(BaseVariables):
|
425 | 436 | ),
|
426 | 437 | )
|
427 | 438 | )
|
| 439 | + capacity_provider_strategy: Optional[List[CapacityProvider]] = Field( |
| 440 | + default_factory=list, |
| 441 | + description=( |
| 442 | + "The capacity provider strategy to use when running the task. " |
| 443 | + "If a capacity provider strategy is specified, the selected launch" |
| 444 | + " type will be ignored." |
| 445 | + ), |
| 446 | + ) |
428 | 447 | image: Optional[str] = Field(
|
429 | 448 | default=None,
|
430 | 449 | description=(
|
@@ -1449,17 +1468,24 @@ def _prepare_task_run_request(
|
1449 | 1468 |
|
1450 | 1469 | task_run_request.setdefault("taskDefinition", task_definition_arn)
|
1451 | 1470 | assert task_run_request["taskDefinition"] == task_definition_arn
|
| 1471 | + capacityProviderStrategy = task_run_request.get("capacityProviderStrategy") |
1452 | 1472 |
|
1453 |
| - if task_run_request.get("launchType") == "FARGATE_SPOT": |
| 1473 | + if capacityProviderStrategy: |
| 1474 | + # Should not be provided at all if capacityProviderStrategy is set, see https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_RunTask.html#ECS-RunTask-request-capacityProviderStrategy # noqa |
| 1475 | + self._logger.warning( |
| 1476 | + "Found capacityProviderStrategy. " |
| 1477 | + "Removing launchType from task run request." |
| 1478 | + ) |
| 1479 | + task_run_request.pop("launchType", None) |
| 1480 | + |
| 1481 | + elif task_run_request.get("launchType") == "FARGATE_SPOT": |
1454 | 1482 | # Should not be provided at all for FARGATE SPOT
|
1455 | 1483 | task_run_request.pop("launchType", None)
|
1456 | 1484 |
|
1457 | 1485 | # A capacity provider strategy is required for FARGATE SPOT
|
1458 |
| - task_run_request.setdefault( |
1459 |
| - "capacityProviderStrategy", |
1460 |
| - [{"capacityProvider": "FARGATE_SPOT", "weight": 1}], |
1461 |
| - ) |
1462 |
| - |
| 1486 | + task_run_request["capacityProviderStrategy"] = [ |
| 1487 | + {"capacityProvider": "FARGATE_SPOT", "weight": 1} |
| 1488 | + ] |
1463 | 1489 | overrides = task_run_request.get("overrides", {})
|
1464 | 1490 | container_overrides = overrides.get("containerOverrides", [])
|
1465 | 1491 |
|
|
0 commit comments