Skip to content

Show downloaded models, improve error handling, ability to delete models, side bar with more detail, button to go back to chat history #456

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 57 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
c7dd312
adding logic to check which models are downloaded
cadenmackenzie Nov 13, 2024
de09e2a
reusing helper function to get cached directory
cadenmackenzie Nov 13, 2024
7d7bdd8
removing uneccesary console logs and fixing order of variables in ind…
cadenmackenzie Nov 13, 2024
fb32a85
removing error separtation so I can put in different PR
cadenmackenzie Nov 13, 2024
59f5b6d
adding back in set error message
cadenmackenzie Nov 13, 2024
25d67f5
cleaning up logging in index.js
cadenmackenzie Nov 13, 2024
95ce665
removing unneccesary css
cadenmackenzie Nov 13, 2024
3eb726c
removing sorting of models by name
cadenmackenzie Nov 13, 2024
cbeb1b3
fix safari issue
dtnewman Nov 14, 2024
372d873
Merge pull request #1 from dtnewman/dn/downloadModelsV2
cadenmackenzie Nov 14, 2024
d9aabd7
working versions
cadenmackenzie Nov 14, 2024
dfcf513
removing is_model_downloaded method and changing how downloaded varia…
cadenmackenzie Nov 14, 2024
972074e
reducing redundent checks
cadenmackenzie Nov 14, 2024
dd38924
removing checking of percentage for models that are not found locally
cadenmackenzie Nov 14, 2024
bd2985a
Merge pull request #2 from cadenmackenzie/downloadedModelsV2Revisions
cadenmackenzie Nov 14, 2024
649157d
creating HFShardDownloader with quick_check true so it doesnt start d…
cadenmackenzie Nov 17, 2024
c923ef6
modifying how its being displayed becuase now calculating overall per…
cadenmackenzie Nov 18, 2024
c61f40c
adding helper funciton to check file download. also modifying downloa…
cadenmackenzie Nov 18, 2024
dec79ac
modify get_shard_download_status to use helper function
cadenmackenzie Nov 18, 2024
4c6fda7
modifying helper fucntion checking size to follow redirect for .safet…
cadenmackenzie Nov 18, 2024
3ac8687
adding redirect for all requests
cadenmackenzie Nov 18, 2024
3256051
comment
cadenmackenzie Nov 18, 2024
db610f5
removing traceback
cadenmackenzie Nov 18, 2024
6a7de04
removing path update
cadenmackenzie Nov 18, 2024
fad0591
Merge pull request #4 from cadenmackenzie/hf_helperRefactor
cadenmackenzie Nov 18, 2024
b77362b
moving os import
cadenmackenzie Nov 18, 2024
695ab34
removing import get_hf_home
cadenmackenzie Nov 18, 2024
8135437
fixing formatting
cadenmackenzie Nov 19, 2024
91276cc
fixing formatting
cadenmackenzie Nov 19, 2024
8ee6cc3
yapf formatting
cadenmackenzie Nov 19, 2024
0d50167
yapf in download_file
cadenmackenzie Nov 19, 2024
2cdd55d
Merge branch 'main' into downloadedModelsV2
cadenmackenzie Nov 21, 2024
1ca11ea
defining optional
cadenmackenzie Nov 21, 2024
7a8c722
Merge pull request #5 from cadenmackenzie/main
cadenmackenzie Nov 21, 2024
7e6c69f
remvoing console log
cadenmackenzie Nov 21, 2024
31ce70f
working with side bar to choose model, show download percentage, sele…
cadenmackenzie Nov 21, 2024
fb3baf5
adding amount that has been downloaded if model is not fully downloaded
cadenmackenzie Nov 21, 2024
619df1d
adding functionality to delete the models if there is part of the mod…
cadenmackenzie Nov 22, 2024
a9838a8
formatting handle_delete_model
cadenmackenzie Nov 22, 2024
bc905cd
formatting deleteModel
cadenmackenzie Nov 22, 2024
39139c1
fixiing required engines definition
cadenmackenzie Nov 22, 2024
c469d53
Merge pull request #6 from cadenmackenzie/modelSideBarV2
cadenmackenzie Nov 24, 2024
e16170c
backend endpoint now uses SSE to send each model as its loaded. also …
cadenmackenzie Nov 24, 2024
db45ed6
Merge pull request #7 from cadenmackenzie/downloadedModelsV2_dynamica…
cadenmackenzie Nov 24, 2024
bc83d1f
Merge pull request #8 from cadenmackenzie/main
cadenmackenzie Nov 24, 2024
ded80b0
ensuring requests do not stack up by moving polling to while loop wit…
cadenmackenzie Nov 26, 2024
e99a739
adding a fetch to get initail model object to show models before goin…
cadenmackenzie Nov 26, 2024
445ba7a
Merge pull request #9 from cadenmackenzie/downloadedModelsV2_showingM…
cadenmackenzie Nov 26, 2024
5968a93
Merge pull request #10 from cadenmackenzie/main
cadenmackenzie Nov 27, 2024
ac32170
removing console log in initial models
cadenmackenzie Nov 27, 2024
9b4d030
adding loading icon to sidebar for model that is being downloaded
cadenmackenzie Nov 29, 2024
a2a2d87
Merge branch 'main' into downloadedModelsV2
cadenmackenzie Nov 29, 2024
26a22fb
Merge branch 'exo-explore:main' into downloadedModelsV2
cadenmackenzie Dec 3, 2024
438310b
adding option to download model from sidebar
cadenmackenzie Dec 4, 2024
095f67a
Merge pull request #12 from cadenmackenzie/downloadedModelsV2_useDown…
cadenmackenzie Dec 4, 2024
12bb315
adding diable to chatbox and send if download is in progress
cadenmackenzie Dec 5, 2024
f8cc54b
previously was checking all nodes for download status which lead to i…
cadenmackenzie Dec 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 123 additions & 9 deletions exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
import asyncio
import json
import os
from pathlib import Path
from transformers import AutoTokenizer
from typing import List, Literal, Union, Dict
Expand All @@ -14,10 +15,12 @@
from exo.helpers import PrefixDict, shutdown
from exo.inference.tokenizers import resolve_tokenizer
from exo.orchestration import Node
from exo.models import build_base_shard, model_cards, get_repo, pretty_name, get_supported_models
from exo.apputil import create_animation_mp4
from exo.models import build_base_shard, model_cards, get_repo, pretty_name
from typing import Callable, Optional
import tempfile
from exo.download.hf.hf_shard_download import HFShardDownloader
import shutil
from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
from exo.apputil import create_animation_mp4

