Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Parallelise Model Loading #360

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions exo/inference/inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

class InferenceEngine(ABC):
@abstractmethod
async def preload_model(self, shard: Shard) -> None:
pass
@abstractmethod
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
pass

Expand Down
18 changes: 13 additions & 5 deletions exo/inference/mlx/sharded_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import pathlib as Path

class MLXDynamicShardInferenceEngine(InferenceEngine):
def __init__(self, shard_downloader: ShardDownloader):
Expand Down Expand Up @@ -43,8 +44,15 @@ async def ensure_shard(self, shard: Shard):
model_path = await self.shard_downloader.ensure_shard(shard)

if self.shard != shard:
loop = asyncio.get_running_loop()
def load_shard_wrapper(): return asyncio.run(load_shard(model_path, shard))
model_shard, self.tokenizer = await loop.run_in_executor(self.executor, load_shard_wrapper)
self.stateful_sharded_model = await loop.run_in_executor(self.executor, StatefulShardedModel, shard, model_shard)
self.shard = shard
await self.load_model(model_path, shard)

async def load_model(self, model_path: Path, shard: Shard):
loop = asyncio.get_running_loop()
def load_shard_wrapper(): return asyncio.run(load_shard(model_path, shard))
model_shard, self.tokenizer = await loop.run_in_executor(self.executor, load_shard_wrapper)
self.stateful_sharded_model = await loop.run_in_executor(self.executor, StatefulShardedModel, shard, model_shard)
self.shard = shard

# already loaded into memory by ensure_shard,
async def preload_model(self, shard: Shard) -> None:
await self.ensure_shard(shard)
36 changes: 25 additions & 11 deletions exo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
parser.add_argument("--preload-models", type=str, help="Comma-separated list of models to preload")
args = parser.parse_args()

print_yellow_exo()
Expand Down Expand Up @@ -100,16 +101,18 @@
lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
)
def preemptively_start_download(request_id: str, opaque_status: str):
try:
status = json.loads(opaque_status)
if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
asyncio.create_task(shard_downloader.ensure_shard(current_shard))
except Exception as e:
if DEBUG >= 2:
print(f"Failed to preemptively start download: {e}")
traceback.print_exc()
try:
status = json.loads(opaque_status)
if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
asyncio.create_task(shard_downloader.ensure_shard(current_shard))
vovw marked this conversation as resolved.
Show resolved Hide resolved
return current_shard
except Exception as e:
if DEBUG >= 2:
print(f"Failed to preemptively start download: {e}")
traceback.print_exc()
return None
node.on_opaque_status.register("start_download").on_next(preemptively_start_download)

if args.prometheus_client_port:
Expand Down Expand Up @@ -176,7 +179,6 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
async def main():
loop = asyncio.get_running_loop()

# Use a more direct approach to handle signals
def handle_exit():
asyncio.ensure_future(shutdown(signal.SIGTERM, loop))

Expand All @@ -195,6 +197,18 @@ def handle_exit():
asyncio.create_task(api.run(port=args.chatgpt_api_port)) # Start the API server as a non-blocking task
await asyncio.Event().wait()

if args.preload_models:
models_to_preload = [Shard.from_dict(model_base_shards[model_name][inference_engine.__class__.__name__])
for model_name in args.preload_models.split(',')]
for shard in models_to_preload:
current_shard = preemptively_start_download(str(uuid.uuid4()), json.dumps({
"type": "node_status",
"status": "start_process_prompt",
"shard": shard.to_dict()
}))
if current_shard:
await node.preload_models([current_shard])


def run():
loop = asyncio.new_event_loop()
Expand Down
5 changes: 5 additions & 0 deletions exo/orchestration/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarr
async def collect_topology(self, visited: set[str] = set(), max_depth: int = 2) -> Topology:
pass

@abstractmethod
async def preload_models(self, shards: List[Shard]) -> None:
pass

@property
@abstractmethod
def current_topology(self) -> Topology:
Expand All @@ -45,3 +49,4 @@ def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]:
@abstractmethod
def on_opaque_status(self) -> AsyncCallbackSystem[str, Tuple[str, str]]:
pass

6 changes: 6 additions & 0 deletions exo/orchestration/standard_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,3 +432,9 @@ async def send_status_to_peer(peer):
@property
def current_topology(self) -> Topology:
return self.topology

async def preload_models(self, shards: List[Shard]) -> None:
preload_tasks = []
for shard in shards:
preload_tasks.append(asyncio.create_task(self.inference_engine.preload_model(shard)))
await asyncio.gather(*preload_tasks)