Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Server port. Optional (defaults to 4000)
PORT=4000

# File logger level. Optional (defaults to "0")
DEBUG_MODE="0"

# Node polling interval in seconds. Optional (defaults to 30)
REFRESH_INTERVAL=30

Expand Down
12 changes: 7 additions & 5 deletions src/logger.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import logging

import structlog
from os import environ

from structlog.typing import Processor

FILE_LOG_LEVEL = logging.DEBUG if environ.get("DEBUG_MODE") == "1" else logging.INFO

# Re-export logger
log = structlog.get_logger()

Expand All @@ -13,7 +16,6 @@
structlog.stdlib.add_logger_name, # Add logging function
structlog.dev.set_exc_info, # Exception info handling
structlog.processors.TimeStamper("%Y-%m-%d %H:%M:%S", utc=False), # Timestamp
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
]


Expand All @@ -34,8 +36,8 @@ def setup_logging(log_path: str = "/tmp/infernet_router.log") -> None:
wrapper_class=structlog.stdlib.BoundLogger,
)

# Setup raw python logger
root_logger = logging.getLogger()
# Use structlog for root logger
root_logger = structlog.getLogger()
root_logger.setLevel(logging.NOTSET)

# Setup log handlers
Expand All @@ -56,7 +58,7 @@ def setup_logging(log_path: str = "/tmp/infernet_router.log") -> None:
processor=structlog.processors.JSONRenderer()
)
)
file_handler.setLevel(logging.DEBUG) # Save to file DEBUG+
file_handler.setLevel(logging.FILE_LOG_LEVEL) # Save to file DEBUG+

# Add log handlers to raw python logger
root_logger.addHandler(console_handler)
Expand Down
16 changes: 9 additions & 7 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ def read_ips(filepath: str = "ips.txt") -> list[str]:
"""
try:
with open(filepath, "r") as file:
return file.read().splitlines()
nodes = file.read().splitlines()
log.debug("Loaded node IPs", count=len(nodes), nodes=nodes)
return nodes
except Exception as e:
log.error(f"Failed to read IPs from {filepath}: {str(e)}")
log.error("Failed to read IPs from %s: %s", filepath, e)
return []


Expand All @@ -33,19 +35,21 @@ async def shutdown(
signal (signal.Signals): Signal to handle
rest (RESTServer): REST server
"""
log.info(f"Received exit signal {signal.name}...")
log.info("Received exit signal", signal_name=signal.name)
await monitor.stop()
await rest.stop()
log.info("Shutdown complete.")


async def main() -> None:
"""Entry point for router"""
"""Entry point for router."""

# Read node IPs from file
nodes = read_ips()
port = environ.get("PORT", "4000")

log.info("Starting router", port=port, node_count=len(nodes))

monitor = NodeMonitor(nodes)
server = RESTServer(port, monitor)

Expand All @@ -64,10 +68,8 @@ async def main() -> None:
# Wait for any task to complete
done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
for task in done:
# Check if any tasks failed
if task.exception() is not None:
# Log exception
log.error(f"Task exception: {task.exception()}")
log.exception("Task failed", exc_info=task.exception())


if __name__ == "__main__":
Expand Down
18 changes: 13 additions & 5 deletions src/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ async def _update_node(
)
return

except Exception:
pass
except Exception as e:
log.error("Node check failed", node=host, error=str(e))

# Node is not available
if nodes[host].available:
Expand All @@ -111,15 +111,19 @@ async def _get_live_nodes(self: NodeMonitor) -> dict[Hostname, NodeInfo]:
dict[Hostname, NodeInfo]: Node objects for live nodes
"""

log.info("Fetching live nodes", api_url=self._api_url)

if self._api_url:
return {
# Hostname is ip:port for each node. Default port is 4000
live_nodes = {
f'{node["ip"]}:{node["port"] if "port" in node else "4000"}': NodeInfo(
available=False, containers=[], container_ids=[], pending={}
)
for node in await fetch_live_nodes(self._api_url)
}

log.debug("Live nodes retrieved", count=len(live_nodes))
return live_nodes

return {}

async def run_forever(self: NodeMonitor) -> None:
Expand All @@ -135,6 +139,8 @@ async def run_forever(self: NodeMonitor) -> None:
The availability of each node is checked by pinging the node's `/info` endpoint.
"""
while not self._shutdown:
log.debug("Starting node monitoring cycle")