class Message:
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
Expand Down Expand Up @@ -175,6 +178,8 @@ def __init__(self, node: Node, inference_engine_classname: str, response_timeout
cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
cors.add(self.app.router.add_post("/quit", self.handle_quit), {"*": cors_options})
cors.add(self.app.router.add_delete("/models/{model_name}", self.handle_delete_model), {"*": cors_options})
cors.add(self.app.router.add_get("/initial_models", self.handle_get_initial_models), {"*": cors_options})
cors.add(self.app.router.add_post("/create_animation", self.handle_create_animation), {"*": cors_options})
cors.add(self.app.router.add_post("/download", self.handle_post_download), {"*": cors_options})

Expand Down Expand Up @@ -216,12 +221,56 @@ async def handle_healthcheck(self, request):
return web.json_response({"status": "ok"})

async def handle_model_support(self, request):
return web.json_response({
"model pool": {
model_name: pretty_name.get(model_name, model_name)
for model_name in get_supported_models(self.node.topology_inference_engines_pool)
}
})
try:
response = web.StreamResponse(
status=200,
reason='OK',
headers={
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
}
)
await response.prepare(request)

for model_name, pretty in pretty_name.items():
if model_name in model_cards:
model_info = model_cards[model_name]

if self.inference_engine_classname in model_info.get("repo", {}):
shard = build_base_shard(model_name, self.inference_engine_classname)
if shard:
downloader = HFShardDownloader(quick_check=True)
downloader.current_shard = shard
downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
status = await downloader.get_shard_download_status()

download_percentage = status.get("overall") if status else None
total_size = status.get("total_size") if status else None
total_downloaded = status.get("total_downloaded") if status else False

model_data = {
model_name: {
"name": pretty,
"downloaded": download_percentage == 100 if download_percentage is not None else False,
"download_percentage": download_percentage,
"total_size": total_size,
"total_downloaded": total_downloaded
}
}

await response.write(f"data: {json.dumps(model_data)}\n\n".encode())

await response.write(b"data: [DONE]\n\n")
return response

except Exception as e:
print(f"Error in handle_model_support: {str(e)}")
traceback.print_exc()
return web.json_response(
{"detail": f"Server error: {str(e)}"},
status=500
)

async def handle_get_models(self, request):
return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
Expand Down Expand Up @@ -372,6 +421,71 @@ def on_result(_request_id: str, tokens: List[int], is_finished: bool):
deregistered_callback = self.node.on_token.deregister(callback_id)
if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")

async def handle_delete_model(self, request):
try:
model_name = request.match_info.get('model_name')
if DEBUG >= 2: print(f"Attempting to delete model: {model_name}")

if not model_name or model_name not in model_cards:
return web.json_response(
{"detail": f"Invalid model name: {model_name}"},
status=400
)

shard = build_base_shard(model_name, self.inference_engine_classname)
if not shard:
return web.json_response(
{"detail": "Could not build shard for model"},
status=400
)

repo_id = get_repo(shard.model_id, self.inference_engine_classname)
if DEBUG >= 2: print(f"Repo ID for model: {repo_id}")

# Get the HF cache directory using the helper function
hf_home = get_hf_home()
cache_dir = get_repo_root(repo_id)

if DEBUG >= 2: print(f"Looking for model files in: {cache_dir}")

if os.path.exists(cache_dir):
if DEBUG >= 2: print(f"Found model files at {cache_dir}, deleting...")
try:
shutil.rmtree(cache_dir)
return web.json_response({
"status": "success",
"message": f"Model {model_name} deleted successfully",
"path": str(cache_dir)
})
except Exception as e:
return web.json_response({
"detail": f"Failed to delete model files: {str(e)}"
}, status=500)
else:
return web.json_response({
"detail": f"Model files not found at {cache_dir}"
}, status=404)

except Exception as e:
print(f"Error in handle_delete_model: {str(e)}")
traceback.print_exc()
return web.json_response({
"detail": f"Server error: {str(e)}"
}, status=500)

async def handle_get_initial_models(self, request):
model_data = {}
for model_name, pretty in pretty_name.items():
model_data[model_name] = {
"name": pretty,
"downloaded": None, # Initially unknown
"download_percentage": None, # Change from 0 to null
"total_size": None,
"total_downloaded": None,
"loading": True # Add loading state
}
return web.json_response(model_data)

async def handle_create_animation(self, request):
try:
data = await request.json()
Expand Down
64 changes: 62 additions & 2 deletions exo/download/hf/hf_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,18 @@ async def download_file(
downloaded_size = local_file_size
downloaded_this_session = 0
mode = 'ab' if use_range_request else 'wb'
if downloaded_size == total_size:
percentage = await get_file_download_percentage(
session,
repo_id,
revision,
file_path,
Path(save_directory)
)

if percentage == 100:
if DEBUG >= 2: print(f"File already downloaded: {file_path}")
if progress_callback:
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, total_size, 0, total_size, 0, timedelta(0), "complete"))
return

