Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Show downloaded models, improve error handling #456

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 53 additions & 18 deletions exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from exo.orchestration import Node
from exo.models import build_base_shard, model_cards, get_repo, pretty_name
from typing import Callable
import os
from exo.download.hf.hf_helpers import get_hf_home
from exo.download.hf.hf_shard_download import HFShardDownloader


class Message:
Expand Down Expand Up @@ -201,24 +204,56 @@ async def handle_root(self, request):
return web.FileResponse(self.static_dir/"index.html")

async def handle_model_support(self, request):
return web.json_response({
"model pool": {
model_name: pretty_name.get(model_name, model_name)
for model_name in [
model_id for model_id, model_info in model_cards.items()
if all(map(
lambda engine: engine in model_info["repo"],
list(dict.fromkeys([
inference_engine_classes.get(engine_name, None)
for engine_list in self.node.topology_inference_engines_pool
for engine_name in engine_list
if engine_name is not None
] + [self.inference_engine_classname]))
))
]
}
})

try:
model_pool = {}

for model_name, pretty in pretty_name.items():
if model_name in model_cards:
model_info = model_cards[model_name]

# Get required engines
required_engines = list(dict.fromkeys([
inference_engine_classes.get(engine_name, None)
for engine_list in self.node.topology_inference_engines_pool
for engine_name in engine_list
if engine_name is not None
] + [self.inference_engine_classname]))

# Check if model supports required engines
if all(map(lambda engine: engine in model_info["repo"], required_engines)):
shard = build_base_shard(model_name, self.inference_engine_classname)
if shard:
downloader = HFShardDownloader()
downloader.current_shard = shard
downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
status = await downloader.get_shard_download_status()
if DEBUG >= 2:
print(f"Download status for {model_name}: {status}")

# Calculate overall percentage if we have status
download_percentage = None
if status:
percentages = list(status.values())
if percentages:
download_percentage = sum(percentages) / len(percentages)
if DEBUG >= 2:
print(f"Calculated download percentage for {model_name}: {download_percentage}")

model_pool[model_name] = {
"name": pretty,
"downloaded": download_percentage == 100 if download_percentage is not None else False,
"download_percentage": download_percentage
}

return web.json_response({"model pool": model_pool})
except Exception as e:
print(f"Error in handle_model_support: {str(e)}")
traceback.print_exc()
return web.json_response(
{"detail": f"Server error: {str(e)}"},
status=500
)

async def handle_get_models(self, request):
return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])

Expand Down
66 changes: 64 additions & 2 deletions exo/download/hf/hf_shard_download.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import asyncio
import traceback
from pathlib import Path
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Optional
from exo.inference.shard import Shard
from exo.download.shard_download import ShardDownloader
from exo.download.download_progress import RepoProgressEvent
from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_weight_map, get_allow_patterns, get_repo_root
from exo.download.hf.hf_helpers import (
download_repo_files, RepoProgressEvent, get_weight_map,
get_allow_patterns, get_repo_root, fetch_file_list, get_local_snapshot_dir
)
from exo.helpers import AsyncCallbackSystem, DEBUG
from exo.models import model_cards, get_repo
import aiohttp
from aiofiles import os as aios


class HFShardDownloader(ShardDownloader):
Expand All @@ -17,8 +22,13 @@ def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4):
self.active_downloads: Dict[Shard, asyncio.Task] = {}
self.completed_downloads: Dict[Shard, Path] = {}
self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
self.current_shard: Optional[Shard] = None
self.current_repo_id: Optional[str] = None
self.revision: str = "main"

async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
self.current_shard = shard
self.current_repo_id = get_repo(shard.model_id, inference_engine_name)
repo_name = get_repo(shard.model_id, inference_engine_name)
if shard in self.completed_downloads:
return self.completed_downloads[shard]
Expand Down Expand Up @@ -77,3 +87,55 @@ async def wrapped_progress_callback(event: RepoProgressEvent):
@property
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
return self._on_progress

