Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Show downloaded models, improve error handling #456

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 88 additions & 18 deletions exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from exo.orchestration import Node
from exo.models import build_base_shard, model_cards, get_repo, pretty_name
from typing import Callable
import os
from exo.download.hf.hf_helpers import get_hf_home


class Message:
Expand Down Expand Up @@ -200,25 +202,93 @@ async def middleware(request):
async def handle_root(self, request):
return web.FileResponse(self.static_dir/"index.html")

def is_model_downloaded(self, model_name):
if DEBUG >= 2:
print(f"\nChecking if model {model_name} is downloaded:")

cache_dir = get_hf_home() / "hub"
repo = get_repo(model_name, self.inference_engine_classname)

if DEBUG >= 2:
print(f" Cache dir: {cache_dir}")
print(f" Repo: {repo}")
print(f" Engine: {self.inference_engine_classname}")

if not repo:
return False

# Convert repo path (e.g. "mlx-community/Llama-3.2-1B-Instruct-4bit")
# to directory format (e.g. "models--mlx-community--Llama-3.2-1B-Instruct-4bit")
repo_parts = repo.split('/')
formatted_path = f"models--{repo_parts[0]}--{repo_parts[1]}"
repo_path = cache_dir / formatted_path / "snapshots"

if DEBUG >= 2:
print(f" Looking in: {repo_path}")

if repo_path.exists():
# Look for the most recent snapshot directory
snapshots = list(repo_path.glob("*"))
if snapshots:
latest_snapshot = max(snapshots, key=lambda p: p.stat().st_mtime)

# Check for model files and their index files
model_files = (
list(latest_snapshot.glob("model.safetensors")) +
list(latest_snapshot.glob("model.safetensors.index.json")) +
list(latest_snapshot.glob("*.mlx"))
)

if DEBUG >= 2:
print(f" Latest snapshot: {latest_snapshot}")
print(f" Found files: {model_files}")

# Model is considered downloaded if we find either:
# 1. model.safetensors file
# 2. model.safetensors.index.json file (for sharded models)
# 3. *.mlx file
return len(model_files) > 0

if DEBUG >= 2:
print(" No valid model files found")
return False

async def handle_model_support(self, request):
return web.json_response({
"model pool": {
model_name: pretty_name.get(model_name, model_name)
for model_name in [
model_id for model_id, model_info in model_cards.items()
if all(map(
lambda engine: engine in model_info["repo"],
list(dict.fromkeys([
inference_engine_classes.get(engine_name, None)
for engine_list in self.node.topology_inference_engines_pool
for engine_name in engine_list
if engine_name is not None
] + [self.inference_engine_classname]))
))
]
}
})

try:
model_pool = {}

for model_name, pretty in pretty_name.items():
if model_name in model_cards:
model_info = model_cards[model_name]

# Get required engines
required_engines = list(dict.fromkeys([
inference_engine_classes.get(engine_name, None)
for engine_list in self.node.topology_inference_engines_pool
for engine_name in engine_list
if engine_name is not None
] + [self.inference_engine_classname]))

# Check if model supports required engines
if all(map(lambda engine: engine in model_info["repo"], required_engines)):
is_downloaded = self.is_model_downloaded(model_name)
if DEBUG >= 2:
print(f"Model {model_name} download status: {is_downloaded}")

model_pool[model_name] = {
"name": pretty,
"downloaded": is_downloaded
}

return web.json_response({"model pool": model_pool})
except Exception as e:
print(f"Error in handle_model_support: {str(e)}")
traceback.print_exc()
return web.json_response(
{"detail": f"Server error: {str(e)}"},
status=500
)

async def handle_get_models(self, request):
return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])