if response.status == 200:
Expand Down Expand Up @@ -432,6 +440,57 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
return list(default_patterns | shard_specific_patterns)

async def get_file_download_percentage(
session: aiohttp.ClientSession,
repo_id: str,
revision: str,
file_path: str,
snapshot_dir: Path,
) -> float:
"""
Calculate the download percentage for a file by comparing local and remote sizes.
"""
try:
local_path = snapshot_dir / file_path
if not await aios.path.exists(local_path):
return 0

# Get local file size first
local_size = await aios.path.getsize(local_path)
if local_size == 0:
return 0

# Check remote size
base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
url = urljoin(base_url, file_path)
headers = await get_auth_headers()

# Use HEAD request with redirect following for all files
async with session.head(url, headers=headers, allow_redirects=True) as response:
if response.status != 200:
if DEBUG >= 2:
print(f"Failed to get remote file info for {file_path}: {response.status}")
return 0

remote_size = int(response.headers.get('Content-Length', 0))

if remote_size == 0:
if DEBUG >= 2:
print(f"Remote size is 0 for {file_path}")
return 0

# Only return 100% if sizes match exactly
if local_size == remote_size:
return 100.0

# Calculate percentage based on sizes
return (local_size / remote_size) * 100 if remote_size > 0 else 0

except Exception as e:
if DEBUG >= 2:
print(f"Error checking file download status for {file_path}: {e}")
return 0

async def has_hf_home_read_access() -> bool:
hf_home = get_hf_home()
try: return await aios.access(hf_home, os.R_OK)
Expand All @@ -441,3 +500,4 @@ async def has_hf_home_write_access() -> bool:
hf_home = get_hf_home()
try: return await aios.access(hf_home, os.W_OK)
except OSError: return False

