diff --git a/exo/inference/inference_engine.py b/exo/inference/inference_engine.py index b94654932..2fa45e702 100644 --- a/exo/inference/inference_engine.py +++ b/exo/inference/inference_engine.py @@ -8,6 +8,9 @@ class InferenceEngine(ABC): @abstractmethod + async def preload_model(self, shard: Shard) -> None: + pass + @abstractmethod async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool): pass diff --git a/exo/inference/mlx/sharded_inference_engine.py b/exo/inference/mlx/sharded_inference_engine.py index 122628d52..585cc9e04 100644 --- a/exo/inference/mlx/sharded_inference_engine.py +++ b/exo/inference/mlx/sharded_inference_engine.py @@ -9,6 +9,7 @@ import asyncio from concurrent.futures import ThreadPoolExecutor from functools import partial +import pathlib as Path class MLXDynamicShardInferenceEngine(InferenceEngine): def __init__(self, shard_downloader: ShardDownloader): @@ -43,8 +44,15 @@ async def ensure_shard(self, shard: Shard): model_path = await self.shard_downloader.ensure_shard(shard) if self.shard != shard: - loop = asyncio.get_running_loop() - def load_shard_wrapper(): return asyncio.run(load_shard(model_path, shard)) - model_shard, self.tokenizer = await loop.run_in_executor(self.executor, load_shard_wrapper) - self.stateful_sharded_model = await loop.run_in_executor(self.executor, StatefulShardedModel, shard, model_shard) - self.shard = shard + await self.load_model(model_path, shard) + + async def load_model(self, model_path: Path, shard: Shard): + loop = asyncio.get_running_loop() + def load_shard_wrapper(): return asyncio.run(load_shard(model_path, shard)) + model_shard, self.tokenizer = await loop.run_in_executor(self.executor, load_shard_wrapper) + self.stateful_sharded_model = await loop.run_in_executor(self.executor, StatefulShardedModel, shard, model_shard) + self.shard = shard + + # already loaded into memory by ensure_shard, + async def preload_model(self, shard: Shard) -> None: + await self.ensure_shard(shard) diff --git a/exo/main.py b/exo/main.py index 49774da97..0dbcfb269 100644 --- a/exo/main.py +++ b/exo/main.py @@ -47,6 +47,7 @@ parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?") parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key") parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name") +parser.add_argument("--preload-models", type=str, help="Comma-separated list of models to preload") args = parser.parse_args() print_yellow_exo() @@ -99,18 +100,26 @@ node.on_token.register("update_topology_viz").on_next( lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None ) -def preemptively_start_download(request_id: str, opaque_status: str): - try: - status = json.loads(opaque_status) - if status.get("type") == "node_status" and status.get("status") == "start_process_prompt": - current_shard = node.get_current_shard(Shard.from_dict(status.get("shard"))) - if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}") - asyncio.create_task(shard_downloader.ensure_shard(current_shard)) - except Exception as e: - if DEBUG >= 2: - print(f"Failed to preemptively start download: {e}") - traceback.print_exc() -node.on_opaque_status.register("start_download").on_next(preemptively_start_download) + +async def preemptively_start_download(request_id: str, opaque_status: str): + try: + status = json.loads(opaque_status) + if status.get("type") == "node_status" and status.get("status") == "start_process_prompt": + current_shard = node.get_current_shard(Shard.from_dict(status.get("shard"))) + if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}") + await shard_downloader.ensure_shard(current_shard) + # Preload the model after ensuring the shard is downloaded + await node.preload_models([current_shard]) + if DEBUG >= 2: print(f"Preloaded model for {current_shard}") + return current_shard + except Exception as e: + if DEBUG >= 2: + print(f"Failed to preemptively start download or preload: {e}") + traceback.print_exc() + return None + +# Update the registration to use the async version +node.on_opaque_status.register("start_download").on_next(lambda request_id, opaque_status: asyncio.create_task(preemptively_start_download(request_id, opaque_status))) if args.prometheus_client_port: from exo.stats.metrics import start_metrics_server @@ -176,7 +185,6 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam async def main(): loop = asyncio.get_running_loop() - # Use a more direct approach to handle signals def handle_exit(): asyncio.ensure_future(shutdown(signal.SIGTERM, loop)) @@ -195,6 +203,18 @@ def handle_exit(): asyncio.create_task(api.run(port=args.chatgpt_api_port)) # Start the API server as a non-blocking task await asyncio.Event().wait() + if args.preload_models: + models_to_preload = [Shard.from_dict(model_base_shards[model_name][inference_engine.__class__.__name__]) + for model_name in args.preload_models.split(',')] + for shard in models_to_preload: + current_shard = preemptively_start_download(str(uuid.uuid4()), json.dumps({ + "type": "node_status", + "status": "start_process_prompt", + "shard": shard.to_dict() + })) + if current_shard: + await node.preload_models([current_shard]) + def run(): loop = asyncio.new_event_loop() diff --git a/exo/orchestration/node.py b/exo/orchestration/node.py index 60b729748..97b4e8f79 100644 --- a/exo/orchestration/node.py +++ b/exo/orchestration/node.py @@ -31,6 +31,10 @@ async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarr async def collect_topology(self, visited: set[str] = set(), max_depth: int = 2) -> Topology: pass + @abstractmethod + async def preload_models(self, shards: List[Shard]) -> None: + pass + @property @abstractmethod def current_topology(self) -> Topology: @@ -45,3 +49,4 @@ def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]: @abstractmethod def on_opaque_status(self) -> AsyncCallbackSystem[str, Tuple[str, str]]: pass + diff --git a/exo/orchestration/standard_node.py b/exo/orchestration/standard_node.py index 76683f562..71c85e24f 100644 --- a/exo/orchestration/standard_node.py +++ b/exo/orchestration/standard_node.py @@ -432,3 +432,9 @@ async def send_status_to_peer(peer): @property def current_topology(self) -> Topology: return self.topology + + async def preload_models(self, shards: List[Shard]) -> None: + preload_tasks = [] + for shard in shards: + preload_tasks.append(asyncio.create_task(self.inference_engine.preload_model(shard))) + await asyncio.gather(*preload_tasks)