async def get_shard_download_status(self) -> Optional[Dict[str, float]]:
if not self.current_shard or not self.current_repo_id:
if DEBUG >= 2: print(f"No current shard or repo_id set: {self.current_shard=} {self.current_repo_id=}")
return None

try:
snapshot_dir = await get_local_snapshot_dir(self.current_repo_id, self.revision)
if not snapshot_dir:
if DEBUG >= 2: print(f"No snapshot directory found for {self.current_repo_id}")
return None

# Get the weight map to know what files we need
weight_map = await get_weight_map(self.current_repo_id, self.revision)
if not weight_map:
if DEBUG >= 2: print(f"No weight map found for {self.current_repo_id}")
return None

# Get the patterns for this shard
patterns = get_allow_patterns(weight_map, self.current_shard)

# First check which files exist locally
status = {}
local_files = []
local_sizes = {}

for pattern in patterns:
if pattern.endswith('safetensors') or pattern.endswith('mlx'):
file_path = snapshot_dir / pattern
if await aios.path.exists(file_path):
local_size = await aios.path.getsize(file_path)
local_files.append(pattern)
local_sizes[pattern] = local_size

# Only fetch remote info if we found local files
if local_files:
async with aiohttp.ClientSession() as session:
file_list = await fetch_file_list(session, self.current_repo_id, self.revision)

for pattern in local_files:
for file in file_list:
if file["path"].endswith(pattern):
status[pattern] = (local_sizes[pattern] / file["size"]) * 100
break

return status

except Exception as e:
if DEBUG >= 2:
print(f"Error getting shard download status: {e}")
traceback.print_exc()
return None
12 changes: 11 additions & 1 deletion exo/download/shard_download.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Optional, Tuple
from typing import Optional, Tuple, Dict
from pathlib import Path
from exo.inference.shard import Shard
from exo.download.download_progress import RepoProgressEvent
Expand All @@ -26,6 +26,16 @@ async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
pass

@abstractmethod
async def get_shard_download_status(self) -> Optional[Dict[str, float]]:
"""Get the download status of shards.

Returns:
Optional[Dict[str, float]]: A dictionary mapping shard IDs to their download percentage (0-100),
or None if status cannot be determined
"""
pass


class NoopShardDownloader(ShardDownloader):
async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
Expand Down
10 changes: 5 additions & 5 deletions exo/tinychat/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@
<body>
<main x-data="state" x-init="console.log(endpoint)">
<!-- Error Toast -->
<div x-show="errorMessage" x-transition.opacity class="toast">
<div x-show="errorMessage !== null" x-transition.opacity class="toast">
<div class="toast-header">
<span class="toast-error-message" x-text="errorMessage.basic"></span>
<span class="toast-error-message" x-text="errorMessage?.basic || ''"></span>
<div class="toast-header-buttons">
<button @click="errorExpanded = !errorExpanded; if (errorTimeout) { clearTimeout(errorTimeout); errorTimeout = null; }"
class="toast-expand-button"
x-show="errorMessage.stack">
x-show="errorMessage?.stack">
<span x-text="errorExpanded ? 'Hide Details' : 'Show Details'"></span>
</button>
<button @click="errorMessage = null; errorExpanded = false;" class="toast-close-button">
Expand All @@ -41,11 +41,11 @@
</div>
</div>
<div class="toast-content" x-show="errorExpanded" x-transition>
<span x-text="errorMessage.stack"></span>
<span x-text="errorMessage?.stack || ''"></span>
</div>
</div>
<div class="model-selector">
<select @change="if (cstate) cstate.selectedModel = $event.target.value" x-model="cstate.selectedModel" x-init="await populateSelector()" class='model-select'>
<select @change="if (cstate) cstate.selectedModel = $event.target.value" x-model="cstate.selectedModel" class='model-select'>
</select>
</div>
<div @popstate.window="
Expand Down
Loading