Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/runpod_flash/cli/commands/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,21 @@


def init_command(
ctx: typer.Context,
project_name: Optional[str] = typer.Argument(
None, help="Project name or '.' for current directory"
None, help="Project name, or '.' to initialize in current directory"
),
force: bool = typer.Option(False, "--force", "-f", help="Overwrite existing files"),
):
"""Create new Flash project with Flash Server and GPU workers."""

# No argument provided — show usage and exit
if project_name is None:
console.print(Panel(ctx.get_help(), title="flash init", expand=False))
raise typer.Exit(0)

# Determine target directory and initialization mode
if project_name is None or project_name == ".":
if project_name == ".":
# Initialize in current directory
project_dir = Path.cwd()
is_current_dir = True
Expand Down
129 changes: 89 additions & 40 deletions tests/unit/cli/commands/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,20 @@
from unittest.mock import MagicMock, Mock, patch

import pytest
import typer
from rich.panel import Panel

from runpod_flash.cli.commands.init import init_command


@pytest.fixture
def mock_typer_ctx():
"""Create a mock typer.Context for direct init_command calls."""
ctx = MagicMock(spec=typer.Context)
ctx.get_help.return_value = "Usage: flash init [OPTIONS] [PROJECT_NAME]"
return ctx


@pytest.fixture
def mock_context(monkeypatch):
"""Set up mocks for init command testing."""
Expand Down Expand Up @@ -44,11 +54,13 @@ def mock_context(monkeypatch):
class TestInitCommandNewDirectory:
"""Tests for init command when creating a new directory."""

def test_create_new_directory(self, mock_context, tmp_path, monkeypatch):
def test_create_new_directory(
self, mock_typer_ctx, mock_context, tmp_path, monkeypatch
):
"""Test creating new project directory."""
monkeypatch.chdir(tmp_path)

init_command("my_project")
init_command(mock_typer_ctx, "my_project")

# Verify directory was created
assert (tmp_path / "my_project").exists()
Expand All @@ -59,45 +71,70 @@ def test_create_new_directory(self, mock_context, tmp_path, monkeypatch):
# Verify console output
mock_context["console"].print.assert_called()

def test_create_nested_directory(self, mock_context, tmp_path, monkeypatch):
def test_create_nested_directory(
self, mock_typer_ctx, mock_context, tmp_path, monkeypatch
):
"""Test creating project in nested directory structure."""
monkeypatch.chdir(tmp_path)

init_command("path/to/my_project")
init_command(mock_typer_ctx, "path/to/my_project")

# Verify nested directory was created
assert (tmp_path / "path/to/my_project").exists()

def test_force_flag_skips_confirmation(self, mock_context, tmp_path, monkeypatch):
def test_force_flag_skips_confirmation(
self, mock_typer_ctx, mock_context, tmp_path, monkeypatch
):
"""Test that force flag bypasses conflict prompts."""
monkeypatch.chdir(tmp_path)
mock_context["detect_conflicts"].return_value = ["main.py", "requirements.txt"]

init_command("my_project", force=True)
init_command(mock_typer_ctx, "my_project", force=True)

# Verify skeleton was created
mock_context["create_skeleton"].assert_called_once()


class TestInitCommandCurrentDirectory:
"""Tests for init command when using current directory."""
class TestInitCommandNoArgs:
"""Tests for init command when called with no arguments."""

@patch("pathlib.Path.cwd")
def test_init_current_directory_with_none(self, mock_cwd, mock_context, tmp_path):
"""Test initialization in current directory with None argument."""
mock_cwd.return_value = tmp_path
def test_no_args_shows_help_and_exits(self, mock_typer_ctx, mock_context):
"""flash init with no args should show help and exit."""
with pytest.raises(typer.Exit) as exc_info:
init_command(mock_typer_ctx, None)

init_command(None)
assert exc_info.value.exit_code == 0

# Verify skeleton was created
mock_context["create_skeleton"].assert_called_once()
def test_no_args_does_not_create_skeleton(self, mock_typer_ctx, mock_context):
"""flash init with no args should not create project skeleton."""
with pytest.raises(typer.Exit):
init_command(mock_typer_ctx, None)

mock_context["create_skeleton"].assert_not_called()

def test_no_args_prints_usage_info(self, mock_typer_ctx, mock_context):
"""flash init with no args should print usage information."""
with pytest.raises(typer.Exit):
init_command(mock_typer_ctx, None)

# Verify console.print was called with a Panel containing usage info
mock_context["console"].print.assert_called_once()
panel_arg = mock_context["console"].print.call_args[0][0]
assert isinstance(panel_arg, Panel)
assert "flash init" in panel_arg.title


class TestInitCommandCurrentDirectory:
"""Tests for init command when using current directory."""

@patch("pathlib.Path.cwd")
def test_init_current_directory_with_dot(self, mock_cwd, mock_context, tmp_path):
def test_init_current_directory_with_dot(
self, mock_cwd, mock_typer_ctx, mock_context, tmp_path
):
"""Test initialization in current directory with '.' argument."""
mock_cwd.return_value = tmp_path

init_command(".")
init_command(mock_typer_ctx, ".")

# Verify skeleton was created
mock_context["create_skeleton"].assert_called_once()
Expand All @@ -106,21 +143,25 @@ def test_init_current_directory_with_dot(self, mock_cwd, mock_context, tmp_path)
class TestInitCommandConflictDetection:
"""Tests for init command file conflict detection and resolution."""

