diff --git a/docs/Flash_Deploy_Guide.md b/docs/Flash_Deploy_Guide.md index 8518c6a6..75b85829 100644 --- a/docs/Flash_Deploy_Guide.md +++ b/docs/Flash_Deploy_Guide.md @@ -289,13 +289,14 @@ result = await vllm.post("/v1/completions", {"prompt": "hello"}) ### Persistent Storage ```python -from runpod_flash import Endpoint, GpuGroup, NetworkVolume +from runpod_flash import Endpoint, GpuGroup, DataCenter, NetworkVolume, PodTemplate -vol = NetworkVolume(id="vol_abc123") +vol = NetworkVolume(name="model-cache", size=100, datacenter=DataCenter.US_GA_1) @Endpoint( name="model-server", gpu=GpuGroup.AMPERE_80, + datacenter=DataCenter.US_GA_1, volume=vol, template=PodTemplate(containerDiskInGb=100), ) @@ -304,6 +305,24 @@ async def serve(data: dict) -> dict: ... ``` +Multiple volumes across datacenters: + +```python +volumes = [ + NetworkVolume(name="models-us", size=100, datacenter=DataCenter.US_GA_1), + NetworkVolume(name="models-eu", size=100, datacenter=DataCenter.EU_RO_1), +] + +@Endpoint( + name="global-server", + gpu=GpuGroup.AMPERE_80, + datacenter=[DataCenter.US_GA_1, DataCenter.EU_RO_1], + volume=volumes, +) +async def serve(data: dict) -> dict: + ... +``` + ## Troubleshooting ### Build Issues diff --git a/docs/Flash_SDK_Reference.md b/docs/Flash_SDK_Reference.md index 8e78eec0..13c119ed 100644 --- a/docs/Flash_SDK_Reference.md +++ b/docs/Flash_SDK_Reference.md @@ -20,8 +20,8 @@ Endpoint( dependencies: Optional[List[str]] = None, system_dependencies: Optional[List[str]] = None, accelerate_downloads: bool = True, - volume: Optional[NetworkVolume] = None, - datacenter: DataCenter = DataCenter.EU_RO_1, + volume: Optional[Union[NetworkVolume, List[NetworkVolume]]] = None, + datacenter: Optional[Union[DataCenter, List[DataCenter], str, List[str]]] = None, env: Optional[Dict[str, str]] = None, gpu_count: int = 1, execution_timeout_ms: int = 0, @@ -46,8 +46,8 @@ Endpoint( | `dependencies` | `list[str]` | `None` | Python packages to install (e.g., `["torch", "numpy==1.24"]`). | | `system_dependencies` | `list[str]` | `None` | System packages to install. | | `accelerate_downloads` | `bool` | `True` | Enable accelerated downloads. | -| `volume` | `NetworkVolume` | `None` | Network volume for persistent storage. | -| `datacenter` | `DataCenter` | `EU_RO_1` | Preferred datacenter. | +| `volume` | `NetworkVolume` or list | `None` | Network volume(s) for persistent storage. One volume per datacenter. | +| `datacenter` | `DataCenter`, list, `str`, or `None` | `None` | Datacenter(s) to deploy into. `None` means all available DCs. Accepts a single value, a list, or string DC IDs. CPU endpoints must use DCs in `CPU_DATACENTERS`. | | `env` | `dict[str, str]` | `None` | Environment variables for the endpoint. | | `gpu_count` | `int` | `1` | GPUs per worker. | | `execution_timeout_ms` | `int` | `0` | Max execution time in ms. 0 = no limit. | @@ -335,8 +335,20 @@ CPU instance selection. Can also be passed as a string to `cpu=`. | Value | Location | |-------|----------| -| `DataCenter.EU_RO_1` | Europe - Romania (default) | -| `DataCenter.US_TX_3` | US - Texas | +| `DataCenter.US_GA_1` | US - Georgia | +| `DataCenter.US_KS_1` | US - Kansas | +| `DataCenter.US_TX_1` | US - Texas | +| `DataCenter.US_OR_1` | US - Oregon | +| `DataCenter.CA_MTL_1` | Canada - Montreal | +| `DataCenter.EU_NL_1` | Europe - Netherlands | +| `DataCenter.EU_CZ_1` | Europe - Czech Republic | +| `DataCenter.EU_RO_1` | Europe - Romania | +| `DataCenter.EU_NO_1` | Europe - Norway | +| `DataCenter.EU_SE_1` | Europe - Sweden | + +When `datacenter=None` (the default), the endpoint is available in all data centers. + +CPU endpoints are restricted to the `CPU_DATACENTERS` subset: `EU_RO_1`, `US_TX_1`, `EU_SE_1`. ### CudaVersion @@ -350,12 +362,22 @@ CPU instance selection. Can also be passed as a string to `cpu=`. ### NetworkVolume -Persistent storage that survives worker restarts. +Persistent storage that survives worker restarts. Each volume is tied to a specific datacenter. ```python -from runpod_flash import NetworkVolume +from runpod_flash import NetworkVolume, DataCenter +# existing volume by ID vol = NetworkVolume(id="vol_abc123") + +# create a new volume in a specific datacenter +vol = NetworkVolume(name="my-models", size=100, datacenter=DataCenter.US_GA_1) + +# multiple volumes across datacenters (one per DC) +volumes = [ + NetworkVolume(name="models-us", size=100, datacenter=DataCenter.US_GA_1), + NetworkVolume(name="models-eu", size=100, datacenter=DataCenter.EU_RO_1), +] ``` ### PodTemplate @@ -392,6 +414,7 @@ from runpod_flash import ( CpuInstanceType, CudaVersion, DataCenter, + CPU_DATACENTERS, NetworkVolume, PodTemplate, ServerlessScalerType, diff --git a/src/runpod_flash/__init__.py b/src/runpod_flash/__init__.py index b5943ac7..00a1e9dd 100644 --- a/src/runpod_flash/__init__.py +++ b/src/runpod_flash/__init__.py @@ -17,6 +17,7 @@ from .client import remote from .endpoint import Endpoint, EndpointJob from .core.resources import ( + CPU_DATACENTERS, CpuInstanceType, CpuLiveLoadBalancer, CpuLiveServerless, @@ -58,6 +59,7 @@ _RESOURCE_NAMES = frozenset( { + "CPU_DATACENTERS", "CpuInstanceType", "CpuLiveLoadBalancer", "CpuLiveServerless", @@ -104,6 +106,7 @@ def __getattr__(name): return remote elif name in _RESOURCE_NAMES: from .core.resources import ( + CPU_DATACENTERS, CpuInstanceType, CpuLiveLoadBalancer, CpuLiveServerless, @@ -126,6 +129,7 @@ def __getattr__(name): ) attrs = { + "CPU_DATACENTERS": CPU_DATACENTERS, "CpuInstanceType": CpuInstanceType, "CpuLiveLoadBalancer": CpuLiveLoadBalancer, "CpuLiveServerless": CpuLiveServerless, @@ -173,6 +177,7 @@ def __getattr__(name): "Endpoint", "EndpointJob", "remote", + "CPU_DATACENTERS", "CpuInstanceType", "CpuLiveLoadBalancer", "CpuLiveServerless", diff --git a/src/runpod_flash/cli/commands/build_utils/manifest.py b/src/runpod_flash/cli/commands/build_utils/manifest.py index d01ca0dd..659f3e18 100644 --- a/src/runpod_flash/cli/commands/build_utils/manifest.py +++ b/src/runpod_flash/cli/commands/build_utils/manifest.py @@ -218,6 +218,9 @@ def _extract_config_properties(config: Dict[str, Any], resource_config) -> None: ): config["scalerValue"] = resource_config.scalerValue + if hasattr(resource_config, "locations") and resource_config.locations: + config["locations"] = resource_config.locations + if hasattr(resource_config, "env") and resource_config.env: env_dict = dict(resource_config.env) env_dict.pop("RUNPOD_API_KEY", None) diff --git a/src/runpod_flash/core/api/runpod.py b/src/runpod_flash/core/api/runpod.py index 17a3f42a..7ab92a45 100644 --- a/src/runpod_flash/core/api/runpod.py +++ b/src/runpod_flash/core/api/runpod.py @@ -273,6 +273,10 @@ async def save_endpoint(self, input_data: Dict[str, Any]) -> Dict[str, Any]: locations name networkVolumeId + networkVolumeIds { + networkVolumeId + dataCenterId + } flashEnvironmentId scalerType scalerValue diff --git a/src/runpod_flash/core/resources/__init__.py b/src/runpod_flash/core/resources/__init__.py index d12fe752..8029d62e 100644 --- a/src/runpod_flash/core/resources/__init__.py +++ b/src/runpod_flash/core/resources/__init__.py @@ -18,7 +18,7 @@ ) from .serverless_cpu import CpuServerlessEndpoint from .template import PodTemplate -from .network_volume import NetworkVolume, DataCenter +from .network_volume import NetworkVolume, DataCenter, CPU_DATACENTERS from .load_balancer_sls_resource import ( CpuLoadBalancerSlsResource, LoadBalancerSlsResource, @@ -33,6 +33,7 @@ "CpuLoadBalancerSlsResource", "CpuServerlessEndpoint", "CudaVersion", + "CPU_DATACENTERS", "DataCenter", "DeployableResource", "GpuGroup", diff --git a/src/runpod_flash/core/resources/load_balancer_sls_resource.py b/src/runpod_flash/core/resources/load_balancer_sls_resource.py index ee0cf4ef..7a481d24 100644 --- a/src/runpod_flash/core/resources/load_balancer_sls_resource.py +++ b/src/runpod_flash/core/resources/load_balancer_sls_resource.py @@ -342,6 +342,7 @@ class CpuLoadBalancerSlsResource(CpuEndpointMixin, LoadBalancerSlsResource): "allowedCudaVersions", "imageName", "networkVolume", + "networkVolumes", "python_version", } diff --git a/src/runpod_flash/core/resources/network_volume.py b/src/runpod_flash/core/resources/network_volume.py index 66551d1f..9655031f 100644 --- a/src/runpod_flash/core/resources/network_volume.py +++ b/src/runpod_flash/core/resources/network_volume.py @@ -6,6 +6,7 @@ from pydantic import ( Field, field_serializer, + model_validator, ) from ..api.runpod import RunpodRestClient @@ -17,12 +18,44 @@ class DataCenter(str, Enum): - """ - Enum representing available data centers for network volumes. - #TODO: Add more data centers as needed. Lock this to the available data center. - """ - + """Enum representing available RunPod data centers.""" + + US_GA_1 = "US-GA-1" + US_KS_1 = "US-KS-1" + US_TX_1 = "US-TX-1" + US_OR_1 = "US-OR-1" + CA_MTL_1 = "CA-MTL-1" + EU_NL_1 = "EU-NL-1" + EU_CZ_1 = "EU-CZ-1" EU_RO_1 = "EU-RO-1" + EU_NO_1 = "EU-NO-1" + EU_SE_1 = "EU-SE-1" + + @classmethod + def from_string(cls, value: str) -> "DataCenter": + """Parse a datacenter ID string into a DataCenter enum. + + Accepts the canonical form (e.g. "EU-RO-1") as well as common + variations like lowercase or underscore-separated. + """ + normalized = value.strip().upper().replace("_", "-") + try: + return cls(normalized) + except ValueError: + valid = ", ".join(dc.value for dc in cls) + raise ValueError( + f"Unknown datacenter '{value}'. Valid datacenters: {valid}" + ) + + +# data centers that support CPU serverless endpoints +CPU_DATACENTERS: frozenset[DataCenter] = frozenset( + { + DataCenter.EU_RO_1, + DataCenter.US_TX_1, + DataCenter.EU_SE_1, + } +) class NetworkVolume(DeployableResource): @@ -41,13 +74,24 @@ class NetworkVolume(DeployableResource): "name", } - # Internal fixed value - dataCenterId: DataCenter = Field(default=DataCenter.EU_RO_1, frozen=True) + # public alias -- users pass datacenter=, which syncs to dataCenterId for the API + datacenter: Optional[DataCenter] = Field(default=None, exclude=True) + dataCenterId: DataCenter = Field(default=DataCenter.EU_RO_1) id: Optional[str] = Field(default=None) name: str size: Optional[int] = Field(default=100, gt=0) # Size in GB + @model_validator(mode="before") + @classmethod + def sync_datacenter_alias(cls, data): + """Allow datacenter= as a user-friendly alias for dataCenterId.""" + if isinstance(data, dict): + dc = data.pop("datacenter", None) + if dc is not None and "dataCenterId" not in data: + data["dataCenterId"] = dc + return data + def __str__(self) -> str: return f"{self.__class__.__name__}:{self.id}" diff --git a/src/runpod_flash/core/resources/serverless.py b/src/runpod_flash/core/resources/serverless.py index 3300a741..ae340fa3 100644 --- a/src/runpod_flash/core/resources/serverless.py +++ b/src/runpod_flash/core/resources/serverless.py @@ -28,7 +28,7 @@ from .environment import EnvironmentVars from .cpu import CpuInstanceType from .gpu import GpuGroup, GpuType -from .network_volume import NetworkVolume, DataCenter +from .network_volume import NetworkVolume, DataCenter, CPU_DATACENTERS from .template import KeyValuePair, PodTemplate from .resource_manager import ResourceManager @@ -52,14 +52,6 @@ def get_env_vars() -> Dict[str, str]: log = logging.getLogger(__name__) -def _is_prod_environment() -> bool: - env = os.getenv("RUNPOD_ENV") - if env: - return env.lower() == "prod" - api_base = os.getenv("RUNPOD_API_BASE_URL", "https://api.runpod.io") - return "api.runpod.io" in api_base or "api.runpod.ai" in api_base - - class ServerlessScalerType(Enum): QUEUE_DELAY = "QUEUE_DELAY" REQUEST_COUNT = "REQUEST_COUNT" @@ -110,6 +102,7 @@ class ServerlessResource(DeployableResource): "flashEnvironmentId", "imageName", "networkVolume", + "networkVolumes", "python_version", } @@ -158,7 +151,9 @@ class ServerlessResource(DeployableResource): gpus: Optional[List[GpuGroup | GpuType]] = [GpuGroup.ANY] # for gpuIds imageName: Optional[str] = "" # for template.imageName networkVolume: Optional[NetworkVolume] = None - datacenter: DataCenter = Field(default=DataCenter.EU_RO_1) + networkVolumes: Optional[List[NetworkVolume]] = None + # accepts a single DataCenter or a list for multi-dc deployments + datacenter: Optional[List[DataCenter] | DataCenter] = Field(default=None) python_version: Optional[str] = Field( default=None, description="Python version for runtime image selection. Defaults to the local interpreter version at build time.", @@ -240,6 +235,64 @@ def serialize_type(self, value: Optional[ServerlessType]) -> Optional[str]: return None return value.value if isinstance(value, ServerlessType) else value + @model_validator(mode="after") + def normalize_network_volumes(self): + """Merge networkVolume (singular) into networkVolumes list. + + Validates that no two volumes share the same datacenter. + """ + volumes: list[NetworkVolume] = [] + if self.networkVolumes: + volumes.extend(self.networkVolumes) + if self.networkVolume and self.networkVolume not in volumes: + volumes.append(self.networkVolume) + + if not volumes: + return self + + # validate one volume per datacenter + seen_dcs: dict[str, str] = {} + for v in volumes: + dc_id = v.dataCenterId.value + if dc_id in seen_dcs: + raise ValueError( + f"Multiple volumes in datacenter {dc_id} " + f"('{seen_dcs[dc_id]}' and '{v.name}'). " + f"Only one network volume is allowed per datacenter." + ) + seen_dcs[dc_id] = v.name + + self.networkVolumes = volumes + # keep networkVolume pointing at the first for backward compat + self.networkVolume = volumes[0] if volumes else None + + return self + + @field_validator("datacenter", mode="before") + @classmethod + def normalize_datacenter(cls, value): + """Normalize datacenter to a list of DataCenter enums. + + Accepts a single DataCenter, a string, a list of either, or None. + """ + if value is None: + return None + if isinstance(value, DataCenter): + return [value] + if isinstance(value, str): + return [DataCenter.from_string(value)] + if isinstance(value, list): + result = [] + for item in value: + if isinstance(item, DataCenter): + result.append(item) + elif isinstance(item, str): + result.append(DataCenter.from_string(item)) + else: + raise ValueError(f"Invalid datacenter value: {item!r}") + return result + raise ValueError(f"Invalid datacenter value: {value!r}") + @field_validator("gpus") @classmethod def validate_gpus(cls, value: List[GpuGroup | GpuType]) -> List[GpuGroup | GpuType]: @@ -312,8 +365,7 @@ def sync_input_fields(self): self.name = self.name[:-3] self.name += "-fb" - # Sync datacenter to locations field for API (only if not already set) - # Allow overrides in non-prod via env + # sync datacenter list to locations field for API env_locations = os.getenv("RUNPOD_DEFAULT_LOCATIONS") env_datacenter = os.getenv("RUNPOD_DEFAULT_DATACENTER") if env_locations: @@ -324,24 +376,63 @@ def sync_input_fields(self): self.locations = DataCenter(env_datacenter).value except ValueError: self.locations = env_datacenter - elif _is_prod_environment(): - self.locations = self.datacenter.value + elif self.datacenter: + dc_list = ( + self.datacenter + if isinstance(self.datacenter, list) + else [self.datacenter] + ) + self.locations = ",".join(dc.value for dc in dc_list) - # Validate datacenter consistency between endpoint and network volume - if self.networkVolume and self.networkVolume.dataCenterId != self.datacenter: - raise ValueError( - f"Network volume datacenter ({self.networkVolume.dataCenterId.value}) " - f"must match endpoint datacenter ({self.datacenter.value})" + # validate that all network volume DCs are within the endpoint's datacenter list + all_volumes = self.networkVolumes or ( + [self.networkVolume] if self.networkVolume else [] + ) + if all_volumes and self.datacenter: + dc_list = ( + self.datacenter + if isinstance(self.datacenter, list) + else [self.datacenter] ) + for vol in all_volumes: + if vol.dataCenterId not in dc_list: + dc_values = ", ".join(dc.value for dc in dc_list) + raise ValueError( + f"Network volume datacenter ({vol.dataCenterId.value}) " + f"is not in the endpoint's datacenter list ({dc_values})" + ) + # backward compat: sync single volume ID for legacy code paths if self.networkVolume and self.networkVolume.is_created: - # Volume already exists, use its ID self.networkVolumeId = self.networkVolume.id self._sync_input_fields_gpu() return self + @model_validator(mode="after") + def validate_cpu_datacenters(self): + """Ensure CPU endpoints only target data centers that support CPU.""" + if not self._has_cpu_instances(): + return self + if not self.datacenter: + return self + + dc_list = ( + self.datacenter if isinstance(self.datacenter, list) else [self.datacenter] + ) + unsupported = [dc for dc in dc_list if dc not in CPU_DATACENTERS] + if unsupported: + unsupported_str = ", ".join(dc.value for dc in unsupported) + supported_str = ", ".join( + dc.value for dc in sorted(CPU_DATACENTERS, key=lambda d: d.value) + ) + raise ValueError( + f"CPU endpoints are not available in: {unsupported_str}. " + f"Supported CPU data centers: {supported_str}" + ) + return self + @model_validator(mode="after") def validate_worker_range(self): """Ensure worker scaling bounds are valid.""" @@ -487,16 +578,32 @@ def _sync_input_fields_gpu(self): return self async def _ensure_network_volume_deployed(self) -> None: + """Ensures all network volumes are deployed. + + Deploys each volume in networkVolumes and collects their IDs. + Sets networkVolumeId (singular) for backward compat with the first volume. + Populates _deployed_volume_ids for multi-volume API payloads. """ - Ensures network volume is deployed and ready if one is specified. - Updates networkVolumeId with the deployed volume ID. - """ + self._deployed_volume_ids: list[str] = [] + if self.networkVolumeId: - return + self._deployed_volume_ids.append(self.networkVolumeId) - if self.networkVolume: - deployedNetworkVolume = await self.networkVolume.deploy() - self.networkVolumeId = deployedNetworkVolume.id + volumes = self.networkVolumes or ( + [self.networkVolume] if self.networkVolume else [] + ) + for vol in volumes: + if vol.is_created and vol.id: + if vol.id not in self._deployed_volume_ids: + self._deployed_volume_ids.append(vol.id) + else: + deployed = await vol.deploy() + if deployed.id and deployed.id not in self._deployed_volume_ids: + self._deployed_volume_ids.append(deployed.id) + + # backward compat: set singular field from first volume + if self._deployed_volume_ids and not self.networkVolumeId: + self.networkVolumeId = self._deployed_volume_ids[0] async def is_deployed(self) -> bool: """ @@ -696,13 +803,22 @@ async def _do_deploy(self) -> "DeployableResource": self.env = env_dict - # Ensure network volume is deployed first + # Ensure network volumes are deployed first await self._ensure_network_volume_deployed() async with RunpodGraphQLClient() as client: payload = self.model_dump( exclude=self._payload_exclude(), exclude_none=True, mode="json" ) + + # inject multi-volume IDs if available + deployed_ids = getattr(self, "_deployed_volume_ids", []) + if len(deployed_ids) > 1: + payload["networkVolumeIds"] = [ + {"networkVolumeId": vid} for vid in deployed_ids + ] + payload.pop("networkVolumeId", None) + result = await client.save_endpoint(payload) if endpoint := self.__class__(**result): @@ -745,7 +861,7 @@ async def update(self, new_config: "ServerlessResource") -> "ServerlessResource" if not self._has_structural_changes(new_config): log.info(f"Updating endpoint '{self.name}' (ID: {self.id})") - # Ensure network volume is deployed if specified + # Ensure network volumes are deployed if specified await new_config._ensure_network_volume_deployed() async with RunpodGraphQLClient() as client: @@ -757,6 +873,14 @@ async def update(self, new_config: "ServerlessResource") -> "ServerlessResource" ) payload["id"] = self.id # Critical: include ID for update + # inject multi-volume IDs if available + deployed_ids = getattr(new_config, "_deployed_volume_ids", []) + if len(deployed_ids) > 1: + payload["networkVolumeIds"] = [ + {"networkVolumeId": vid} for vid in deployed_ids + ] + payload.pop("networkVolumeId", None) + result = await client.save_endpoint(payload) resolved_template_id = ( result.get("templateId") or self.templateId or new_config.templateId diff --git a/src/runpod_flash/core/resources/serverless_cpu.py b/src/runpod_flash/core/resources/serverless_cpu.py index 41d380e3..1143f0d4 100644 --- a/src/runpod_flash/core/resources/serverless_cpu.py +++ b/src/runpod_flash/core/resources/serverless_cpu.py @@ -123,6 +123,7 @@ class CpuServerlessEndpoint(CpuEndpointMixin, ServerlessEndpoint): "flashEnvironmentId", "imageName", "networkVolume", + "networkVolumes", "python_version", } diff --git a/src/runpod_flash/endpoint.py b/src/runpod_flash/endpoint.py index 5d3f51dc..d05c0023 100644 --- a/src/runpod_flash/endpoint.py +++ b/src/runpod_flash/endpoint.py @@ -258,6 +258,22 @@ def _normalize_cpu( ) +def _normalize_volumes( + volume: Optional[Union[NetworkVolume, List[NetworkVolume]]], +) -> Optional[List[NetworkVolume]]: + """Normalize volume parameter to a list of NetworkVolume.""" + if volume is None: + return None + if isinstance(volume, NetworkVolume): + return [volume] + if isinstance(volume, list): + return volume or None + raise ValueError( + f"volume must be a NetworkVolume or list of NetworkVolume, " + f"got {type(volume).__name__}" + ) + + class Endpoint: """unified configuration and decorator for flash endpoints. @@ -334,8 +350,10 @@ def __init__( dependencies: Optional[List[str]] = None, system_dependencies: Optional[List[str]] = None, accelerate_downloads: bool = True, - volume: Optional[NetworkVolume] = None, - datacenter: DataCenter = DataCenter.EU_RO_1, + volume: Optional[Union[NetworkVolume, List[NetworkVolume]]] = None, + datacenter: Optional[ + Union[DataCenter, List[DataCenter], str, List[str]] + ] = None, env: Optional[Dict[str, str]] = None, gpu_count: int = 1, execution_timeout_ms: int = 0, @@ -367,7 +385,7 @@ def __init__( self.dependencies = dependencies self.system_dependencies = system_dependencies self.accelerate_downloads = accelerate_downloads - self.volume = volume + self.volume = _normalize_volumes(volume) self.datacenter = datacenter self.env = env self.gpu_count = gpu_count @@ -467,9 +485,7 @@ def _build_resource_config(self): "idleTimeout": self.idle_timeout, "executionTimeoutMs": self.execution_timeout_ms, "flashboot": self.flashboot, - "datacenter": self.datacenter.value - if hasattr(self.datacenter, "value") - else self.datacenter, + "datacenter": self.datacenter, "scalerType": self.scaler_type.value if hasattr(self.scaler_type, "value") else self.scaler_type, @@ -482,9 +498,13 @@ def _build_resource_config(self): kwargs["template"] = self.template.model_dump(exclude_none=True) if self.volume is not None: - # serialize to dict to avoid pydantic model identity issues + # serialize to dicts to avoid pydantic model identity issues # when modules get re-imported across different test/import contexts - kwargs["networkVolume"] = self.volume.model_dump(exclude_none=True) + volumes_dicts = [v.model_dump(exclude_none=True) for v in self.volume] + if len(volumes_dicts) == 1: + kwargs["networkVolume"] = volumes_dicts[0] + else: + kwargs["networkVolumes"] = volumes_dicts if self.env is not None: kwargs["env"] = self.env diff --git a/src/runpod_flash/runtime/resource_provisioner.py b/src/runpod_flash/runtime/resource_provisioner.py index 356f32f2..77940176 100644 --- a/src/runpod_flash/runtime/resource_provisioner.py +++ b/src/runpod_flash/runtime/resource_provisioner.py @@ -143,6 +143,8 @@ def create_resource_from_manifest( deployment_kwargs["scalerValue"] = resource_data["scalerValue"] if "instanceIds" in resource_data: deployment_kwargs["instanceIds"] = resource_data["instanceIds"] + if "locations" in resource_data: + deployment_kwargs["locations"] = resource_data["locations"] # Reconstruct NetworkVolume from manifest data if present if "networkVolume" in resource_data: diff --git a/tests/unit/resources/test_network_volume.py b/tests/unit/resources/test_network_volume.py index 17adec1d..49aace8b 100644 --- a/tests/unit/resources/test_network_volume.py +++ b/tests/unit/resources/test_network_volume.py @@ -142,6 +142,19 @@ async def test_deploy_multiple_times_same_name_is_idempotent( ) # Only called once assert result1.id == result2.id == "vol-123456" + def test_datacenter_alias(self): + """Test that datacenter= works as an alias for dataCenterId=.""" + volume = NetworkVolume(name="test", datacenter=DataCenter.EU_RO_1) + assert volume.dataCenterId == DataCenter.EU_RO_1 + + def test_datacenter_alias_does_not_override_explicit(self): + """Test that dataCenterId= takes precedence over datacenter=.""" + volume = NetworkVolume( + name="test", + dataCenterId=DataCenter.EU_RO_1, + ) + assert volume.dataCenterId == DataCenter.EU_RO_1 + def test_resource_id_based_on_name_and_datacenter(self): """Test that resource_id is based on name and datacenter for named volumes.""" # Arrange & Act diff --git a/tests/unit/resources/test_serverless.py b/tests/unit/resources/test_serverless.py index 22bf4b34..55cd29e6 100644 --- a/tests/unit/resources/test_serverless.py +++ b/tests/unit/resources/test_serverless.py @@ -271,33 +271,64 @@ def test_flashboot_appends_to_name(self): assert serverless.name == "test-serverless-fb" - def test_datacenter_defaults_to_eu_ro_1(self): - """Test datacenter defaults to EU_RO_1.""" + def test_datacenter_defaults_to_none(self): + """Test datacenter defaults to None (all datacenters).""" serverless = ServerlessResource(name="test") - assert serverless.datacenter == DataCenter.EU_RO_1 + assert serverless.datacenter is None - def test_datacenter_can_be_overridden(self): - """Test datacenter can be overridden by user.""" - # This would work if we had other datacenters defined + def test_datacenter_single_value(self): + """Test datacenter accepts a single DataCenter and normalizes to list.""" serverless = ServerlessResource(name="test", datacenter=DataCenter.EU_RO_1) - assert serverless.datacenter == DataCenter.EU_RO_1 + assert serverless.datacenter == [DataCenter.EU_RO_1] - def test_locations_synced_from_datacenter(self, monkeypatch): - """Test locations field gets synced from datacenter in prod.""" - monkeypatch.setenv("RUNPOD_ENV", "prod") - serverless = ServerlessResource(name="test") + def test_datacenter_multiple_values(self): + """Test datacenter accepts a list of DataCenter values.""" + serverless = ServerlessResource( + name="test", + datacenter=[DataCenter.EU_RO_1, DataCenter.US_GA_1], + ) + assert serverless.datacenter == [DataCenter.EU_RO_1, DataCenter.US_GA_1] + + def test_datacenter_string_value(self): + """Test datacenter accepts string values.""" + serverless = ServerlessResource(name="test", datacenter="EU-RO-1") + assert serverless.datacenter == [DataCenter.EU_RO_1] + + def test_datacenter_string_list(self): + """Test datacenter accepts list of strings.""" + serverless = ServerlessResource(name="test", datacenter=["EU-RO-1", "US-GA-1"]) + assert serverless.datacenter == [DataCenter.EU_RO_1, DataCenter.US_GA_1] - # Should automatically set locations from datacenter in prod + def test_datacenter_invalid_string_raises(self): + """Test that an invalid datacenter string raises ValueError.""" + with pytest.raises(ValueError, match="Unknown datacenter"): + ServerlessResource(name="test", datacenter="INVALID-DC") + + def test_locations_synced_from_datacenter(self): + """Test locations field gets synced from datacenter.""" + serverless = ServerlessResource(name="test", datacenter=DataCenter.EU_RO_1) assert serverless.locations == "EU-RO-1" + def test_locations_synced_from_multi_datacenter(self): + """Test locations field gets synced from multiple datacenters.""" + serverless = ServerlessResource( + name="test", + datacenter=[DataCenter.EU_RO_1, DataCenter.US_GA_1], + ) + assert serverless.locations == "EU-RO-1,US-GA-1" + + def test_no_datacenter_no_locations(self): + """Test that no datacenter means no locations restriction.""" + serverless = ServerlessResource(name="test") + assert serverless.locations is None + def test_explicit_locations_not_overridden(self): """Test explicit locations field is not overridden.""" - serverless = ServerlessResource(name="test", locations="US-WEST-1") + serverless = ServerlessResource(name="test", locations="US-GA-1") - # Explicit locations should not be overridden - assert serverless.locations == "US-WEST-1" + assert serverless.locations == "US-GA-1" def test_datacenter_validation_matching_datacenters(self): """Test that matching datacenters between endpoint and volume work.""" @@ -306,36 +337,25 @@ def test_datacenter_validation_matching_datacenters(self): name="test", datacenter=DataCenter.EU_RO_1, networkVolume=volume ) - # Should not raise any validation error - assert serverless.datacenter == DataCenter.EU_RO_1 + assert serverless.datacenter == [DataCenter.EU_RO_1] assert serverless.networkVolume.dataCenterId == DataCenter.EU_RO_1 - def test_datacenter_validation_logic_exists(self): - """Test that datacenter validation logic exists in sync_input_fields.""" - # Test by examining the validation code directly - # Since we can't easily mock frozen fields, we'll test the logic exists - volume = NetworkVolume(name="test-volume", dataCenterId=DataCenter.EU_RO_1) - _ = ServerlessResource( - name="test", datacenter=DataCenter.EU_RO_1, networkVolume=volume - ) - - # Create a mock volume with mismatched datacenter for direct validation test - mock_volume = MagicMock() - mock_volume.dataCenterId.value = "US-WEST-1" - mock_datacenter = MagicMock() - mock_datacenter.value = "EU-RO-1" - - # Test the validation logic directly + def test_datacenter_validation_volume_not_in_dc_list(self): + """Test that a volume DC not in endpoint's DC list raises an error.""" + volume = NetworkVolume(name="test-volume", dataCenterId=DataCenter.US_GA_1) with pytest.raises( ValueError, - match="Network volume datacenter.*must match endpoint datacenter", + match="Network volume datacenter.*is not in the endpoint's datacenter list", ): - # Simulate the validation check - if mock_volume.dataCenterId != mock_datacenter: - raise ValueError( - f"Network volume datacenter ({mock_volume.dataCenterId.value}) " - f"must match endpoint datacenter ({mock_datacenter.value})" - ) + ServerlessResource( + name="test", datacenter=DataCenter.EU_RO_1, networkVolume=volume + ) + + def test_volume_dc_allowed_when_no_datacenter_set(self): + """Test that any volume DC is allowed when no datacenter restriction is set.""" + volume = NetworkVolume(name="test-volume", dataCenterId=DataCenter.US_GA_1) + serverless = ServerlessResource(name="test", networkVolume=volume) + assert serverless.networkVolume.dataCenterId == DataCenter.US_GA_1 def test_no_flashboot_keeps_name(self): """Test flashboot=False keeps original name.""" @@ -404,6 +424,103 @@ def test_reverse_sync_cuda_versions(self): assert CudaVersion.V11_8 in serverless.cudaVersions +class TestMultiVolumeValidation: + """Test multiple network volume support.""" + + def test_single_volume_compat(self): + """Test single networkVolume still works.""" + vol = NetworkVolume(name="v1", dataCenterId=DataCenter.EU_RO_1) + s = ServerlessResource(name="test", networkVolume=vol) + assert s.networkVolume is vol + assert s.networkVolumes == [vol] + + def test_multiple_volumes_via_list(self): + """Test networkVolumes accepts multiple volumes.""" + v1 = NetworkVolume(name="v1", dataCenterId=DataCenter.EU_RO_1) + v2 = NetworkVolume(name="v2", dataCenterId=DataCenter.US_GA_1) + s = ServerlessResource(name="test", networkVolumes=[v1, v2]) + assert len(s.networkVolumes) == 2 + assert s.networkVolume is v1 + + def test_duplicate_dc_raises(self): + """Test two volumes in the same DC raises.""" + v1 = NetworkVolume(name="v1", dataCenterId=DataCenter.EU_RO_1) + v2 = NetworkVolume(name="v2", dataCenterId=DataCenter.EU_RO_1) + with pytest.raises(ValueError, match="Multiple volumes in datacenter EU-RO-1"): + ServerlessResource(name="test", networkVolumes=[v1, v2]) + + def test_volumes_dc_outside_endpoint_dc_raises(self): + """Test volume DC not in endpoint's DC list raises.""" + vol = NetworkVolume(name="v1", dataCenterId=DataCenter.US_GA_1) + with pytest.raises( + ValueError, + match="is not in the endpoint's datacenter list", + ): + ServerlessResource( + name="test", + datacenter=DataCenter.EU_RO_1, + networkVolumes=[vol], + ) + + def test_volumes_dc_within_endpoint_dc_list(self): + """Test volume DCs all within endpoint DC list works.""" + v1 = NetworkVolume(name="v1", dataCenterId=DataCenter.EU_RO_1) + v2 = NetworkVolume(name="v2", dataCenterId=DataCenter.US_GA_1) + s = ServerlessResource( + name="test", + datacenter=[DataCenter.EU_RO_1, DataCenter.US_GA_1], + networkVolumes=[v1, v2], + ) + assert len(s.networkVolumes) == 2 + + +class TestCpuDatacenterValidation: + """Test CPU datacenter validation.""" + + def test_cpu_endpoint_in_supported_dc(self): + """Test CPU endpoint in supported datacenter works.""" + endpoint = CpuServerlessEndpoint( + name="test-cpu", + imageName="test/cpu:latest", + datacenter=DataCenter.EU_RO_1, + ) + assert endpoint.datacenter == [DataCenter.EU_RO_1] + + def test_cpu_endpoint_in_unsupported_dc_raises(self): + """Test CPU endpoint in unsupported datacenter raises.""" + with pytest.raises(ValueError, match="CPU endpoints are not available in"): + CpuServerlessEndpoint( + name="test-cpu", + imageName="test/cpu:latest", + datacenter=DataCenter.US_GA_1, + ) + + def test_cpu_endpoint_mixed_dcs_raises(self): + """Test CPU endpoint with mix of supported/unsupported DCs raises.""" + with pytest.raises(ValueError, match="CPU endpoints are not available in"): + CpuServerlessEndpoint( + name="test-cpu", + imageName="test/cpu:latest", + datacenter=[DataCenter.EU_RO_1, DataCenter.US_GA_1], + ) + + def test_cpu_endpoint_no_datacenter_ok(self): + """Test CPU endpoint with no datacenter (all DCs) is allowed.""" + endpoint = CpuServerlessEndpoint( + name="test-cpu", + imageName="test/cpu:latest", + ) + assert endpoint.datacenter is None + + def test_gpu_endpoint_any_dc_ok(self): + """Test GPU endpoint in any datacenter is allowed.""" + serverless = ServerlessResource( + name="test-gpu", + datacenter=DataCenter.US_GA_1, + ) + assert serverless.datacenter == [DataCenter.US_GA_1] + + class TestJobOutput: """Test JobOutput model.""" @@ -522,14 +639,14 @@ async def test_deploy_already_deployed(self): @pytest.mark.asyncio async def test_deploy_success_with_network_volume( - self, mock_runpod_client, deployment_response, monkeypatch + self, mock_runpod_client, deployment_response ): """Test successful deployment with network volume integration.""" - monkeypatch.setenv("RUNPOD_ENV", "prod") serverless = ServerlessResource( name="test-serverless", gpus=[GpuGroup.AMPERE_48], cudaVersions=[CudaVersion.V12_1], + datacenter=DataCenter.EU_RO_1, ) mock_runpod_client.save_endpoint.return_value = deployment_response diff --git a/tests/unit/test_deprecations.py b/tests/unit/test_deprecations.py index 31a691b3..2d0c391e 100644 --- a/tests/unit/test_deprecations.py +++ b/tests/unit/test_deprecations.py @@ -20,6 +20,7 @@ _NON_DEPRECATED = [ "Endpoint", "EndpointJob", + "CPU_DATACENTERS", "CpuInstanceType", "CudaVersion", "DataCenter", diff --git a/tests/unit/test_endpoint.py b/tests/unit/test_endpoint.py index a3a8cfd1..39b7864a 100644 --- a/tests/unit/test_endpoint.py +++ b/tests/unit/test_endpoint.py @@ -123,12 +123,46 @@ def test_all_params(self): assert ep.dependencies == ["torch"] assert ep.system_dependencies == ["ffmpeg"] assert ep.accelerate_downloads is False - assert ep.volume is vol + assert ep.volume == [vol] assert ep.env == {"MY_VAR": "value"} assert ep.gpu_count == 2 assert ep.execution_timeout_ms == 5000 assert ep.flashboot is False + def test_volume_single(self): + vol = NetworkVolume(name="v1", size=50) + ep = Endpoint(name="test", volume=vol) + assert ep.volume == [vol] + + def test_volume_list(self): + v1 = NetworkVolume(name="v1", size=50, dataCenterId=DataCenter.EU_RO_1) + v2 = NetworkVolume(name="v2", size=50, dataCenterId=DataCenter.US_GA_1) + ep = Endpoint(name="test", volume=[v1, v2]) + assert ep.volume == [v1, v2] + + def test_volume_none(self): + ep = Endpoint(name="test") + assert ep.volume is None + + def test_datacenter_single(self): + ep = Endpoint(name="test", datacenter=DataCenter.US_GA_1) + assert ep.datacenter == DataCenter.US_GA_1 + + def test_datacenter_string(self): + ep = Endpoint(name="test", datacenter="US-GA-1") + assert ep.datacenter == "US-GA-1" + + def test_datacenter_list(self): + ep = Endpoint( + name="test", + datacenter=[DataCenter.EU_RO_1, DataCenter.US_GA_1], + ) + assert ep.datacenter == [DataCenter.EU_RO_1, DataCenter.US_GA_1] + + def test_datacenter_none_default(self): + ep = Endpoint(name="test") + assert ep.datacenter is None + def test_scaler_type_defaults(self): ep = Endpoint(name="test") assert ep.scaler_type == ServerlessScalerType.QUEUE_DELAY diff --git a/tests/unit/test_p2_gaps.py b/tests/unit/test_p2_gaps.py index 430e293d..abc2907c 100644 --- a/tests/unit/test_p2_gaps.py +++ b/tests/unit/test_p2_gaps.py @@ -303,14 +303,22 @@ def test_default_locations_overrides_resource(self): @patch.dict(os.environ, {}, clear=True) def test_no_default_locations_uses_resource_default(self): - """Without env var, resource uses its own default.""" + """Without env var or datacenter, locations is None (all DCs).""" from runpod_flash.core.resources import LiveServerless resource = LiveServerless(name="loc-test") - # Without the env var, locations uses the default datacenter from the model - assert isinstance(resource.locations, str) - # Verify it's the model default, not an env var override - assert resource.locations == resource.datacenter.value + # no datacenter specified means no location restriction + assert resource.locations is None + assert resource.datacenter is None + + @patch.dict(os.environ, {}, clear=True) + def test_datacenter_syncs_to_locations(self): + """When datacenter is set, locations is synced from it.""" + from runpod_flash.core.resources import LiveServerless + from runpod_flash.core.resources.network_volume import DataCenter + + resource = LiveServerless(name="loc-test", datacenter=DataCenter.EU_RO_1) + assert resource.locations == "EU-RO-1" # ---------------------------------------------------------------------------