Expand Down
10 changes: 5 additions & 5 deletions exo/tinychat/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@
<body>
<main x-data="state" x-init="console.log(endpoint)">
<!-- Error Toast -->
<div x-show="errorMessage" x-transition.opacity class="toast">
<div x-show="errorMessage !== null" x-transition.opacity class="toast">
<div class="toast-header">
<span class="toast-error-message" x-text="errorMessage.basic"></span>
<span class="toast-error-message" x-text="errorMessage?.basic || ''"></span>
<div class="toast-header-buttons">
<button @click="errorExpanded = !errorExpanded; if (errorTimeout) { clearTimeout(errorTimeout); errorTimeout = null; }"
class="toast-expand-button"
x-show="errorMessage.stack">
x-show="errorMessage?.stack">
<span x-text="errorExpanded ? 'Hide Details' : 'Show Details'"></span>
</button>
<button @click="errorMessage = null; errorExpanded = false;" class="toast-close-button">
Expand All @@ -41,11 +41,11 @@
</div>
</div>
<div class="toast-content" x-show="errorExpanded" x-transition>
<span x-text="errorMessage.stack"></span>
<span x-text="errorMessage?.stack || ''"></span>
</div>
</div>
<div class="model-selector">
<select @change="if (cstate) cstate.selectedModel = $event.target.value" x-model="cstate.selectedModel" x-init="await populateSelector()" class='model-select'>
<select @change="if (cstate) cstate.selectedModel = $event.target.value" x-model="cstate.selectedModel" class='model-select'>
</select>
</div>
<div @popstate.window="
Expand Down
114 changes: 35 additions & 79 deletions exo/tinychat/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ document.addEventListener("alpine:init", () => {
home: 0,
generating: false,
endpoint: `${window.location.origin}/v1`,

// Initialize error message structure
errorMessage: null,
errorExpanded: false,
errorTimeout: null,
Expand All @@ -38,6 +40,9 @@ document.addEventListener("alpine:init", () => {

// Start polling for download progress
this.startDownloadProgressPolling();

// Call populateSelector immediately after initialization
this.populateSelector();
},

removeHistory(cstate) {
Expand Down Expand Up @@ -77,50 +82,25 @@ document.addEventListener("alpine:init", () => {
async populateSelector() {
try {
const response = await fetch(`${window.location.origin}/modelpool`);
const responseText = await response.text(); // Get raw response text first

if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}

// Try to parse the response text
let responseJson;
try {
responseJson = JSON.parse(responseText);
} catch (parseError) {
console.error('Failed to parse JSON:', parseError);
throw new Error(`Invalid JSON response: ${responseText}`);
}

const sel = document.querySelector(".model-select");
if (!sel) {
throw new Error("Could not find model selector element");
const errorText = await response.text();
throw new Error(`HTTP error! status: ${response.status}\n${errorText}`);
}

// Clear the current options and add new ones
const data = await response.json();
const sel = document.querySelector('.model-select');
sel.innerHTML = '';

const modelDict = responseJson["model pool"];
if (!modelDict) {
throw new Error("Response missing 'model pool' property");
}

Object.entries(modelDict).forEach(([key, value]) => {
// Use the model pool entries in their original order
Object.entries(data["model pool"]).forEach(([key, value]) => {
const opt = document.createElement("option");
opt.value = key;
opt.textContent = value;
opt.textContent = `${value.name}${value.downloaded ? ' (downloaded)' : ''}`;
sel.appendChild(opt);
});

// Set initial value to the first model
const firstKey = Object.keys(modelDict)[0];
if (firstKey) {
sel.value = firstKey;
this.cstate.selectedModel = firstKey;
}
} catch (error) {
console.error("Error populating model selector:", error);
this.errorMessage = `Failed to load models: ${error.message}`;
this.setError(error);
}
},

Expand Down Expand Up @@ -169,29 +149,7 @@ document.addEventListener("alpine:init", () => {
this.processMessage(value);
} catch (error) {
console.error('error', error);
const errorDetails = {
message: error.message || 'Unknown error',
stack: error.stack,
name: error.name || 'Error'
};

this.errorMessage = {
basic: `${errorDetails.name}: ${errorDetails.message}`,
stack: errorDetails.stack
};

// Clear any existing timeout
if (this.errorTimeout) {
clearTimeout(this.errorTimeout);
}

// Only set the timeout if the error details aren't expanded
if (!this.errorExpanded) {
this.errorTimeout = setTimeout(() => {
this.errorMessage = null;
this.errorExpanded = false;
}, 30 * 1000);
}
this.setError(error);
this.generating = false;
}
},
Expand Down Expand Up @@ -309,29 +267,7 @@ document.addEventListener("alpine:init", () => {
}
} catch (error) {
console.error('error', error);
const errorDetails = {
message: error.message || 'Unknown error',
stack: error.stack,
name: error.name || 'Error'
};

this.errorMessage = {
basic: `${errorDetails.name}: ${errorDetails.message}`,
stack: errorDetails.stack
};

// Clear any existing timeout
if (this.errorTimeout) {
clearTimeout(this.errorTimeout);
}

// Only set the timeout if the error details aren't expanded
if (!this.errorExpanded) {
this.errorTimeout = setTimeout(() => {
this.errorMessage = null;
this.errorExpanded = false;
}, 30 * 1000);
}
this.setError(error);
} finally {
this.generating = false;
}
Expand Down Expand Up @@ -467,6 +403,26 @@ document.addEventListener("alpine:init", () => {
this.fetchDownloadProgress();
}, 1000); // Poll every second
},

// Add a helper method to set errors consistently
setError(error) {
this.errorMessage = {
basic: error.message || "An unknown error occurred",
stack: error.stack || ""
};
this.errorExpanded = false;

if (this.errorTimeout) {
clearTimeout(this.errorTimeout);
}

if (!this.errorExpanded) {
this.errorTimeout = setTimeout(() => {
this.errorMessage = null;
this.errorExpanded = false;
}, 30 * 1000);
}
},
}));
});

Expand Down