def test_no_conflicts_no_prompt(self, mock_context, tmp_path, monkeypatch):
def test_no_conflicts_no_prompt(
self, mock_typer_ctx, mock_context, tmp_path, monkeypatch
):
"""Test that prompt is skipped when no conflicts exist."""
monkeypatch.chdir(tmp_path)
mock_context["detect_conflicts"].return_value = []

init_command("my_project")
init_command(mock_typer_ctx, "my_project")

# Verify skeleton was created
mock_context["create_skeleton"].assert_called_once()

def test_console_called_multiple_times(self, mock_context, tmp_path, monkeypatch):
def test_console_called_multiple_times(
self, mock_typer_ctx, mock_context, tmp_path, monkeypatch
):
"""Test that console prints multiple outputs."""
monkeypatch.chdir(tmp_path)

init_command("my_project")
init_command(mock_typer_ctx, "my_project")

# Verify console.print was called multiple times
assert mock_context["console"].print.call_count > 0
Expand All @@ -129,42 +170,50 @@ def test_console_called_multiple_times(self, mock_context, tmp_path, monkeypatch
class TestInitCommandOutput:
"""Tests for init command output messages."""

def test_panel_title_for_new_directory(self, mock_context, tmp_path, monkeypatch):
def test_panel_title_for_new_directory(
self, mock_typer_ctx, mock_context, tmp_path, monkeypatch
):
"""Test that panel output is created for new directory."""
monkeypatch.chdir(tmp_path)

init_command("my_project")
init_command(mock_typer_ctx, "my_project")

# Verify console.print was called multiple times
assert mock_context["console"].print.call_count > 0

@patch("pathlib.Path.cwd")
def test_panel_title_for_current_directory(self, mock_cwd, mock_context, tmp_path):
def test_panel_title_for_current_directory(
self, mock_cwd, mock_typer_ctx, mock_context, tmp_path
):
"""Test that panel output is created for current directory."""
mock_cwd.return_value = tmp_path

init_command(".")
init_command(mock_typer_ctx, ".")

# Verify console.print was called
assert mock_context["console"].print.call_count > 0

def test_next_steps_displayed(self, mock_context, tmp_path, monkeypatch):
def test_next_steps_displayed(
self, mock_typer_ctx, mock_context, tmp_path, monkeypatch
):
"""Test next steps are displayed."""
monkeypatch.chdir(tmp_path)

init_command("my_project")
init_command(mock_typer_ctx, "my_project")

# Verify console.print was called with next steps text
assert any(
"Next steps" in str(c) for c in mock_context["console"].print.call_args_list
)

@patch("pathlib.Path.cwd")
def test_flash_login_step_displayed(self, mock_cwd, mock_context, tmp_path):
"""Test flash login is shown in the next steps table."""
def test_flash_login_step_displayed(
self, mock_cwd, mock_typer_ctx, mock_context, tmp_path
):
"""Test flash login is shown in the next steps table.""" (fix(init): use Typer built-in help instead of hand-crafted usage panel)
mock_cwd.return_value = tmp_path

init_command(".")
init_command(mock_typer_ctx, ".")

# The steps table is a Rich Table passed to console.print.
# Render it to plain text and check for "flash login".
Expand All @@ -186,12 +235,12 @@ def test_flash_login_step_displayed(self, mock_cwd, mock_context, tmp_path):
assert "flash login" in buf.getvalue()

def test_status_message_for_new_directory(
self, mock_context, tmp_path, monkeypatch
self, mock_typer_ctx, mock_context, tmp_path, monkeypatch
):
"""Test status message while creating new directory."""
monkeypatch.chdir(tmp_path)

init_command("my_project")
init_command(mock_typer_ctx, "my_project")

# Check that status was called with appropriate message
mock_context["console"].status.assert_called_once()
Expand All @@ -200,12 +249,12 @@ def test_status_message_for_new_directory(

@patch("pathlib.Path.cwd")
def test_status_message_for_current_directory(
self, mock_cwd, mock_context, tmp_path
self, mock_cwd, mock_typer_ctx, mock_context, tmp_path
):
"""Test status message while initializing current directory."""
mock_cwd.return_value = tmp_path

init_command(".")
init_command(mock_typer_ctx, ".")

# Check that status was called with initialization message
mock_context["console"].status.assert_called_once()
Expand All @@ -217,36 +266,36 @@ class TestInitCommandProjectNameHandling:
"""Tests for project name handling."""

def test_special_characters_in_project_name(
self, mock_context, tmp_path, monkeypatch
self, mock_typer_ctx, mock_context, tmp_path, monkeypatch
):
"""Test project name with special characters."""
monkeypatch.chdir(tmp_path)

init_command("my-project_123")
init_command(mock_typer_ctx, "my-project_123")

# Verify directory was created with the exact name
assert (tmp_path / "my-project_123").exists()

def test_console_called_with_panels_and_tables(
self, mock_context, tmp_path, monkeypatch
self, mock_typer_ctx, mock_context, tmp_path, monkeypatch
):
"""Test that console prints panels and tables."""
monkeypatch.chdir(tmp_path)

init_command("test_project")
init_command(mock_typer_ctx, "test_project")

# Verify console.print was called multiple times
assert (
mock_context["console"].print.call_count >= 4
) # Panel, "Next steps:", Table, API key info

def test_directory_created_matches_argument(
self, mock_context, tmp_path, monkeypatch
self, mock_typer_ctx, mock_context, tmp_path, monkeypatch
):
"""Test that directory created matches the argument."""
monkeypatch.chdir(tmp_path)

init_command("my_awesome_project")
init_command(mock_typer_ctx, "my_awesome_project")

# Verify directory was created with exact name
assert (tmp_path / "my_awesome_project").exists()
Expand Down
Loading