|
29 | 29 | from covalent._shared_files.config import get_config
|
30 | 30 | from covalent._shared_files.logger import app_log
|
31 | 31 | from covalent_aws_plugins import AWSExecutor
|
| 32 | +from pydantic import BaseModel |
32 | 33 |
|
33 | 34 | from .utils import _execute_partial_in_threadpool, _load_pickle_file
|
34 | 35 |
|
35 |
| -_EXECUTOR_PLUGIN_DEFAULTS = { |
36 |
| - "credentials": "", |
37 |
| - "profile": "", |
38 |
| - "region": "", |
39 |
| - "s3_bucket_name": "covalent-fargate-task-resources", |
40 |
| - "ecs_cluster_name": "covalent-fargate-cluster", |
41 |
| - "ecs_task_execution_role_name": "ecsTaskExecutionRole", |
42 |
| - "ecs_task_role_name": "CovalentFargateTaskRole", |
43 |
| - "ecs_task_subnet_id": "", |
44 |
| - "ecs_task_security_group_id": "", |
45 |
| - "ecs_task_log_group_name": "covalent-fargate-task-logs", |
46 |
| - "vcpu": 0.25, |
47 |
| - "memory": 0.5, |
48 |
| - "cache_dir": "/tmp/covalent", |
49 |
| - "poll_freq": 10, |
50 |
| -} |
| 36 | + |
| 37 | +class ExecutorPluginDefaults(BaseModel): |
| 38 | + credentials: str = "" |
| 39 | + profile: str = "" |
| 40 | + region: str = "" |
| 41 | + s3_bucket_name: str = "covalent-fargate-task-resources" |
| 42 | + ecs_cluster_name: str = "covalent-fargate-cluster" |
| 43 | + ecs_task_execution_role_name: str = "ecsTaskExecutionRole" |
| 44 | + ecs_task_role_name: str = "CovalentFargateTaskRole" |
| 45 | + ecs_task_subnet_id: str = "" |
| 46 | + ecs_task_security_group_id: str = "" |
| 47 | + ecs_task_log_group_name: str = "covalent-fargate-task-logs" |
| 48 | + vcpu: float = 0.25 |
| 49 | + memory: float = 0.5 |
| 50 | + cache_dir: str = "/tmp/covalent" |
| 51 | + poll_freq: int = 10 |
| 52 | + |
| 53 | + |
| 54 | +class ExecutorInfraDefaults(BaseModel): |
| 55 | + """ |
| 56 | + Configuration values for provisioning AWS Batch cloud infrastructure |
| 57 | + """ |
| 58 | + |
| 59 | + prefix: str = "" |
| 60 | + credentials: str = "" |
| 61 | + profile: str = "" |
| 62 | + region: str = "" |
| 63 | + s3_bucket_name: str = "covalent-fargate-task-resources" |
| 64 | + ecs_cluster_name: str = "covalent-fargate-cluster" |
| 65 | + ecs_task_execution_role_name: str = "ecsTaskExecutionRole" |
| 66 | + ecs_task_role_name: str = "CovalentFargateTaskRole" |
| 67 | + ecs_task_subnet_id: str = "" |
| 68 | + ecs_task_security_group_id: str = "" |
| 69 | + ecs_task_log_group_name: str = "covalent-fargate-task-logs" |
| 70 | + vcpu: float = 0.25 |
| 71 | + memory: float = 0.5 |
| 72 | + cache_dir: str = "/tmp/covalent" |
| 73 | + poll_freq: int = 10 |
| 74 | + |
| 75 | + |
| 76 | +_EXECUTOR_PLUGIN_DEFAULTS = ExecutorPluginDefaults().dict() |
51 | 77 |
|
52 | 78 | EXECUTOR_PLUGIN_NAME = "ECSExecutor"
|
53 | 79 |
|
@@ -215,7 +241,7 @@ async def submit_task(self, task_metadata: Dict, identity: Dict) -> Any:
|
215 | 241 | ],
|
216 | 242 | },
|
217 | 243 | ],
|
218 |
| - cpu=str(int(self.vcpu * 1024)), |
| 244 | + cpu=str(int(self.vcpu)), |
219 | 245 | memory=str(int(self.memory * 1024)),
|
220 | 246 | )
|
221 | 247 | await _execute_partial_in_threadpool(partial_func)
|
|
0 commit comments