92 changes: 90 additions & 2 deletions exo/download/hf/hf_shard_download.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
import asyncio
import traceback
from pathlib import Path
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Optional, Union
from exo.inference.shard import Shard
from exo.download.shard_download import ShardDownloader
from exo.download.download_progress import RepoProgressEvent
from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_weight_map, get_allow_patterns, get_repo_root
from exo.download.hf.hf_helpers import (
download_repo_files, RepoProgressEvent, get_weight_map,
get_allow_patterns, get_repo_root, fetch_file_list,
get_local_snapshot_dir, get_file_download_percentage,
filter_repo_objects
)
from exo.helpers import AsyncCallbackSystem, DEBUG
from exo.models import model_cards, get_repo
import aiohttp
from aiofiles import os as aios


class HFShardDownloader(ShardDownloader):
Expand All @@ -17,8 +24,13 @@ def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4):
self.active_downloads: Dict[Shard, asyncio.Task] = {}
self.completed_downloads: Dict[Shard, Path] = {}
self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
self.current_shard: Optional[Shard] = None
self.current_repo_id: Optional[str] = None
self.revision: str = "main"

async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
self.current_shard = shard
self.current_repo_id = get_repo(shard.model_id, inference_engine_name)
repo_name = get_repo(shard.model_id, inference_engine_name)
if shard in self.completed_downloads:
return self.completed_downloads[shard]
Expand Down Expand Up @@ -77,3 +89,79 @@ async def wrapped_progress_callback(event: RepoProgressEvent):
@property
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
return self._on_progress

async def get_shard_download_status(self) -> Optional[Dict[str, Union[float, int]]]:
if not self.current_shard or not self.current_repo_id:
if DEBUG >= 2:
print(f"No current shard or repo_id set: {self.current_shard=} {self.current_repo_id=}")
return None

try:
# If no snapshot directory exists, return None - no need to check remote files
snapshot_dir = await get_local_snapshot_dir(self.current_repo_id, self.revision)
if not snapshot_dir:
if DEBUG >= 2:
print(f"No snapshot directory found for {self.current_repo_id}")
return None

# Get the weight map to know what files we need
weight_map = await get_weight_map(self.current_repo_id, self.revision)
if not weight_map:
if DEBUG >= 2:
print(f"No weight map found for {self.current_repo_id}")
return None

# Get all files needed for this shard
patterns = get_allow_patterns(weight_map, self.current_shard)

# Check download status for all relevant files
status = {}
total_bytes = 0
downloaded_bytes = 0

async with aiohttp.ClientSession() as session:
file_list = await fetch_file_list(session, self.current_repo_id, self.revision)
relevant_files = list(
filter_repo_objects(
file_list, allow_patterns=patterns, key=lambda x: x["path"]))

for file in relevant_files:
file_size = file["size"]
total_bytes += file_size

percentage = await get_file_download_percentage(
session,
self.current_repo_id,
self.revision,
file["path"],
snapshot_dir,
)
status[file["path"]] = percentage
downloaded_bytes += (file_size * (percentage / 100))

# Add overall progress weighted by file size
if total_bytes > 0:
status["overall"] = (downloaded_bytes / total_bytes) * 100
else:
status["overall"] = 0

# Add total size in bytes
status["total_size"] = total_bytes
if status["overall"] != 100:
status["total_downloaded"] = downloaded_bytes


if DEBUG >= 2:
print(f"Download calculation for {self.current_repo_id}:")
print(f"Total bytes: {total_bytes}")
print(f"Downloaded bytes: {downloaded_bytes}")
for file in relevant_files:
print(f"File {file['path']}: size={file['size']}, percentage={status[file['path']]}")

return status

except Exception as e:
if DEBUG >= 2:
print(f"Error getting shard download status: {e}")
traceback.print_exc()
return None
12 changes: 11 additions & 1 deletion exo/download/shard_download.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Optional, Tuple
from typing import Optional, Tuple, Dict
from pathlib import Path
from exo.inference.shard import Shard
from exo.download.download_progress import RepoProgressEvent
Expand All @@ -26,6 +26,16 @@ async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
pass

@abstractmethod
async def get_shard_download_status(self) -> Optional[Dict[str, float]]:
"""Get the download status of shards.

Returns:
Optional[Dict[str, float]]: A dictionary mapping shard IDs to their download percentage (0-100),
or None if status cannot be determined
"""
pass


class NoopShardDownloader(ShardDownloader):
async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
Expand Down
Loading