# Base nodes
tasks = [
create_task(self._update_node(host, self._base_nodes))
Expand Down Expand Up @@ -258,7 +264,7 @@ async def fetch(session: ClientSession, url: str) -> Optional[dict[str, Any]]:
response.raise_for_status()
return cast(dict[str, Any], await response.json())
except Exception as e:
log.warning(f"Error fetching data from {url}: {e}")
log.warning("Failed to fetch node resources", url=url, error=str(e))
return None

async with ClientSession() as session:
Expand All @@ -274,6 +280,8 @@ async def fetch(session: ClientSession, url: str) -> Optional[dict[str, Any]]:
# Gather results in parallel
results = await gather(*tasks.values(), return_exceptions=False)

log.info("Fetched node resources", node_count=len(results))

# Return a dictionary from host to resources
return {
host: result
Expand Down
34 changes: 28 additions & 6 deletions src/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,35 @@ def __init__(
def register_routes(self: RESTServer) -> None:
"""Registers Quart webserver routes"""

@self._app.before_request
async def log_request() -> None:
"""Logs incoming requests"""
log.info(
"Incoming request",
method=request.method,
path=request.path,
args=request.args.to_dict(),
remote_addr=request.remote_addr,
)

@self._app.route("/api/v1/ips", methods=["GET"])
@rate_limit(RATELIMIT_REQS_PER_MIN, timedelta(seconds=30))
async def ips() -> Response:
"""Returns IPs of nodes that can fulfill a job request"""

containers = request.args.getlist("container")
if not containers:
log.warning("Invalid request: Missing containers", remote_addr=request.remote_addr)
abort(400, "No containers specified")

# Optional query parameters n and offset
n = request.args.get("n", default=3, type=int)
offset = request.args.get("offset", default=0, type=int)

return jsonify(self._monitor.get_nodes(containers, n, offset))
nodes = self._monitor.get_nodes(containers, n, offset)
log.info("Returning available nodes", nodes=nodes, count=len(nodes))

return jsonify(nodes)

@self._app.route("/api/v1/containers", methods=["GET"])
@rate_limit(RATELIMIT_REQS_PER_MIN, timedelta(seconds=30))
Expand All @@ -79,7 +94,15 @@ async def resources() -> Response:
"""Returns resources available on each node in the network"""

model_id = request.args.get("model_id")
return jsonify(await self._monitor.get_resources(model_id))
resources = await self._monitor.get_resources(model_id)
log.info("Returning node resources", node_count=len(resources))
return jsonify(resources)

@self._app.errorhandler(Exception)
async def handle_exception(e: Exception) -> Response:
"""Global error handler"""
log.error("Unhandled exception in request", error=str(e), path=request.path)
return jsonify({"error": "Internal server error"}), 500

async def run_forever(self: RESTServer) -> None:
"""Main RESTServer lifecycle loop. Uses production hypercorn server"""
Expand All @@ -103,11 +126,10 @@ async def shutdown_trigger() -> None:
try:
await server_task
except CancelledError:
pass # Expected due to cancellation
log.warning("REST server task was cancelled.") # Expected due to cancellation

async def stop(self: RESTServer) -> None:
"""Stops the RESTServer."""
log.info("Stopping REST webserver")

# Set shutdown event to stop server
log.warning("Shutdown signal received. Stopping REST webserver...")
self._shutdown_event.set()
log.info("REST webserver shutdown completed.")
12 changes: 9 additions & 3 deletions src/sql.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, cast

from aiohttp import ClientSession
from aiohttp import ClientSession, ClientError, ClientResponseError, ClientTimeout

from logger import log

Expand All @@ -16,15 +16,21 @@ async def fetch_live_nodes(api_url: str) -> list[dict[str, Any]]:
"""
url = f"{api_url}/api/nodes?minutes_past=60"
try:
async with ClientSession() as session:
async with ClientSession(timeout=ClientTimeout(total=10)) as session:
async with session.get(url) as response:
# Check if the HTTP request was successful
if response.status == 200:
body = await response.json()
return cast(list[dict[str, Any]], body["data"])
else:
log.error("Failed to fetch live nodes", status=response.status)
except ClientResponseError as e:
log.error("HTTP error fetching live nodes", status=e.status, message=str(e))
except ClientError as e:
log.error("Network error fetching live nodes", error=str(e))
except KeyError:
log.error("Unexpected API response format: Missing 'data' key")
except Exception as e:
log.error(f"Failed to fetch live nodes: {str(e)}")
log.error("Unknown error fetching live nodes", error=str(e))

return []