diff --git a/pyproject.toml b/pyproject.toml index 3a97878d..fd3d41fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ test = [ ] [project.scripts] -flash = "runpod_flash.cli.main:app" +flash = "runpod_flash.cli.entrypoint:main" [build-system] requires = ["setuptools>=42", "wheel"] diff --git a/src/runpod_flash/cli/entrypoint.py b/src/runpod_flash/cli/entrypoint.py new file mode 100644 index 00000000..55f7b498 --- /dev/null +++ b/src/runpod_flash/cli/entrypoint.py @@ -0,0 +1,36 @@ +"""Thin CLI entrypoint that catches corrupted credentials at import time. + +The runpod package reads ~/.runpod/config.toml at import time (in its +__init__.py). If that file contains invalid TOML, the import raises a +TOMLDecodeError before any Flash error handling can run. This wrapper +catches that and surfaces a clean error message. +""" + +import sys + + +def main() -> None: + """Entry point for the ``flash`` console script.""" + try: + from runpod_flash.cli.main import app + except ValueError as exc: + # TOML decode errors from toml/tomli/tomllib are ValueError subclasses. + # The runpod package calls a TOML loader at import time; a corrupted + # ~/.runpod/config.toml triggers this before Flash code executes. + exc_type = type(exc) + exc_module = getattr(exc_type, "__module__", "").lower() + is_toml_decode_error = exc_type.__name__ == "TOMLDecodeError" and ( + exc_module.startswith("toml") + or exc_module.startswith("tomli") + or exc_module.startswith("tomllib") + ) + if is_toml_decode_error: + print( + "Error: ~/.runpod/config.toml is corrupted and cannot be parsed.\n" + "Run 'flash login' to re-authenticate, or delete the file and retry.", + file=sys.stderr, + ) + raise SystemExit(1) from None + raise + + app() diff --git a/tests/unit/cli/test_entrypoint.py b/tests/unit/cli/test_entrypoint.py new file mode 100644 index 00000000..0cf599c3 --- /dev/null +++ b/tests/unit/cli/test_entrypoint.py @@ -0,0 +1,77 @@ +"""Tests for the CLI entrypoint wrapper that catches corrupted credentials.""" + +import sys +from unittest.mock import MagicMock, patch + +import pytest + +from runpod_flash.cli.entrypoint import main + + +class TestEntrypoint: + """Tests for runpod_flash.cli.entrypoint.main.""" + + def test_normal_import_runs_app(self): + """When import succeeds, the Typer app is invoked.""" + mock_app = MagicMock() + mock_module = MagicMock() + mock_module.app = mock_app + + with patch.dict(sys.modules, {"runpod_flash.cli.main": mock_module}): + main() + + mock_app.assert_called_once() + + def test_corrupted_toml_shows_clean_error(self, capsys): + """Import-time TOMLDecodeError surfaces a clean message, not a traceback.""" + # Create a ValueError whose class looks like a TOML decode error. + # tomli.TOMLDecodeError is a ValueError subclass with module "tomli._parser". + toml_exc_cls = type( + "TOMLDecodeError", (ValueError,), {"__module__": "tomli._parser"} + ) + toml_error = toml_exc_cls("Invalid value at line 1 col 9") + + # Remove the module from cache so the import inside main() re-executes + saved = sys.modules.pop("runpod_flash.cli.main", None) + try: + with patch.dict(sys.modules, {"runpod_flash.cli.main": None}): + # Patch __import__ to raise when the entrypoint tries to import main + real_import = __import__ + + def patched_import(name, *args, **kwargs): + if name == "runpod_flash.cli.main": + raise toml_error + return real_import(name, *args, **kwargs) + + with patch("builtins.__import__", side_effect=patched_import): + with pytest.raises(SystemExit) as exc_info: + main() + + assert exc_info.value.code == 1 + captured = capsys.readouterr() + assert "corrupted" in captured.err + assert "flash login" in captured.err + finally: + if saved is not None: + sys.modules["runpod_flash.cli.main"] = saved + + def test_non_toml_value_error_propagates(self): + """A ValueError unrelated to TOML is not caught.""" + saved = sys.modules.pop("runpod_flash.cli.main", None) + try: + with patch.dict(sys.modules, {"runpod_flash.cli.main": None}): + real_import = __import__ + + def patched_import(name, *args, **kwargs): + if name == "runpod_flash.cli.main": + raise ValueError("something completely different") + return real_import(name, *args, **kwargs) + + with patch("builtins.__import__", side_effect=patched_import): + with pytest.raises( + ValueError, match="something completely different" + ): + main() + finally: + if saved is not None: + sys.modules["runpod_flash.cli.main"] = saved