Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 75 additions & 48 deletions rock/sandbox/sandbox_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,26 @@

logger = init_logger(__name__)

# ---------------------------------------------------------------------------
# Transition table: (current_state, action) -> handler method name
# actions that not defined in the table will raise BadRequestRockError
# ---------------------------------------------------------------------------
_NOT_EXIST = "NOT_EXIST"
TRANSITION_MAP: dict[tuple[str, str], str] = {
# start actions
(_NOT_EXIST, "start"): "_handle_start",
# stop actions
(State.PENDING, "stop"): "_handle_stop",
(State.RUNNING, "stop"): "_handle_stop",
(State.STOPPED, "stop"): "_handle_stop_noop",
(_NOT_EXIST, "stop"): "_handle_stop_noop",
# get_status actions
(State.PENDING, "get_status"): "_handle_get_status",
(State.RUNNING, "get_status"): "_handle_get_status",
(State.STOPPED, "get_status"): "_handle_get_status_stopped",
(_NOT_EXIST, "get_status"): "_handle_get_status_not_found",
}


class SandboxManager(BaseManager):
_ray_namespace: str = None
Expand All @@ -68,6 +88,19 @@ def __init__(
self._proxy_service = SandboxProxyService(rock_config=rock_config, meta_store=meta_store)
logger.info("sandbox service init success")

async def get_transition_handler(self, sandbox_id: str, action: str):
info = await self._meta_store.get(sandbox_id)
if info is None:
state = _NOT_EXIST
else:
state = info.get("state")
if state is None:
raise InternalServerRockError(f"Sandbox {sandbox_id} exists in store but has no state field")
handler_name = TRANSITION_MAP.get((state, action))
if handler_name is None:
raise BadRequestRockError(f"Action '{action}' is not allowed in state '{state}'")
return getattr(self, handler_name)

async def refresh_aes_key(self):
try:
await self.rock_config.update()
Expand All @@ -77,12 +110,6 @@ async def refresh_aes_key(self):
logger.error(f"update aes key failed, error: {e}")
raise InternalServerRockError(f"update aes key failed, {str(e)}")

async def _check_sandbox_exists_in_redis(self, config: DeploymentConfig):
if isinstance(config, DockerDeploymentConfig) and config.container_name:
sandbox_id = config.container_name
if await self._meta_store.exists(sandbox_id):
raise BadRequestRockError(f"Sandbox {sandbox_id} already exists")

def _setup_sandbox_actor_metadata(self, sandbox_actor: SandboxActor, user_info: UserInfo) -> None:
user_id = user_info.get("user_id", "default")
experiment_id = user_info.get("experiment_id", "default")
Expand Down Expand Up @@ -110,7 +137,9 @@ async def _build_sandbox_info_metadata(
async def start_async(
self, config: DeploymentConfig, user_info: UserInfo = {}, cluster_info: ClusterInfo = {}
) -> SandboxStartResponse:
await self._check_sandbox_exists_in_redis(config)
if not isinstance(config, DockerDeploymentConfig):
raise BadRequestRockError(f"Unsupported config type: {type(config).__name__}")

self.validate_sandbox_spec(self.rock_config.runtime, config)
docker_deployment_config: DockerDeploymentConfig = await self.deployment_manager.init_config(config)

Expand All @@ -123,6 +152,23 @@ async def start_async(
)
docker_deployment_config.cpus = self.rock_config.runtime.standard_spec.cpus
docker_deployment_config.memory = self.rock_config.runtime.standard_spec.memory

handler = await self.get_transition_handler(sandbox_id, "start")
return await handler(
sandbox_id,
docker_deployment_config=docker_deployment_config,
user_info=user_info,
cluster_info=cluster_info,
)

async def _handle_start(
self,
sandbox_id: str,
*,
docker_deployment_config: DockerDeploymentConfig,
user_info: UserInfo,
cluster_info: ClusterInfo,
) -> SandboxStartResponse:
sandbox_info: SandboxInfo = await self._operator.submit(docker_deployment_config, user_info)
await self._build_sandbox_info_metadata(sandbox_info, user_info, cluster_info)
timeout_info = SandboxTimeoutHelper.make_timeout_info(docker_deployment_config.auto_clear_time)
Expand All @@ -140,33 +186,19 @@ async def start_async(

@monitor_sandbox_operation()
async def start(self, config: DeploymentConfig) -> SandboxStartResponse:
docker_deployment_config: DockerDeploymentConfig = await self.deployment_manager.init_config(config)

sandbox_id = docker_deployment_config.container_name
actor_name = self.deployment_manager.get_actor_name(sandbox_id)
deployment = docker_deployment_config.get_deployment()

sandbox_actor: SandboxActor = await deployment.creator_actor(actor_name)

await self._ray_service.async_ray_get(sandbox_actor.start.remote())
logger.info(f"sandbox {sandbox_id} is started")

while not await self._is_actor_alive(sandbox_id):
logger.debug(f"wait actor for sandbox alive, sandbox_id: {sandbox_id}")
# TODO: timeout check
response = await self.start_async(config, user_info={}, cluster_info={})
while True:
status = await self.get_status(response.sandbox_id)
if status.is_alive:
return response
await asyncio.sleep(1)
await self.get_status(sandbox_id)

self._sandbox_meta[sandbox_id] = {"image": docker_deployment_config.image}

return SandboxStartResponse(
sandbox_id=sandbox_id,
host_name=await self._ray_service.async_ray_get(sandbox_actor.host_name.remote()),
host_ip=await self._ray_service.async_ray_get(sandbox_actor.host_ip.remote()),
)

@monitor_sandbox_operation()
async def stop(self, sandbox_id):
handler = await self.get_transition_handler(sandbox_id, "stop")
await handler(sandbox_id)

async def _handle_stop(self, sandbox_id: str) -> None:
logger.info(f"stop sandbox {sandbox_id}")
sandbox_info: SandboxInfo | None = await self._meta_store.get(sandbox_id)
if sandbox_info is None:
Expand All @@ -181,13 +213,12 @@ async def stop(self, sandbox_id):
logger.error(f"ray get actor, actor {sandbox_id} not exist", exc_info=e)
await self._meta_store.archive(sandbox_id, sandbox_info)
return
try:
self._sandbox_meta.pop(sandbox_id)
except KeyError:
logger.debug(f"{sandbox_id} key not found")
logger.info(f"sandbox {sandbox_id} stopped")
await self._meta_store.archive(sandbox_id, sandbox_info)

async def _handle_stop_noop(self, sandbox_id: str) -> None:
logger.info(f"Sandbox {sandbox_id} already stopped or not found, skipping")

async def get_mount(self, sandbox_id):
async with self._ray_service.get_ray_rwlock().read_lock():
actor_name = self.deployment_manager.get_actor_name(sandbox_id)
Expand Down Expand Up @@ -215,6 +246,10 @@ async def commit(self, sandbox_id, image_tag: str, username: str, password: str)

@monitor_sandbox_operation()
async def get_status(self, sandbox_id) -> SandboxStatusResponse:
handler = await self.get_transition_handler(sandbox_id, "get_status")
return await handler(sandbox_id)

async def _handle_get_status(self, sandbox_id: str) -> SandboxStatusResponse:
sandbox_info: SandboxInfo = await self._operator.get_status(sandbox_id=sandbox_id)
is_alive = sandbox_info.get("state") == State.RUNNING
if sandbox_info.get("state") == State.STOPPED:
Expand Down Expand Up @@ -244,27 +279,19 @@ async def get_status(self, sandbox_id) -> SandboxStatusResponse:
disk_limit_log=sandbox_info.get("disk_limit_log"),
)

async def build_sandbox_info_from_redis(self, sandbox_id: str, deployment_info: SandboxInfo) -> SandboxInfo | None:
sandbox_info_from_store = await self._meta_store.get(sandbox_id)
if sandbox_info_from_store:
sandbox_info = sandbox_info_from_store
remote_info = {
k: v for k, v in deployment_info.items() if k in ["phases", "port_mapping", "alive", "state"]
}
if "phases" in remote_info and remote_info["phases"]:
remote_info["phases"] = {name: phase.to_dict() for name, phase in remote_info["phases"].items()}
sandbox_info.update(remote_info)
else:
sandbox_info = deployment_info
return sandbox_info

def _update_sandbox_alive_info(self, sandbox_info: SandboxInfo, is_alive: bool) -> None:
if is_alive:
sandbox_info["state"] = State.RUNNING
# Set start_time for the first time the sandbox becomes alive
if sandbox_info.get("start_time") is None:
sandbox_info["start_time"] = get_iso8601_timestamp()

async def _handle_get_status_stopped(self, sandbox_id: str) -> SandboxStatusResponse:
raise BadRequestRockError(f"Sandbox {sandbox_id} is already stopped")

async def _handle_get_status_not_found(self, sandbox_id: str) -> SandboxStatusResponse:
raise BadRequestRockError(f"Sandbox {sandbox_id} not found")

async def get_status_v2(self, sandbox_id) -> SandboxStatusResponse:
"""
Deprecated: Use get_status(sandbox_id, use_rocklet=True) instead.
Expand Down
Loading
Loading