diff --git a/cli/azd/.vscode/cspell.yaml b/cli/azd/.vscode/cspell.yaml index 6ae1d0c06b5..1cec1704844 100644 --- a/cli/azd/.vscode/cspell.yaml +++ b/cli/azd/.vscode/cspell.yaml @@ -72,11 +72,35 @@ words: - subcmd - genproto - errdetails + - slogger + - SSRF + - ssrf + - Teredo + - allowlist + - blocklist + - metacharacter + - metacharacters - yarnpkg - azconfig - hostnames - seekable - seekability + - APFS + - NTFS + - mcpgo + - cpus + - unsanitized + - PATHEXT + - mintty + - dockerenv + - exfiltration + - Fprintf + - gocritic + - IMDS + - myhost + - preconfigured + - Println + - sctx languageSettings: - languageId: go ignoreRegExpList: @@ -90,6 +114,9 @@ dictionaryDefinitions: dictionaries: - azdProjectDictionary overrides: + - filename: pkg/azdext/config_helper.go + words: + - myext - filename: internal/tracing/fields/domains.go words: - azmk diff --git a/cli/azd/CHANGELOG.md b/cli/azd/CHANGELOG.md index 8e7c1076ff0..da555bc9c14 100644 --- a/cli/azd/CHANGELOG.md +++ b/cli/azd/CHANGELOG.md @@ -4,12 +4,24 @@ ### Features Added +- [[#2743]](https://github.com/Azure/azure-dev/issues/2743) Support deploying Container App Jobs (`Microsoft.App/jobs`) via `host: containerapp`. The Bicep template determines whether the target is a Container App or Container App Job. +- Add `ConfigHelper` for typed, ergonomic access to azd user and environment configuration through gRPC services, with validation support, shallow/deep merge, and structured error types (`ConfigError`). +- Add `Pager[T]` generic pagination helper with SSRF-safe nextLink validation, `Collect` with `MaxPages`/`MaxItems` bounds, and `Truncated()` detection for callers. +- Add `ResilientClient` hardening: exponential backoff with jitter, upfront body seekability validation, and `Retry-After` header cap at 120 s. +- Add `SSRFGuard` standalone SSRF protection with metadata endpoint blocking, private network blocking, HTTPS enforcement, DNS fail-closed, IPv6 embedding extraction, and allowlist bypass. +- Add atomic file operations (`WriteFileAtomic`, `CopyFileAtomic`, `BackupFile`, `EnsureDir`) with crash-safe write-temp-rename pattern. +- Add runtime process utilities for cross-platform process management, tool discovery, and shell execution helpers. + ### Breaking Changes ### Bugs Fixed ### Other Changes +- Add Extension SDK Reference documentation covering `NewExtensionRootCommand`, `MCPServerBuilder`, `ToolArgs`, `MCPSecurityPolicy`, `BaseServiceTargetProvider`, and all SDK helpers introduced in [#6856](https://github.com/Azure/azure-dev/pull/6856). See [Extension SDK Reference](docs/extensions/extension-sdk-reference.md). +- Add Extension Migration Guide with before/after examples for migrating from legacy patterns to SDK helpers. See [Extension Migration Guide](docs/extensions/extension-migration-guide.md). +- Add Extension End-to-End Walkthrough demonstrating root command setup, MCP server construction, lifecycle event handlers, and security policy usage. See [Extension End-to-End Walkthrough](docs/extensions/extension-e2e-walkthrough.md). + ## 1.23.8 (2026-03-06) ### Features Added diff --git a/cli/azd/docs/extensions/extension-e2e-walkthrough.md b/cli/azd/docs/extensions/extension-e2e-walkthrough.md new file mode 100644 index 00000000000..79c1ead73cc --- /dev/null +++ b/cli/azd/docs/extensions/extension-e2e-walkthrough.md @@ -0,0 +1,477 @@ +# Extension End-to-End Walkthrough + +Build a complete azd extension from scratch using the `azdext` SDK helpers. This +walkthrough creates a **resource-tagging** extension that: + +1. Registers custom commands for managing Azure resource tags. +2. Exposes an MCP server so AI assistants can call the tagging tools. +3. Hooks into the `postprovision` lifecycle event to auto-tag resources. +4. Uses `MCPSecurityPolicy` to validate user-supplied URLs. + +> **Prerequisites:** +> +> - Go ≥ 1.22 +> - azd ≥ 1.23.7 (with extension SDK helpers from [#6856](https://github.com/Azure/azure-dev/pull/6856)) +> - The `microsoft.azd.extensions` developer extension installed (`azd extension install microsoft.azd.extensions`) + +--- + +## Table of Contents + +- [Step 1: Scaffold the Extension](#step-1-scaffold-the-extension) +- [Step 2: Define the Root Command](#step-2-define-the-root-command) +- [Step 3: Add a Custom Command](#step-3-add-a-custom-command) +- [Step 4: Build an MCP Server with Tools](#step-4-build-an-mcp-server-with-tools) +- [Step 5: Register Lifecycle Event Handlers](#step-5-register-lifecycle-event-handlers) +- [Step 6: Wire It All Together](#step-6-wire-it-all-together) +- [Step 7: Build, Install, and Test](#step-7-build-install-and-test) +- [Project Structure Summary](#project-structure-summary) +- [What You Have Built](#what-you-have-built) + +--- + +## Step 1: Scaffold the Extension + +```bash +cd cli/azd/extensions +azd x init +``` + +Follow the prompts: + +- **Name:** `contoso.azd.tagger` +- **Language:** Go +- **Capabilities:** `custom-commands`, `metadata`, `lifecycle-events` + +This creates a directory with `extension.yaml`, `main.go`, build scripts, and a +`CHANGELOG.md`. + +Edit `extension.yaml` to match: + +```yaml +# yaml-language-server: $schema=https://raw.githubusercontent.com/Azure/azure-dev/refs/heads/main/cli/azd/extensions/extension.schema.json +id: contoso.azd.tagger +namespace: tagger +displayName: Resource Tagger +description: Auto-tag Azure resources and expose tagging tools via MCP. +version: 0.1.0 +capabilities: + - custom-commands + - metadata + - lifecycle-events +``` + +--- + +## Step 2: Define the Root Command + +Replace `main.go` with the SDK-helper entry point: + +```go +package main + +import ( + "github.com/azure/azure-dev/cli/azd/extensions/contoso.azd.tagger/internal/cmd" + "github.com/azure/azure-dev/cli/azd/pkg/azdext" +) + +func main() { + azdext.Run(cmd.NewRootCommand()) +} +``` + +Create `internal/cmd/root.go`: + +```go +package cmd + +import ( + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/spf13/cobra" +) + +const ( + extensionID = "contoso.azd.tagger" + version = "0.1.0" +) + +func NewRootCommand() *cobra.Command { + // NewExtensionRootCommand registers --debug, --no-prompt, --cwd, + // -e/--environment, --output and sets up trace context automatically. + rootCmd, extCtx := azdext.NewExtensionRootCommand(azdext.ExtensionCommandOptions{ + Name: "tagger", + Version: version, + Short: "Manage Azure resource tags", + }) + + // Custom commands + rootCmd.AddCommand(newTagCommand(extCtx)) + rootCmd.AddCommand(newMCPCommand(extCtx)) + + // Standard lifecycle, metadata, and version commands + rootCmd.AddCommand(azdext.NewListenCommand(configureListen)) + rootCmd.AddCommand(azdext.NewMetadataCommand("1.0", extensionID, NewRootCommand)) + rootCmd.AddCommand(azdext.NewVersionCommand(extensionID, version, &extCtx.OutputFormat)) + + return rootCmd +} +``` + +**What this gives you:** + +- All azd global flags parsed and available via `extCtx`. +- OpenTelemetry trace context propagated from the parent azd process. +- gRPC access token injected into the command context. + +--- + +## Step 3: Add a Custom Command + +Create `internal/cmd/tag.go`: + +```go +package cmd + +import ( + "fmt" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/spf13/cobra" +) + +func newTagCommand(extCtx *azdext.ExtensionContext) *cobra.Command { + var tagKey, tagValue string + + cmd := &cobra.Command{ + Use: "tag", + Short: "Apply a tag to resources in the current environment", + RunE: func(cmd *cobra.Command, args []string) error { + ctx := extCtx.Context() + + // Create the azd gRPC client + client, err := azdext.NewAzdClient() + if err != nil { + return fmt.Errorf("creating azd client: %w", err) + } + defer client.Close() + + // Use the ConfigHelper to read environment config + configHelper, err := azdext.NewConfigHelper(client) + if err != nil { + return fmt.Errorf("creating config helper: %w", err) + } + + // Read the resource group from env config + rg, found, err := configHelper.GetEnvString(ctx, "AZURE_RESOURCE_GROUP") + if err != nil { + return fmt.Errorf("reading resource group: %w", err) + } + if !found { + return &azdext.LocalError{ + Message: "no resource group configured", + Category: azdext.LocalErrorCategoryValidation, + Suggestion: "Run 'azd provision' first to create Azure resources.", + } + } + + // Format-aware output + output := azdext.NewOutput(azdext.OutputOptions{ + Format: azdext.OutputFormat(extCtx.OutputFormat), + }) + + // Log the operation + logger := azdext.NewLogger("tagger") + logger.Info("applying tag", "key", tagKey, "value", tagValue, "rg", rg) + + // ... tag application logic using Azure SDK ... + + output.Success("Tagged resources in %s: %s=%s", rg, tagKey, tagValue) + return nil + }, + } + + cmd.Flags().StringVar(&tagKey, "key", "", "Tag key (required)") + cmd.Flags().StringVar(&tagValue, "value", "", "Tag value (required)") + _ = cmd.MarkFlagRequired("key") + _ = cmd.MarkFlagRequired("value") + + return cmd +} +``` + +**Key patterns demonstrated:** + +- `extCtx.Context()` for a trace-aware, token-injected context. +- `ConfigHelper` for reading azd environment configuration. +- `LocalError` with category and suggestion for structured error reporting. +- `Output` for format-aware display (text or JSON). +- `Logger` for structured logging. + +--- + +## Step 4: Build an MCP Server with Tools + +Create `internal/cmd/mcp.go`: + +```go +package cmd + +import ( + "context" + "fmt" + "os" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + mcp "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/spf13/cobra" +) + +func newMCPCommand(extCtx *azdext.ExtensionContext) *cobra.Command { + return &cobra.Command{ + Use: "mcp", + Short: "Start the MCP server for AI-assisted tagging", + RunE: func(cmd *cobra.Command, args []string) error { + // Build the MCP server with the fluent builder API + mcpServer := azdext.NewMCPServerBuilder("tagger", "0.1.0"). + // Rate limit: max 10 concurrent, refill 2/sec + WithRateLimit(10, 2.0). + // Security policy to validate any user-provided URLs + WithSecurityPolicy(azdext.DefaultMCPSecurityPolicy()). + // System instructions for AI clients + WithInstructions(`Use these tools to manage Azure resource tags. +Always confirm tag operations with the user before applying.`). + // Register tools + AddTool("list_tags", listTagsHandler, azdext.MCPToolOptions{ + Description: "List tags on resources in a resource group", + }, + mcp.WithString("resourceGroup", + mcp.Required(), + mcp.Description("Azure resource group name"), + ), + mcp.WithString("subscription", + mcp.Description("Azure subscription ID (uses default if omitted)"), + ), + ). + AddTool("set_tag", setTagHandler, azdext.MCPToolOptions{ + Description: "Set a tag on all resources in a resource group", + }, + mcp.WithString("resourceGroup", + mcp.Required(), + mcp.Description("Azure resource group name"), + ), + mcp.WithString("key", + mcp.Required(), + mcp.Description("Tag key"), + ), + mcp.WithString("value", + mcp.Required(), + mcp.Description("Tag value"), + ), + ). + Build() + + // Serve over stdio (standard MCP transport) + sseServer := server.NewStdioServer(mcpServer) + return sseServer.Listen(cmd.Context(), os.Stdin, os.Stdout) + }, + } +} + +// listTagsHandler demonstrates typed argument parsing and JSON result helpers. +func listTagsHandler(ctx context.Context, args azdext.ToolArgs) (*mcp.CallToolResult, error) { + rg, err := args.RequireString("resourceGroup") + if err != nil { + return azdext.MCPErrorResult("missing argument: %v", err), nil + } + sub := args.OptionalString("subscription", "") + + logger := azdext.NewLogger("mcp.list_tags") + logger.Info("listing tags", "resourceGroup", rg, "subscription", sub) + + // ... Azure SDK call to list tags ... + tags := map[string]string{ + "environment": "production", + "owner": "platform-team", + } + + return azdext.MCPJSONResult(tags), nil +} + +// setTagHandler demonstrates security policy usage and error handling. +func setTagHandler(ctx context.Context, args azdext.ToolArgs) (*mcp.CallToolResult, error) { + rg, err := args.RequireString("resourceGroup") + if err != nil { + return azdext.MCPErrorResult("missing argument: %v", err), nil + } + key, err := args.RequireString("key") + if err != nil { + return azdext.MCPErrorResult("missing argument: %v", err), nil + } + value, err := args.RequireString("value") + if err != nil { + return azdext.MCPErrorResult("missing argument: %v", err), nil + } + + logger := azdext.NewLogger("mcp.set_tag") + logger.Info("setting tag", "resourceGroup", rg, "key", key, "value", value) + + // ... Azure SDK call to set tag ... + + return azdext.MCPTextResult("Tag %s=%s applied to resource group %s", key, value, rg), nil +} +``` + +**Key patterns demonstrated:** + +- `MCPServerBuilder` fluent API with rate limiting and security policy. +- `ToolArgs.RequireString` / `OptionalString` for typed argument access. +- `MCPTextResult`, `MCPJSONResult`, `MCPErrorResult` for response construction. +- `DefaultMCPSecurityPolicy` for SSRF protection. + +--- + +## Step 5: Register Lifecycle Event Handlers + +Create `internal/cmd/listen.go`: + +```go +package cmd + +import ( + "context" + "fmt" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" +) + +// configureListen is called by NewListenCommand to register event handlers. +func configureListen(host *azdext.ExtensionHost) { + host.WithProjectEventHandler("postprovision", handlePostProvision) +} + +func handlePostProvision(ctx context.Context, args *azdext.ProjectEventArgs) error { + logger := azdext.NewLogger("tagger.postprovision") + logger.Info("auto-tagging resources after provision", "project", args.Project.Name) + + client := args.Client // The host provides access to the azd gRPC client + + // Read environment values + envResp, err := client.Environment().GetValues(ctx, &azdext.GetEnvironmentRequest{}) + if err != nil { + logger.Warn("could not read environment", "error", err) + return nil // Non-fatal: don't block provisioning + } + + // Apply standard tags to all resources + for key, value := range envResp.Values { + logger.Debug("found env value", "key", key, "value", value) + } + + logger.Info("auto-tagging complete") + return nil +} +``` + +**Key patterns demonstrated:** + +- `NewListenCommand` + configure callback for clean lifecycle registration. +- Project event handlers receive `ProjectEventArgs` with access to the project + metadata and gRPC client. +- Non-fatal error handling — the handler logs warnings but doesn't block the + parent azd workflow. + +--- + +## Step 6: Wire It All Together + +Your final `main.go`: + +```go +package main + +import ( + "github.com/azure/azure-dev/cli/azd/extensions/contoso.azd.tagger/internal/cmd" + "github.com/azure/azure-dev/cli/azd/pkg/azdext" +) + +func main() { + azdext.Run(cmd.NewRootCommand()) +} +``` + +That is the **entire** entry point. `azdext.Run` handles: + +- `FORCE_COLOR` detection +- Trace context propagation +- Access token injection +- Structured error reporting +- Exit code management + +--- + +## Step 7: Build, Install, and Test + +```bash +# Build for current platform +azd x build + +# Build for all platforms +azd x build --all + +# Test the custom command +azd tagger tag --key team --value platform -e dev + +# Test the MCP server (connects over stdio) +echo '{"jsonrpc":"2.0","id":1,"method":"tools/list"}' | azd tagger mcp + +# Test lifecycle integration (azd invokes 'listen' automatically) +azd provision # postprovision handler fires +``` + +--- + +## Project Structure Summary + +``` +contoso.azd.tagger/ +├── main.go # Entry point: azdext.Run(cmd.NewRootCommand()) +├── extension.yaml # Extension manifest +├── CHANGELOG.md # Release notes +├── go.mod # Go module +├── go.sum +├── build.ps1 # Windows build +├── build.sh # Unix build +└── internal/ + └── cmd/ + ├── root.go # Root command + subcommand wiring + ├── tag.go # Custom 'tag' command + ├── mcp.go # MCP server with tools + └── listen.go # Lifecycle event handlers +``` + +--- + +## What You Have Built + +| Capability | SDK Helper Used | Lines Saved | +|------------|----------------|-------------| +| Root command with global flags + tracing | `NewExtensionRootCommand` | ~40 lines | +| Entry point with error reporting | `Run` | ~15 lines | +| Listen command with host setup | `NewListenCommand` | ~20 lines | +| MCP server with rate limiting + security | `MCPServerBuilder` | ~60 lines | +| Typed MCP argument parsing | `ToolArgs` | ~10 lines per tool | +| MCP response construction | `MCPTextResult`, `MCPJSONResult`, `MCPErrorResult` | ~5 lines per tool | +| SSRF/path protection | `DefaultMCPSecurityPolicy` | ~50 lines | +| Metadata + version commands | `NewMetadataCommand`, `NewVersionCommand` | ~30 lines | +| **Total boilerplate eliminated** | | **~250+ lines** | + +--- + +## See Also + +- [Extension SDK Reference](./extension-sdk-reference.md) — Full API reference for all helpers. +- [Extension Migration Guide](./extension-migration-guide.md) — Migrate existing extensions from legacy patterns. +- [Extension Framework](./extension-framework.md) — General framework documentation. +- [Extension Framework Services](./extension-framework-services.md) — gRPC service reference. +- [Extension Style Guide](./extensions-style-guide.md) — Design guidelines. diff --git a/cli/azd/docs/extensions/extension-framework.md b/cli/azd/docs/extensions/extension-framework.md index 9b8369e7eb9..a5f4a478035 100644 --- a/cli/azd/docs/extensions/extension-framework.md +++ b/cli/azd/docs/extensions/extension-framework.md @@ -22,6 +22,16 @@ Table of Contents - [Compose Service](#compose-service) - [Workflow Service](#workflow-service) +### Related Guides + +| Guide | Description | +|-------|-------------| +| [Extension SDK Reference](./extension-sdk-reference.md) | Complete API reference for `azdext` SDK helpers (command scaffolding, MCP builder, security policy, service-target base). | +| [Extension Migration Guide](./extension-migration-guide.md) | Before/after cookbook for migrating from pre-#6856 patterns to SDK helpers. | +| [Extension End-to-End Walkthrough](./extension-e2e-walkthrough.md) | Build a complete extension from scratch with root command, MCP server, lifecycle events, and security. | +| [Extension Framework Services](./extension-framework-services.md) | Custom language/framework support via `FrameworkServiceProvider`. | +| [Extension Style Guide](./extensions-style-guide.md) | Design guidelines for command integration, flags, and discoverability. | + ## Getting Started `azd` extensions are currently an alpha feature within `azd`. @@ -474,18 +484,20 @@ The build process automatically creates binaries for multiple platforms and arch `azd` uses OpenTelemetry and W3C Trace Context for distributed tracing. `azd` sets `TRACEPARENT` in the environment when it launches the extension process. -Use `azdext.NewContext()` to hydrate the root context with trace context: +The recommended approach is to use `azdext.Run`, which automatically creates a trace-aware context, injects the access token, reports structured errors, and handles `os.Exit`: ```go func main() { - ctx := azdext.NewContext() - rootCmd := cmd.NewRootCommand() - if err := rootCmd.ExecuteContext(ctx); err != nil { - // Handle error - } + azdext.Run(cmd.NewRootCommand()) } ``` +For lifecycle-listener extensions, `azdext.NewListenCommand` sets up trace context and access token automatically within its handler. + +> **Note:** `azdext.NewContext()` is deprecated. Use `azdext.Run` for custom-command extensions +> or `azdext.NewListenCommand`/`azdext.NewExtensionRootCommand` for lifecycle listeners. +> `NewContext` remains available for backward compatibility but new extensions should not use it. + To correlate Azure SDK calls with the parent trace, add the correlation policy to your client options: ```go @@ -1138,11 +1150,12 @@ func main() { } ``` -Alternatively, you can use `azdext.ReportError` directly for lower-level control: +Alternatively, you can use `azdext.ReportError` directly for lower-level control +(note: `NewContext` is deprecated — prefer `Run` for new extensions): ```go func main() { - ctx := azdext.NewContext() + ctx := azdext.NewContext() // Deprecated: prefer azdext.Run ctx = azdext.WithAccessToken(ctx) rootCmd := cmd.NewRootCommand() diff --git a/cli/azd/docs/extensions/extension-migration-guide.md b/cli/azd/docs/extensions/extension-migration-guide.md new file mode 100644 index 00000000000..f313d690d5a --- /dev/null +++ b/cli/azd/docs/extensions/extension-migration-guide.md @@ -0,0 +1,554 @@ +# Extension Migration Guide + +This guide helps extension authors migrate from pre-[#6856](https://github.com/Azure/azure-dev/pull/6856) patterns to the new `azdext` SDK helpers. Each section shows a **before** (legacy) and **after** (recommended) pattern with a brief explanation of what changed and why. + +> **Applies to:** Extensions targeting azd ≥ 1.23.7 with `azdext` SDK helpers. + +--- + +## Table of Contents + +- [M1: Entry Point — NewContext → Run](#m1-entry-point--newcontext--run) +- [M2: Root Command — Manual Flags → NewExtensionRootCommand](#m2-root-command--manual-flags--newextensionrootcommand) +- [M3: Listen Command — Manual Host Setup → NewListenCommand](#m3-listen-command--manual-host-setup--newlistencommand) +- [M4: MCP Server — Manual Construction → MCPServerBuilder](#m4-mcp-server--manual-construction--mcpserverbuilder) +- [M5: MCP Tool Arguments — Raw Map Access → ToolArgs](#m5-mcp-tool-arguments--raw-map-access--toolargs) +- [M6: MCP Responses — Manual Result Construction → Result Helpers](#m6-mcp-responses--manual-result-construction--result-helpers) +- [M7: SSRF / Path Validation — Custom Checks → MCPSecurityPolicy](#m7-ssrf--path-validation--custom-checks--mcpsecuritypolicy) +- [M8: Service Target — Full Interface → BaseServiceTargetProvider](#m8-service-target--full-interface--baseservicetargetprovider) +- [M9: Metadata Command — Hand-Rolled → NewMetadataCommand](#m9-metadata-command--hand-rolled--newmetadatacommand) +- [M10: Version Command — Custom → NewVersionCommand](#m10-version-command--custom--newversioncommand) +- [Compatibility Notes](#compatibility-notes) +- [Step-by-Step Migration Checklist](#step-by-step-migration-checklist) + +--- + +## M1: Entry Point — NewContext → Run + +### Before (legacy) + +```go +func main() { + ctx, err := azdext.NewContext() + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + + rootCmd := cmd.NewRootCommand() + rootCmd.SetContext(ctx) + if err := rootCmd.Execute(); err != nil { + os.Exit(1) + } +} +``` + +**Problems:** Manual error display, no structured error reporting to azd, no +`FORCE_COLOR` handling, exit codes not standardized. + +### After (recommended) + +```go +func main() { + azdext.Run(cmd.NewRootCommand()) +} +``` + +**What changed:** `Run` handles context creation, `FORCE_COLOR`, trace +propagation, access-token injection, structured error reporting via gRPC +`ReportError`, and `os.Exit`. One line replaces ~15 lines of boilerplate. + +> **Note:** `NewContext()` is deprecated but remains available for backward +> compatibility. New extensions should not use it. + +--- + +## M2: Root Command — Manual Flags → NewExtensionRootCommand + +### Before (legacy) + +```go +var ( + debug bool + noPrompt bool + cwd string + environment string + output string +) + +func NewRootCommand() *cobra.Command { + rootCmd := &cobra.Command{ + Use: "my-extension", + Short: "My extension", + PersistentPreRunE: func(cmd *cobra.Command, args []string) error { + // Manual trace context extraction + traceparent := os.Getenv("TRACEPARENT") + tracestate := os.Getenv("TRACESTATE") + if traceparent != "" { + ctx := propagation.TraceContext{}.Extract(cmd.Context(), + propagation.MapCarrier{ + "traceparent": traceparent, + "tracestate": tracestate, + }) + cmd.SetContext(ctx) + } + // Manual access token + cmd.SetContext(azdext.WithAccessToken(cmd.Context())) + return nil + }, + } + + rootCmd.PersistentFlags().BoolVar(&debug, "debug", false, "Enable debug logging") + rootCmd.PersistentFlags().BoolVar(&noPrompt, "no-prompt", false, "Disable prompts") + rootCmd.PersistentFlags().StringVar(&cwd, "cwd", "", "Working directory") + rootCmd.PersistentFlags().StringVarP(&environment, "environment", "e", "", "Environment") + rootCmd.PersistentFlags().StringVar(&output, "output", "", "Output format") + + return rootCmd +} +``` + +**Problems:** 30-50 lines of identical boilerplate in every extension. Flag names +and trace-context extraction can drift from what azd expects. + +### After (recommended) + +```go +func NewRootCommand() *cobra.Command { + rootCmd, extCtx := azdext.NewExtensionRootCommand(azdext.ExtensionCommandOptions{ + Name: "my-extension", + Version: "1.0.0", + Short: "My extension", + }) + + // Use extCtx.Debug, extCtx.Cwd, etc. in subcommands + rootCmd.AddCommand(newServeCommand(extCtx)) + + return rootCmd +} +``` + +**What changed:** `NewExtensionRootCommand` registers all standard flags, reads +`AZD_*` env vars, and sets up trace context + access token in +`PersistentPreRunE`. The `ExtensionContext` struct provides typed access to the +parsed values. + +--- + +## M3: Listen Command — Manual Host Setup → NewListenCommand + +### Before (legacy) + +```go +func newListenCommand() *cobra.Command { + return &cobra.Command{ + Use: "listen", + Short: "Starts the extension and listens for events.", + RunE: func(cmd *cobra.Command, args []string) error { + ctx := azdext.WithAccessToken(cmd.Context()) + azdClient, err := azdext.NewAzdClient() + if err != nil { + return fmt.Errorf("failed to create azd client: %w", err) + } + defer azdClient.Close() + + host := azdext.NewExtensionHost(azdClient). + WithServiceTarget("myhost", func() azdext.ServiceTargetProvider { + return &MyProvider{client: azdClient} + }) + + return host.Run(ctx) + }, + } +} +``` + +### After (recommended) + +```go +rootCmd.AddCommand(azdext.NewListenCommand(func(host *azdext.ExtensionHost) { + host.WithServiceTarget("myhost", func() azdext.ServiceTargetProvider { + return &MyProvider{client: host.Client()} + }) +})) +``` + +**What changed:** `NewListenCommand` handles client creation, context injection, +and `defer Close()` internally. The `configure` callback receives the fully +initialized host. + +--- + +## M4: MCP Server — Manual Construction → MCPServerBuilder + +### Before (legacy) + +```go +mcpServer := server.NewMCPServer("my-mcp", "1.0.0") + +// Manual rate limiter setup +limiter := rate.NewLimiter(rate.Limit(2.0), 10) + +// Manual tool registration with raw handler +mcpServer.AddTool( + mcp.NewTool("list_items", + mcp.WithDescription("List items"), + mcp.WithString("query", mcp.Required(), mcp.Description("Search query")), + ), + func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Manual rate limiting check + if !limiter.Allow() { + return &mcp.CallToolResult{ + IsError: true, + Content: []mcp.Content{{Type: "text", Text: ptr("rate limited")}}, + }, nil + } + // Manual arg extraction + args := req.Params.Arguments + query, ok := args["query"].(string) + if !ok { + return nil, fmt.Errorf("missing required argument: query") + } + // ... handler logic ... + return nil, nil + }, +) +``` + +### After (recommended) + +```go +mcpServer := azdext.NewMCPServerBuilder("my-mcp", "1.0.0"). + WithRateLimit(10, 2.0). + WithSecurityPolicy(azdext.DefaultMCPSecurityPolicy()). + AddTool("list_items", listItemsHandler, azdext.MCPToolOptions{ + Description: "List items", + }, + mcp.WithString("query", mcp.Required(), mcp.Description("Search query")), + ). + Build() + +func listItemsHandler(ctx context.Context, args azdext.ToolArgs) (*mcp.CallToolResult, error) { + query, err := args.RequireString("query") + if err != nil { + return azdext.MCPErrorResult("missing argument: %v", err), nil + } + // ... handler logic ... + return azdext.MCPJSONResult(results), nil +} +``` + +**What changed:** The builder handles rate-limiter wiring, security policy +attachment, and argument parsing. Tool handlers receive `ToolArgs` instead of +raw `CallToolRequest`. + +--- + +## M5: MCP Tool Arguments — Raw Map Access → ToolArgs + +### Before (legacy) + +```go +func handler(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + args := req.Params.Arguments + name, ok := args["name"].(string) + if !ok { + return nil, fmt.Errorf("missing name") + } + // JSON numbers are float64, manual conversion needed + countRaw, ok := args["count"] + count := 10 // default + if ok { + if f, ok := countRaw.(float64); ok { + count = int(f) + } + } + verbose := false + if v, ok := args["verbose"].(bool); ok { + verbose = v + } + // ... +} +``` + +### After (recommended) + +```go +func handler(ctx context.Context, args azdext.ToolArgs) (*mcp.CallToolResult, error) { + name, err := args.RequireString("name") + if err != nil { + return azdext.MCPErrorResult("%v", err), nil + } + count := args.OptionalInt("count", 10) + verbose := args.OptionalBool("verbose", false) + // ... +} +``` + +**What changed:** `ToolArgs` handles JSON `float64` → `int` conversion, type +checking, and defaults in a single method call. `Require*` methods return +errors; `Optional*` methods return defaults. + +--- + +## M6: MCP Responses — Manual Result Construction → Result Helpers + +### Before (legacy) + +```go +// Text result +textPtr := func(s string) *string { return &s } +result := &mcp.CallToolResult{ + Content: []mcp.Content{{Type: "text", Text: textPtr("Success: 5 items found")}}, +} + +// JSON result +jsonBytes, err := json.Marshal(data) +if err != nil { + return nil, err +} +result := &mcp.CallToolResult{ + Content: []mcp.Content{{Type: "text", Text: textPtr(string(jsonBytes))}}, +} + +// Error result +result := &mcp.CallToolResult{ + IsError: true, + Content: []mcp.Content{{Type: "text", Text: textPtr("failed: " + err.Error())}}, +} +``` + +### After (recommended) + +```go +return azdext.MCPTextResult("Success: %d items found", 5), nil +return azdext.MCPJSONResult(data), nil +return azdext.MCPErrorResult("failed: %v", err), nil +``` + +**What changed:** One-liners replace 3-5 lines each. `MCPJSONResult` handles +marshal errors internally (returns an error result if marshaling fails). + +--- + +## M7: SSRF / Path Validation — Custom Checks → MCPSecurityPolicy + +### Before (legacy) + +```go +func validateURL(rawURL string) error { + u, err := url.Parse(rawURL) + if err != nil { + return err + } + // Check for metadata endpoints + if u.Hostname() == "169.254.169.254" { + return fmt.Errorf("metadata endpoint blocked") + } + // Check for private IPs - incomplete, easy to miss ranges + ip := net.ParseIP(u.Hostname()) + if ip != nil && ip.IsPrivate() { + return fmt.Errorf("private network blocked") + } + // Missing: CGNAT, IPv6 transition, cloud metadata variants, etc. + return nil +} +``` + +### After (recommended) + +```go +policy := azdext.DefaultMCPSecurityPolicy() + +// Or build a custom policy: +policy := azdext.NewMCPSecurityPolicy(). + BlockMetadataEndpoints(). + BlockPrivateNetworks(). + RequireHTTPS(). + ValidatePathsWithinBase(projectDir) + +if err := policy.CheckURL(userURL); err != nil { + return azdext.MCPErrorResult("blocked: %v", err), nil +} +if err := policy.CheckPath(userPath); err != nil { + return azdext.MCPErrorResult("blocked: %v", err), nil +} +``` + +**What changed:** `MCPSecurityPolicy` covers cloud metadata (AWS, GCP, Azure +IMDS), RFC 1918, CGNAT (RFC 6598), IPv6 transition mechanisms (6to4, Teredo, +NAT64), symlink resolution, and sensitive header redaction — areas that manual +checks commonly miss. + +--- + +## M8: Service Target — Full Interface → BaseServiceTargetProvider + +### Before (legacy) + +```go +type MyProvider struct { + client *azdext.AzdClient +} + +// Must implement ALL 6 methods even if unused +func (p *MyProvider) Initialize(ctx context.Context, sc *azdext.ServiceConfig) error { + return nil +} +func (p *MyProvider) Endpoints(ctx context.Context, sc *azdext.ServiceConfig, + tr *azdext.TargetResource) ([]string, error) { + return nil, nil +} +func (p *MyProvider) GetTargetResource(ctx context.Context, subId string, + sc *azdext.ServiceConfig, defaultResolver func() (*azdext.TargetResource, error), +) (*azdext.TargetResource, error) { + return nil, nil +} +func (p *MyProvider) Package(ctx context.Context, sc *azdext.ServiceConfig, + sctx *azdext.ServiceContext, progress azdext.ProgressReporter, +) (*azdext.ServicePackageResult, error) { + return nil, nil +} +func (p *MyProvider) Publish(ctx context.Context, sc *azdext.ServiceConfig, + sctx *azdext.ServiceContext, tr *azdext.TargetResource, + opts *azdext.PublishOptions, progress azdext.ProgressReporter, +) (*azdext.ServicePublishResult, error) { + return nil, nil +} +func (p *MyProvider) Deploy(ctx context.Context, sc *azdext.ServiceConfig, + sctx *azdext.ServiceContext, tr *azdext.TargetResource, progress azdext.ProgressReporter, +) (*azdext.ServiceDeployResult, error) { + // Only method actually needed + // ... deploy logic ... + return &azdext.ServiceDeployResult{}, nil +} +``` + +### After (recommended) + +```go +type MyProvider struct { + azdext.BaseServiceTargetProvider // no-op defaults for all methods + client *azdext.AzdClient +} + +// Override only what you need +func (p *MyProvider) Deploy(ctx context.Context, sc *azdext.ServiceConfig, + sctx *azdext.ServiceContext, tr *azdext.TargetResource, progress azdext.ProgressReporter, +) (*azdext.ServiceDeployResult, error) { + // ... deploy logic ... + return &azdext.ServiceDeployResult{}, nil +} +``` + +**What changed:** Embed `BaseServiceTargetProvider` and override only the +methods you need. Eliminates ~40 lines of no-op stubs per provider. + +--- + +## M9: Metadata Command — Hand-Rolled → NewMetadataCommand + +### Before (legacy) + +```go +func newMetadataCommand() *cobra.Command { + return &cobra.Command{ + Use: "metadata", + Short: "Generate extension metadata", + Hidden: true, + RunE: func(cmd *cobra.Command, args []string) error { + rootCmd := newRootCommand() + metadata := azdext.GenerateExtensionMetadata("1.0", "my.extension", rootCmd) + jsonBytes, _ := json.MarshalIndent(metadata, "", " ") + fmt.Println(string(jsonBytes)) + return nil + }, + } +} +``` + +### After (recommended) + +```go +rootCmd.AddCommand(azdext.NewMetadataCommand("1.0", "my.extension", newRootCommand)) +``` + +--- + +## M10: Version Command — Custom → NewVersionCommand + +### Before (legacy) + +```go +func newVersionCommand() *cobra.Command { + return &cobra.Command{ + Use: "version", + Short: "Print version", + RunE: func(cmd *cobra.Command, args []string) error { + fmt.Printf("my.extension version %s\n", Version) + return nil + }, + } +} +``` + +### After (recommended) + +```go +rootCmd.AddCommand(azdext.NewVersionCommand("my.extension", "1.0.0", &extCtx.OutputFormat)) +``` + +Supports `--output json` automatically when the output-format pointer is +provided. + +--- + +## Compatibility Notes + +1. **`azdext.NewContext()` is deprecated** — it remains functional but should + not be used in new code. Use `azdext.Run` + `NewExtensionRootCommand` instead. + +2. **All migrations are additive** — existing extensions continue to compile + and run without changes. You can migrate incrementally. + +3. **Minimum azd version** — The SDK helpers ship in azd ≥ 1.23.7. If your + extension must support older azd versions, guard imports behind a build tag + or version check. + +4. **`mcp-go` dependency** — MCP helpers wrap `mark3labs/mcp-go`. Your + extension's `go.mod` must include this dependency. Run `go mod tidy` after + migration. + +--- + +## Step-by-Step Migration Checklist + +Use this checklist to migrate an existing extension: + +- [ ] **Replace entry point:** Change `main()` to use `azdext.Run(rootCmd)`. +- [ ] **Replace root command:** Use `NewExtensionRootCommand` instead of manual + flag registration and trace-context extraction. +- [ ] **Replace listen command:** Use `NewListenCommand(configure)` instead of + manual host setup. +- [ ] **Replace MCP server construction:** Use `MCPServerBuilder` with + `AddTool`, `WithRateLimit`, `WithSecurityPolicy`. +- [ ] **Replace argument parsing:** Switch tool handlers to `MCPToolHandler` + signature and use `ToolArgs` methods. +- [ ] **Replace result construction:** Use `MCPTextResult`, `MCPJSONResult`, + `MCPErrorResult`. +- [ ] **Add security policy:** Use `DefaultMCPSecurityPolicy()` or build a + custom policy. +- [ ] **Embed BaseServiceTargetProvider:** Remove no-op method stubs; embed + the base struct. +- [ ] **Replace metadata/version commands:** Use `NewMetadataCommand` and + `NewVersionCommand`. +- [ ] **Run `go mod tidy`** to pick up new dependencies. +- [ ] **Test:** Build and run the extension against azd ≥ 1.23.7. + +--- + +## See Also + +- [Extension SDK Reference](./extension-sdk-reference.md) — Full API reference. +- [Extension End-to-End Walkthrough](./extension-e2e-walkthrough.md) — Build a complete extension from scratch. +- [Extension Framework](./extension-framework.md) — General framework documentation. diff --git a/cli/azd/docs/extensions/extension-sdk-reference.md b/cli/azd/docs/extensions/extension-sdk-reference.md new file mode 100644 index 00000000000..ba91ef45e3a --- /dev/null +++ b/cli/azd/docs/extensions/extension-sdk-reference.md @@ -0,0 +1,705 @@ +# Extension SDK Reference + +This document is the API reference for the `azdext` SDK helpers introduced in [PR #6856](https://github.com/Azure/azure-dev/pull/6856). These helpers eliminate boilerplate that every azd extension must otherwise implement manually, covering command scaffolding, MCP server construction, typed argument parsing, security policy, and service-target base implementations. + +> **Package import:** `"github.com/azure/azure-dev/cli/azd/pkg/azdext"` + +--- + +## Table of Contents + +- [Entry Point & Lifecycle](#entry-point--lifecycle) + - [Run](#run) + - [RunOption / WithPreExecute](#runoption--withpreexecute) +- [Command Scaffolding](#command-scaffolding) + - [NewExtensionRootCommand](#newextensionrootcommand) + - [ExtensionCommandOptions](#extensioncommandoptions) + - [ExtensionContext](#extensioncontext) + - [NewListenCommand](#newlistencommand) + - [NewMetadataCommand](#newmetadatacommand) + - [NewVersionCommand](#newversioncommand) +- [MCP Server Builder](#mcp-server-builder) + - [NewMCPServerBuilder](#newmcpserverbuilder) + - [MCPServerBuilder Methods](#mcpserverbuilder-methods) + - [MCPToolHandler](#mcptoolhandler) + - [MCPToolOptions](#mcptooloptions) +- [Typed Argument Parsing](#typed-argument-parsing) + - [ToolArgs](#toolargs) + - [ParseToolArgs](#parsetoolargs) +- [MCP Result Helpers](#mcp-result-helpers) + - [MCPTextResult](#mcptextresult) + - [MCPJSONResult](#mcpjsonresult) + - [MCPErrorResult](#mcperrorresult) +- [MCP Security Policy](#mcp-security-policy) + - [NewMCPSecurityPolicy](#newmcpsecuritypolicy) + - [DefaultMCPSecurityPolicy](#defaultmcpsecuritypolicy) + - [MCPSecurityPolicy Methods](#mcpsecuritypolicy-methods) +- [Service Target Providers](#service-target-providers) + - [ServiceTargetProvider Interface](#servicetargetprovider-interface) + - [BaseServiceTargetProvider](#baseservicetargetprovider) +- [Extension Host](#extension-host) + - [NewExtensionHost](#newextensionhost) + - [ExtensionHost Methods](#extensionhost-methods) +- [Client & Utilities](#client--utilities) + - [AzdClient](#azdclient) + - [ConfigHelper](#confighelper) + - [TokenProvider](#tokenprovider) + - [Logger](#logger) + - [Output](#output) + - [Runtime Utilities](#runtime-utilities) + - [Shell Helpers](#shell-helpers) + - [Tool Discovery Helpers](#tool-discovery-helpers) + - [Interactive/TUI Helpers](#interactivetui-helpers) + - [Atomic File Helpers](#atomic-file-helpers) +- [Error Handling](#error-handling) + - [LocalError](#localerror) + - [ServiceError](#serviceerror) + - [LocalErrorCategory](#localerrorcategory) + +--- + +## Entry Point & Lifecycle + +### Run + +```go +func Run(rootCmd *cobra.Command, opts ...RunOption) +``` + +`Run` is the **recommended entry point** for all azd extensions. It handles the +full lifecycle that every extension needs: + +1. Reads `FORCE_COLOR` environment variable and configures `color.NoColor`. +2. Silences cobra's built-in error output (extensions control error display). +3. Creates a context with OpenTelemetry trace propagation from `TRACEPARENT`/`TRACESTATE`. +4. Injects the gRPC access token via `WithAccessToken`. +5. Executes the cobra command tree. +6. On failure, reports the error to azd via gRPC `ReportError` for structured telemetry. +7. Displays the error and any suggestion text to stderr. +8. Calls `os.Exit(1)` on failure. + +**Usage:** + +```go +func main() { + rootCmd := cmd.NewRootCommand() + azdext.Run(rootCmd) +} +``` + +### RunOption / WithPreExecute + +```go +type RunOption func(*runConfig) + +func WithPreExecute(fn func(ctx context.Context, cmd *cobra.Command) error) RunOption +``` + +`WithPreExecute` registers a hook that runs **after** context creation but +**before** command execution. If the hook returns a non-nil error, `Run` prints +it and exits. This is useful for extensions that need special setup such as +dual-mode host detection or working-directory changes. + +**Usage:** + +```go +func main() { + rootCmd := cmd.NewRootCommand() + azdext.Run(rootCmd, azdext.WithPreExecute(func(ctx context.Context, cmd *cobra.Command) error { + // Validate prerequisites + if _, err := exec.LookPath("docker"); err != nil { + return fmt.Errorf("docker is required: %w", err) + } + return nil + })) +} +``` + +--- + +## Command Scaffolding + +### NewExtensionRootCommand + +```go +func NewExtensionRootCommand(opts ExtensionCommandOptions) (*cobra.Command, *ExtensionContext) +``` + +Creates a root `cobra.Command` pre-configured for azd extensions. It +automatically: + +- Registers azd's global flags (`--debug`, `--no-prompt`, `--cwd`, + `-e`/`--environment`, `--output`). +- Reads `AZD_*` environment variables set by the azd framework. +- Sets up OpenTelemetry trace context from `TRACEPARENT`/`TRACESTATE` env vars. +- Calls `WithAccessToken()` on the command context. + +The returned command has `PersistentPreRunE` configured to populate the +`ExtensionContext` before any subcommand runs. + +**Usage:** + +```go +rootCmd, extCtx := azdext.NewExtensionRootCommand(azdext.ExtensionCommandOptions{ + Name: "my-extension", + Version: "1.0.0", + Short: "My custom azd extension", +}) + +// Add subcommands +rootCmd.AddCommand(newServeCommand(extCtx)) + +azdext.Run(rootCmd) +``` + +### ExtensionCommandOptions + +```go +type ExtensionCommandOptions struct { + Name string // Extension name (used as cobra Use field) + Version string // Extension version + Use string // Overrides the default Use string (defaults to Name) + Short string // Short description + Long string // Long description +} +``` + +### ExtensionContext + +```go +type ExtensionContext struct { + Debug bool // --debug flag value + NoPrompt bool // --no-prompt flag value + Cwd string // --cwd flag value + Environment string // -e/--environment flag value + OutputFormat string // --output flag value +} + +func (ec *ExtensionContext) Context() context.Context +``` + +`Context()` returns a `context.Context` with the tracing span and access token +already injected. Use this context for all downstream calls (gRPC, HTTP, +Azure SDK). + +### NewListenCommand + +```go +func NewListenCommand(configure func(host *ExtensionHost)) *cobra.Command +``` + +Creates the standard `listen` command for lifecycle-event extensions. The +`configure` callback receives an `ExtensionHost` to register service targets, +framework services, and event handlers before the host starts its gRPC listener. + +If `configure` is nil, the host runs with no custom registrations. + +**Usage:** + +```go +rootCmd.AddCommand(azdext.NewListenCommand(func(host *azdext.ExtensionHost) { + host.WithServiceTarget("myhost", func() azdext.ServiceTargetProvider { + return &MyProvider{} + }) + host.WithProjectEventHandler("preprovision", myHandler) +})) +``` + +### NewMetadataCommand + +```go +func NewMetadataCommand( + schemaVersion, extensionId string, + rootCmdProvider func() *cobra.Command, +) *cobra.Command +``` + +Creates the standard hidden `metadata` command that outputs extension command +metadata for IntelliSense/discovery. `rootCmdProvider` returns the root command +to introspect. + +### NewVersionCommand + +```go +func NewVersionCommand(extensionId, version string, outputFormat *string) *cobra.Command +``` + +Creates the standard `version` command. Pass a pointer to the output-format +string so JSON output is supported when `--output json` is used. + +--- + +## MCP Server Builder + +### NewMCPServerBuilder + +```go +func NewMCPServerBuilder(name, version string) *MCPServerBuilder +``` + +Creates a new builder for an MCP (Model Context Protocol) server. The builder +provides a fluent API to configure tools, resources, rate limiting, security +policies, and instructions. + +### MCPServerBuilder Methods + +| Method | Signature | Description | +|--------|-----------|-------------| +| `WithRateLimit` | `(burst int, refillRate float64) *MCPServerBuilder` | Configure a token-bucket rate limiter. `burst` = max concurrent requests; `refillRate` = tokens/second. | +| `WithSecurityPolicy` | `(policy *MCPSecurityPolicy) *MCPServerBuilder` | Attach a security policy for URL/path validation on tool calls. | +| `WithInstructions` | `(instructions string) *MCPServerBuilder` | Set system instructions that guide AI clients on how to use the server's tools. | +| `WithResourceCapabilities` | `(subscribe, listChanged bool) *MCPServerBuilder` | Enable resource support. | +| `WithPromptCapabilities` | `(listChanged bool) *MCPServerBuilder` | Enable prompt support. | +| `WithServerOption` | `(opt server.ServerOption) *MCPServerBuilder` | Add a raw `mcp-go` server option for capabilities not directly exposed by the builder. | +| `AddTool` | `(name string, handler MCPToolHandler, opts MCPToolOptions, params ...mcp.ToolOption) *MCPServerBuilder` | Register a tool with the server. The handler receives parsed `ToolArgs` (not raw `mcp.CallToolRequest`). | +| `AddResources` | `(resources ...server.ServerResource) *MCPServerBuilder` | Register static resources. | +| `Build` | `() *server.MCPServer` | Create the configured MCP server. | +| `SecurityPolicy` | `() *MCPSecurityPolicy` | Return the configured security policy, or `nil`. | + +**Usage:** + +```go +mcpServer := azdext.NewMCPServerBuilder("my-ext", "1.0.0"). + WithRateLimit(10, 2.0). + WithSecurityPolicy(azdext.DefaultMCPSecurityPolicy()). + WithInstructions("Use these tools to manage Azure resources."). + AddTool("list_resources", listHandler, azdext.MCPToolOptions{ + Description: "List Azure resources in a resource group", + }, + mcp.WithString("resourceGroup", mcp.Required(), mcp.Description("Resource group name")), + mcp.WithString("subscription", mcp.Description("Subscription ID")), + ). + Build() +``` + +### MCPToolHandler + +```go +type MCPToolHandler func(ctx context.Context, args ToolArgs) (*mcp.CallToolResult, error) +``` + +Handler function for MCP tools. The `args` parameter provides typed access to +tool arguments (see [ToolArgs](#toolargs)). + +### MCPToolOptions + +```go +type MCPToolOptions struct { + Description string // Human-readable tool description +} +``` + +--- + +## Typed Argument Parsing + +### ToolArgs + +Wraps parsed MCP tool arguments for typed, safe access. JSON numbers from MCP +requests arrive as `float64`; the `RequireInt`/`OptionalInt` methods handle +conversion automatically. + +| Method | Signature | Description | +|--------|-----------|-------------| +| `RequireString` | `(key string) (string, error)` | Returns a string or error if missing/wrong type. | +| `OptionalString` | `(key, defaultValue string) string` | Returns a string or the default. | +| `RequireInt` | `(key string) (int, error)` | Returns an int or error if missing/wrong type. | +| `OptionalInt` | `(key string, defaultValue int) int` | Returns an int or the default. | +| `OptionalBool` | `(key string, defaultValue bool) bool` | Returns a bool or the default. | +| `OptionalFloat` | `(key string, defaultValue float64) float64` | Returns a float64 or the default. | +| `Has` | `(key string) bool` | True if the key exists in the arguments. | +| `Raw` | `() map[string]interface{}` | Returns the underlying argument map. | + +### ParseToolArgs + +```go +func ParseToolArgs(request mcp.CallToolRequest) ToolArgs +``` + +Extracts the arguments map from an MCP `CallToolRequest`. + +**Usage:** + +```go +func listHandler(ctx context.Context, args azdext.ToolArgs) (*mcp.CallToolResult, error) { + rg, err := args.RequireString("resourceGroup") + if err != nil { + return azdext.MCPErrorResult("missing required argument: %v", err), nil + } + sub := args.OptionalString("subscription", "") + limit := args.OptionalInt("limit", 50) + + // ... perform operation ... + + return azdext.MCPJSONResult(results), nil +} +``` + +--- + +## MCP Result Helpers + +### MCPTextResult + +```go +func MCPTextResult(format string, args ...interface{}) *mcp.CallToolResult +``` + +Creates a text-content `CallToolResult` using `fmt.Sprintf` formatting. + +### MCPJSONResult + +```go +func MCPJSONResult(data interface{}) *mcp.CallToolResult +``` + +Marshals `data` to JSON and creates a text-content `CallToolResult`. Returns an +error result if marshaling fails. + +### MCPErrorResult + +```go +func MCPErrorResult(format string, args ...interface{}) *mcp.CallToolResult +``` + +Creates an error `CallToolResult` with `IsError` set to `true`. + +--- + +## MCP Security Policy + +The `MCPSecurityPolicy` validates URLs and file paths used by MCP tool calls to +prevent SSRF, directory traversal, and data exfiltration. + +### NewMCPSecurityPolicy + +```go +func NewMCPSecurityPolicy() *MCPSecurityPolicy +``` + +Creates an empty security policy. Chain methods to build up the desired rules. + +### DefaultMCPSecurityPolicy + +```go +func DefaultMCPSecurityPolicy() *MCPSecurityPolicy +``` + +Returns a policy with recommended defaults: + +- Cloud metadata endpoints blocked (AWS, GCP, Azure IMDS). +- RFC 1918 private networks blocked. +- HTTPS required (except localhost/127.0.0.1). +- Common sensitive headers redacted (`Authorization`, `Cookie`, `X-Api-Key`, etc.). + +### MCPSecurityPolicy Methods + +| Method | Signature | Description | +|--------|-----------|-------------| +| `BlockMetadataEndpoints` | `() *MCPSecurityPolicy` | Block cloud metadata service endpoints (`169.254.169.254`, `fd00:ec2::254`, `metadata.google.internal`, etc.). | +| `BlockPrivateNetworks` | `() *MCPSecurityPolicy` | Block RFC 1918 private networks, loopback, link-local, CGNAT (RFC 6598), and deprecated IPv6 transition mechanisms. | +| `RequireHTTPS` | `() *MCPSecurityPolicy` | Require HTTPS for all URLs except `localhost`/`127.0.0.1`. | +| `RedactHeaders` | `(headers ...string) *MCPSecurityPolicy` | Mark headers that should be blocked/redacted in outgoing requests. | +| `ValidatePathsWithinBase` | `(basePaths ...string) *MCPSecurityPolicy` | Restrict file paths to the given base directories. Resolves symlinks and blocks `../` traversal. | +| `CheckURL` | `(rawURL string) error` | Validate a URL against the policy. Returns `nil` if allowed. | +| `CheckPath` | `(path string) error` | Validate a file path against the policy. | +| `IsHeaderBlocked` | `(header string) bool` | Check if a header name is in the redacted set. | + +**Usage:** + +```go +policy := azdext.NewMCPSecurityPolicy(). + BlockMetadataEndpoints(). + BlockPrivateNetworks(). + RequireHTTPS(). + RedactHeaders("Authorization", "X-Custom-Secret"). + ValidatePathsWithinBase("/home/user/project") + +if err := policy.CheckURL(userProvidedURL); err != nil { + return azdext.MCPErrorResult("blocked URL: %v", err), nil +} +``` + +--- + +## Service Target Providers + +### ServiceTargetProvider Interface + +```go +type ServiceTargetProvider interface { + Initialize(ctx context.Context, serviceConfig *ServiceConfig) error + Endpoints(ctx context.Context, serviceConfig *ServiceConfig, targetResource *TargetResource) ([]string, error) + GetTargetResource(ctx context.Context, subscriptionId string, serviceConfig *ServiceConfig, defaultResolver func() (*TargetResource, error)) (*TargetResource, error) + Package(ctx context.Context, serviceConfig *ServiceConfig, serviceContext *ServiceContext, progress ProgressReporter) (*ServicePackageResult, error) + Publish(ctx context.Context, serviceConfig *ServiceConfig, serviceContext *ServiceContext, targetResource *TargetResource, publishOptions *PublishOptions, progress ProgressReporter) (*ServicePublishResult, error) + Deploy(ctx context.Context, serviceConfig *ServiceConfig, serviceContext *ServiceContext, targetResource *TargetResource, progress ProgressReporter) (*ServiceDeployResult, error) +} +``` + +### BaseServiceTargetProvider + +```go +type BaseServiceTargetProvider struct{} +``` + +Provides **no-op default implementations** for all `ServiceTargetProvider` +methods. Extensions should embed this struct and override only the methods they +need. + +**Usage:** + +```go +type MyProvider struct { + azdext.BaseServiceTargetProvider // embed defaults + client *azdext.AzdClient +} + +// Override only what you need +func (p *MyProvider) Package(ctx context.Context, sc *azdext.ServiceConfig, + sctx *azdext.ServiceContext, progress azdext.ProgressReporter, +) (*azdext.ServicePackageResult, error) { + progress.Report("Packaging...") + // custom packaging logic + return &azdext.ServicePackageResult{PackagePath: "/out/app.tar.gz"}, nil +} +``` + +--- + +## Extension Host + +### NewExtensionHost + +```go +func NewExtensionHost(client *AzdClient) *ExtensionHost +``` + +Creates an `ExtensionHost` that manages service targets, framework services, +and event handlers. The host starts a gRPC listener and blocks until azd shuts +down the connection. + +### ExtensionHost Methods + +| Method | Signature | Description | +|--------|-----------|-------------| +| `Client` | `() *AzdClient` | Returns the underlying gRPC client. | +| `WithServiceTarget` | `(host string, factory ServiceTargetFactory) *ExtensionHost` | Register a custom deployment target. | +| `WithFrameworkService` | `(language string, factory FrameworkServiceFactory) *ExtensionHost` | Register a custom language/framework build service. | +| `WithProjectEventHandler` | `(eventName string, handler ProjectEventHandler) *ExtensionHost` | Register a project-level lifecycle event handler. | +| `WithServiceEventHandler` | `(eventName string, handler ServiceEventHandler, options *ServiceEventOptions) *ExtensionHost` | Register a service-level lifecycle event handler (with optional filtering). | +| `Run` | `(ctx context.Context) error` | Start the host and block until shutdown. | + +--- + +## Client & Utilities + +### AzdClient + +```go +func NewAzdClient(opts ...AzdClientOption) (*AzdClient, error) +``` + +gRPC client connecting to the azd framework. Auto-discovers the socket via +`AZD_RPC_SERVER_ENDPOINT`. Provides typed accessors for all framework services: + +| Accessor | Returns | +|----------|---------| +| `Project()` | `ProjectServiceClient` | +| `Environment()` | `EnvironmentServiceClient` | +| `UserConfig()` | `UserConfigServiceClient` | +| `Prompt()` | `PromptServiceClient` | +| `Deployment()` | `DeploymentServiceClient` | +| `Events()` | `EventServiceClient` | +| `Compose()` | `ComposeServiceClient` | +| `Workflow()` | `WorkflowServiceClient` | +| `ServiceTarget()` | `ServiceTargetServiceClient` | +| `FrameworkService()` | `FrameworkServiceClient` | +| `Container()` | `ContainerServiceClient` | +| `Extension()` | `ExtensionServiceClient` | +| `Account()` | `AccountServiceClient` | +| `Ai()` | `AiModelServiceClient` | + +Always call `defer client.Close()` after creation. + +### ConfigHelper + +```go +func NewConfigHelper(client *AzdClient) (*ConfigHelper, error) +``` + +Provides read/write access to azd user and environment configuration: + +| Method | Description | +|--------|-------------| +| `GetUserString(ctx, path)` | Read a string from user config. | +| `GetUserJSON(ctx, path, out)` | Unmarshal user config into a struct. | +| `SetUserJSON(ctx, path, value)` | Write a value to user config. | +| `UnsetUser(ctx, path)` | Remove a user config key. | +| `GetEnvString(ctx, path)` | Read a string from env config. | +| `GetEnvJSON(ctx, path, out)` | Unmarshal env config into a struct. | +| `SetEnvJSON(ctx, path, value)` | Write a value to env config. | +| `UnsetEnv(ctx, path)` | Remove an env config key. | + +Utility functions: + +- `MergeJSON(base, override)` — Shallow-merge two JSON maps. +- `DeepMergeJSON(base, override)` — Deep recursive merge. +- `ValidateConfig(path, data, validators...)` — Validate config data. +- `RequiredKeys(keys...)` — Returns a `ConfigValidator` that checks for required keys. + +### TokenProvider + +```go +func NewTokenProvider(ctx context.Context, client *AzdClient, opts *TokenProviderOptions) (*TokenProvider, error) +``` + +Obtains Azure access tokens for authenticated API calls. Implements +`azcore.TokenCredential` semantics. + +| Method | Description | +|--------|-------------| +| `GetToken(ctx, options)` | Returns an `azcore.AccessToken`. | +| `TenantID()` | Returns the resolved tenant ID. | + +### Logger + +```go +func NewLogger(component string, opts ...LoggerOptions) *Logger +``` + +Structured logging with component tagging, backed by `slog`: + +| Method | Description | +|--------|-------------| +| `Debug(msg, args...)` | Log at DEBUG level. | +| `Info(msg, args...)` | Log at INFO level. | +| `Warn(msg, args...)` | Log at WARN level. | +| `Error(msg, args...)` | Log at ERROR level. | +| `With(args...)` | Create a child logger with additional fields. | +| `WithComponent(name)` | Create a child logger for a sub-component. | +| `WithOperation(name)` | Create a child logger tagged with an operation name. | +| `Slogger()` | Return the underlying `*slog.Logger`. | + +Call `azdext.SetupLogging(LoggerOptions{Debug: true})` during initialization to +configure the global log level. + +### Output + +```go +func NewOutput(opts OutputOptions) *Output +``` + +Format-aware output (text or JSON): + +| Method | Description | +|--------|-------------| +| `IsJSON()` | True if output format is JSON. | +| `Success(fmt, args...)` | Print a success message (green). | +| `Warning(fmt, args...)` | Print a warning (yellow). | +| `Error(fmt, args...)` | Print an error (red). | +| `Info(fmt, args...)` | Print informational text. | +| `Message(fmt, args...)` | Print plain text. | +| `JSON(data)` | Marshal and print JSON. | +| `Table(headers, rows)` | Print a formatted table. | + +### Runtime Utilities + +These helpers are intended to remove common extension boilerplate for shell execution, tool checks, TTY detection, and safe file writes. + +#### Shell Helpers + +| API | Description | +|-----|-------------| +| `DetectShell()` | Detects the current shell using `SHELL`, `PSModulePath`, `ComSpec`, then platform defaults. | +| `ShellCommand(ctx, script)` | Builds an `exec.Cmd` using detected shell conventions (`cmd /C`, `pwsh -Command`, ` -c`). | +| `ShellCommandWith(ctx, info, script)` | Same as `ShellCommand` but uses explicit `ShellInfo` for deterministic behavior/testing. | +| `IsInteractiveTerminal(f)` / `IsStdinTerminal()` / `IsStdoutTerminal()` | Terminal detection helpers. | + +#### Tool Discovery Helpers + +| API | Description | +|-----|-------------| +| `LookupTool(name)` | Looks up tools on `PATH` and also checks the current project directory for local wrappers (for example `./mvnw`). | +| `LookupTools(names...)` | Batch lookup for multiple tools. | +| `RequireTools(names...)` | Returns a typed error when required tools are missing. | +| `PrependPATH` / `AppendPATH` / `PATHContains` | Cross-platform `PATH` mutation and detection helpers. | + +#### Interactive/TUI Helpers + +| API | Description | +|-----|-------------| +| `DetectInteractive()` | Detects TTY mode (`full` / `limited` / `none`), `AZD_NO_PROMPT`, CI, and known agent environments. | +| `InteractiveInfo.CanPrompt()` | Safe prompt gate (`stdin/stdout tty`, not no-prompt, not CI, not agent). | +| `InteractiveInfo.CanColorize()` | Color output gate honoring `FORCE_COLOR` and `NO_COLOR`. | + +#### Atomic File Helpers + +| API | Description | +|-----|-------------| +| `WriteFileAtomic(path, data, perm)` | Writes via temp-file + atomic rename, with Windows rename retry behavior. | +| `CopyFileAtomic(src, dst, perm)` | Atomic copy via `WriteFileAtomic`. | +| `BackupFile(path, suffix)` | Creates an atomic backup file (`.bak` by default). | +| `EnsureDir(dir, perm)` | Convenience wrapper around `os.MkdirAll` with extension-prefixed errors. | + +--- + +## Error Handling + +### LocalError + +```go +type LocalError struct { + Message string + Code string + Category LocalErrorCategory + Suggestion string +} +``` + +Represents an error originating within the extension. The `Suggestion` field +provides actionable guidance displayed to the user. + +### ServiceError + +```go +type ServiceError struct { + Message string + ErrorCode string + StatusCode int + ServiceName string + Suggestion string +} +``` + +Represents an error from an Azure service call. + +### LocalErrorCategory + +```go +type LocalErrorCategory string + +const ( + LocalErrorCategoryValidation LocalErrorCategory = "validation" + LocalErrorCategoryAuth LocalErrorCategory = "auth" + LocalErrorCategoryDependency LocalErrorCategory = "dependency" + LocalErrorCategoryCompatibility LocalErrorCategory = "compatibility" + LocalErrorCategoryUser LocalErrorCategory = "user" + LocalErrorCategoryInternal LocalErrorCategory = "internal" + LocalErrorCategoryLocal LocalErrorCategory = "local" +) +``` + +Error categories enable structured telemetry classification and targeted error +guidance. Use `WrapError(err)` to convert a `LocalError` or `ServiceError` to +the gRPC `ExtensionError` proto for reporting. + +--- + +## See Also + +- [Extension Framework](./extension-framework.md) — Getting started, managing extensions, developing extensions. +- [Extension Migration Guide](./extension-migration-guide.md) — Migrate from pre-#6856 patterns to new SDK helpers. +- [Extension End-to-End Walkthrough](./extension-e2e-walkthrough.md) — Build a complete extension from scratch. +- [Extension Framework Services](./extension-framework-services.md) — gRPC service reference for custom language frameworks. +- [Extension Style Guide](./extensions-style-guide.md) — Design guidelines and best practices. diff --git a/cli/azd/extensions/azure.appservice/go.mod b/cli/azd/extensions/azure.appservice/go.mod index 4f681d16180..364b90cad04 100644 --- a/cli/azd/extensions/azure.appservice/go.mod +++ b/cli/azd/extensions/azure.appservice/go.mod @@ -15,14 +15,20 @@ require ( dario.cat/mergo v1.0.2 // indirect github.com/AlecAivazis/survey/v2 v2.3.7 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect + github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/keyvault/armkeyvault v1.5.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armsubscriptions v1.3.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azsecrets v1.4.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.2.0 // indirect github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 // indirect github.com/Masterminds/semver/v3 v3.4.0 // indirect + github.com/adam-lavrik/go-imath v0.0.0-20210910152346-265a42a96f0b // indirect github.com/alecthomas/chroma/v2 v2.23.1 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/aymerick/douceur v0.2.0 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/blang/semver/v4 v4.0.0 // indirect github.com/braydonk/yaml v0.9.0 // indirect + github.com/buger/goterm v1.0.4 // indirect github.com/buger/jsonparser v1.1.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/charmbracelet/colorprofile v0.4.1 // indirect @@ -32,6 +38,7 @@ require ( github.com/charmbracelet/x/cellbuf v0.0.15 // indirect github.com/charmbracelet/x/exp/slice v0.0.0-20260204111555-7642919e0bee // indirect github.com/charmbracelet/x/term v0.2.2 // indirect + github.com/cli/browser v1.3.0 // indirect github.com/clipperhouse/displaywidth v0.9.0 // indirect github.com/clipperhouse/stringish v0.1.1 // indirect github.com/clipperhouse/uax29/v2 v2.5.0 // indirect @@ -40,6 +47,7 @@ require ( github.com/drone/envsubst v1.0.3 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect + github.com/gofrs/flock v0.12.1 // indirect github.com/golang-jwt/jwt/v5 v5.3.1 // indirect github.com/golobby/container/v3 v3.3.2 // indirect github.com/google/uuid v1.6.0 // indirect @@ -48,6 +56,7 @@ require ( github.com/invopop/jsonschema v0.13.0 // indirect github.com/jmespath-community/go-jmespath v1.1.1 // indirect github.com/joho/godotenv v1.5.1 // indirect + github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect github.com/kylelemons/godebug v1.1.0 // indirect github.com/lucasb-eyer/go-colorful v1.3.0 // indirect github.com/mailru/easyjson v0.9.1 // indirect @@ -55,6 +64,7 @@ require ( github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.19 // indirect + github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d // indirect github.com/microcosm-cc/bluemonday v1.0.27 // indirect github.com/microsoft/ApplicationInsights-Go v0.4.4 // indirect github.com/microsoft/go-deviceid v1.0.0 // indirect @@ -69,6 +79,7 @@ require ( github.com/spf13/cast v1.10.0 // indirect github.com/spf13/pflag v1.0.10 // indirect github.com/stretchr/testify v1.11.1 // indirect + github.com/theckman/yacspin v0.13.12 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect @@ -80,6 +91,7 @@ require ( go.opentelemetry.io/otel/sdk v1.40.0 // indirect go.opentelemetry.io/otel/trace v1.40.0 // indirect go.uber.org/atomic v1.11.0 // indirect + go.uber.org/multierr v1.11.0 // indirect golang.org/x/crypto v0.47.0 // indirect golang.org/x/exp v0.0.0-20260112195511-716be5621a96 // indirect golang.org/x/net v0.49.0 // indirect @@ -93,4 +105,9 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect ) +// replace: The extension lives inside the azd monorepo and depends on the parent +// module at cli/azd. During local development (and in CI for this repo) the +// replace directive lets `go build` resolve the dependency from the working tree +// instead of requiring a published module version. It is stripped automatically +// by `go mod tidy` when the module is consumed as a standalone dependency. replace github.com/azure/azure-dev/cli/azd => ../.. diff --git a/cli/azd/extensions/azure.appservice/go.sum b/cli/azd/extensions/azure.appservice/go.sum index 27760cc3a93..70b42074618 100644 --- a/cli/azd/extensions/azure.appservice/go.sum +++ b/cli/azd/extensions/azure.appservice/go.sum @@ -13,12 +13,27 @@ github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 h1:9iefClla7iYpfYWdzPCRDo github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2/go.mod h1:XtLgD3ZD34DAaVIIAyG3objl5DynM3CQ/vMcbBNJZGI= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/appservice/armappservice/v2 v2.3.0 h1:JI8PcWOImyvIUEZ0Bbmfe05FOlWkMi2KhjG+cAKaUms= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/appservice/armappservice/v2 v2.3.0/go.mod h1:nJLFPGJkyKfDDyJiPuHIXsCi/gpJkm07EvRgiX7SGlI= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v2 v2.0.0 h1:PTFGRSlMKCQelWwxUyYVEUqseBJVemLyqWJjvMyt0do= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v2 v2.0.0/go.mod h1:LRr2FzBTQlONPPa5HREE5+RjSCTXl7BwOvYOaWTqCaI= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v3 v3.1.0 h1:2qsIIvxVT+uE6yrNldntJKlLRgxGbZ85kgtz5SNBhMw= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v3 v3.1.0/go.mod h1:AW8VEadnhw9xox+VaVd9sP7NjzOAnaZBLRH6Tq3cJ38= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/keyvault/armkeyvault v1.5.0 h1:nnQ9vXH039UrEFxi08pPuZBE7VfqSJt343uJLw0rhWI= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/keyvault/armkeyvault v1.5.0/go.mod h1:4YIVtzMFVsPwBvitCDX7J9sqthSj43QD1sP6fYc1egc= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources v1.2.0 h1:Dd+RhdJn0OTtVGaeDLZpcumkIVCtA/3/Fo42+eoYvVM= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources v1.2.0/go.mod h1:5kakwfW5CjC9KK+Q4wjXAg+ShuIm2mBMua0ZFj2C8PE= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armsubscriptions v1.3.0 h1:wxQx2Bt4xzPIKvW59WQf1tJNx/ZZKPfN+EhPX3Z6CYY= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armsubscriptions v1.3.0/go.mod h1:TpiwjwnW/khS0LKs4vW5UmmT9OWcxaveS8U7+tlknzo= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azsecrets v1.4.0 h1:/g8S6wk65vfC6m3FIxJ+i5QDyN9JWwXI8Hb0Img10hU= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azsecrets v1.4.0/go.mod h1:gpl+q95AzZlKVI3xSoseF9QPrypk0hQqBiJYeB/cR/I= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.2.0 h1:nCYfgcSyHZXJI8J0IWE5MsCGlb2xp9fJiXyxWgmOFg4= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.2.0/go.mod h1:ucUjca2JtSZboY8IoUqyQyuuXvwbMBVwFOm0vdQPNhA= github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJTmL004Abzc5wDB5VtZG2PJk5ndYDgVacGqfirKxjM= github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE= github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs= github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= github.com/Masterminds/semver/v3 v3.4.0 h1:Zog+i5UMtVoCU8oKka5P7i9q9HgrJeGzI9SA1Xbatp0= github.com/Masterminds/semver/v3 v3.4.0/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM= +github.com/Netflix/go-expect v0.0.0-20220104043353-73e0943537d2 h1:+vx7roKuyA63nhn5WAunQHLTznkw5W8b1Xc0dNjp83s= github.com/Netflix/go-expect v0.0.0-20220104043353-73e0943537d2/go.mod h1:HBCaDeC1lPdgDeDbhX8XFpy1jqjK0IBG8W5K+xYqA0w= github.com/adam-lavrik/go-imath v0.0.0-20210910152346-265a42a96f0b h1:g9SuFmxM/WucQFKTMSP+irxyf5m0RiUJreBDhGI6jSA= github.com/adam-lavrik/go-imath v0.0.0-20210910152346-265a42a96f0b/go.mod h1:XjvqMUpGd3Xn9Jtzk/4GEBCSoBX0eB2RyriXgne0IdM= @@ -66,6 +81,8 @@ github.com/charmbracelet/x/exp/slice v0.0.0-20260204111555-7642919e0bee h1:B/JPE github.com/charmbracelet/x/exp/slice v0.0.0-20260204111555-7642919e0bee/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA= github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk= github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI= +github.com/cli/browser v1.3.0 h1:LejqCrpWr+1pRqmEPDGnTZOjsMe7sehifLynZJuqJpo= +github.com/cli/browser v1.3.0/go.mod h1:HH8s+fOAxjhQoBUAsKuPCbqUuxZDhQ2/aD+SzsEfBTk= github.com/clipperhouse/displaywidth v0.9.0 h1:Qb4KOhYwRiN3viMv1v/3cTBlz3AcAZX3+y9OLhMtAtA= github.com/clipperhouse/displaywidth v0.9.0/go.mod h1:aCAAqTlh4GIVkhQnJpbL0T/WfcrJXHcj8C0yjYcjOZA= github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs= @@ -73,6 +90,7 @@ github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEX github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U= github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/creack/pty v1.1.17 h1:QeVUsEDNrLBW4tMgZHvxy18sKtr6VI492kBhUfhDJNI= github.com/creack/pty v1.1.17/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -92,6 +110,8 @@ github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/gofrs/flock v0.12.1 h1:MTLVXXHf8ekldpJk3AKicLij9MdwOWkZ+a/jHHZby9E= +github.com/gofrs/flock v0.12.1/go.mod h1:9zxTsyu5xtJ9DK+1tFZyibEV7y3uwDxPPfbxeeHCoD0= github.com/gofrs/uuid v3.3.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= @@ -109,6 +129,7 @@ github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8= github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0= github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= +github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec h1:qv2VnGeEQHchGaZ/u7lxST/RaJw+cv273q79D81Xbog= github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= @@ -229,6 +250,8 @@ go.opentelemetry.io/otel/trace v1.40.0 h1:WA4etStDttCSYuhwvEa8OP8I5EWu24lkOzp+ZY go.opentelemetry.io/otel/trace v1.40.0/go.mod h1:zeAhriXecNGP/s2SEG3+Y8X9ujcJOTqQ5RgdEJcawiA= go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= @@ -250,6 +273,7 @@ golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210331175145-43e1dd70ce54/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/cli/azd/extensions/azure.appservice/internal/cmd/swap.go b/cli/azd/extensions/azure.appservice/internal/cmd/swap.go index b0ffa6e4ca6..dcbbd3451c3 100644 --- a/cli/azd/extensions/azure.appservice/internal/cmd/swap.go +++ b/cli/azd/extensions/azure.appservice/internal/cmd/swap.go @@ -143,7 +143,12 @@ func runSwap(ctx context.Context, flags *swapFlags, rootFlags rootFlagsDefinitio if err != nil { return fmt.Errorf("selecting service: %w", err) } - selectedService = appserviceServices[prompt.GetValue()] + + idx := int(prompt.GetValue()) + if idx < 0 || idx >= len(appserviceServices) { + return fmt.Errorf("invalid service selection index: %d", idx) + } + selectedService = appserviceServices[idx] } color.Cyan("Using service: %s", selectedService.Name) @@ -207,6 +212,18 @@ func runSwap(ctx context.Context, flags *swapFlags, rootFlags rootFlagsDefinitio srcProvided := flags.src != "" dstProvided := flags.dst != "" + // Validate slot name format before any further processing. + if srcProvided { + if err := validateSlotName(srcSlot); err != nil { + return err + } + } + if dstProvided { + if err := validateSlotName(dstSlot); err != nil { + return err + } + } + // Build the list of all slot names (including production as empty string) slotNames := []string{""} // Production is represented as empty string for _, slot := range slots { @@ -250,7 +267,11 @@ func runSwap(ctx context.Context, flags *swapFlags, rootFlags rootFlagsDefinitio return fmt.Errorf("selecting source slot: %w", err) } - srcSlot = srcChoices[prompt.GetValue()].Value + idx := int(prompt.GetValue()) + if idx < 0 || idx >= len(srcChoices) { + return fmt.Errorf("invalid source slot selection index: %d", idx) + } + srcSlot = srcChoices[idx].Value } // Prompt for destination slot (excluding the selected source) @@ -275,7 +296,11 @@ func runSwap(ctx context.Context, flags *swapFlags, rootFlags rootFlagsDefinitio return fmt.Errorf("selecting destination slot: %w", err) } - dstSlot = dstChoices[prompt.GetValue()].Value + idx := int(prompt.GetValue()) + if idx < 0 || idx >= len(dstChoices) { + return fmt.Errorf("invalid destination slot selection index: %d", idx) + } + dstSlot = dstChoices[idx].Value } } @@ -340,6 +365,27 @@ func normalizeSlotName(slot string) string { return slot } +// validateSlotName checks that a slot name is safe to use as an Azure App +// Service deployment slot identifier. Empty string is allowed (represents +// the production slot). Valid slot names contain only alphanumeric characters, +// hyphens, and underscores (matching Azure's naming constraints). +func validateSlotName(name string) error { + if name == "" { + return nil // empty = production slot + } + for i, r := range name { + if !((r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || + (r >= '0' && r <= '9') || r == '-' || r == '_') { + return fmt.Errorf("invalid slot name %q: contains invalid character %q at position %d "+ + "(only alphanumeric, hyphens, and underscores are allowed)", name, string(r), i) + } + } + if len(name) > 64 { + return fmt.Errorf("invalid slot name %q: exceeds 64-character limit", name) + } + return nil +} + func isValidSlotName(name string, availableSlots []string) bool { for _, slot := range availableSlots { if slot == name { diff --git a/cli/azd/extensions/azure.appservice/main.go b/cli/azd/extensions/azure.appservice/main.go index 1491b685a3c..a06d026d336 100644 --- a/cli/azd/extensions/azure.appservice/main.go +++ b/cli/azd/extensions/azure.appservice/main.go @@ -4,28 +4,11 @@ package main import ( - "os" - "azureappservice/internal/cmd" "github.com/azure/azure-dev/cli/azd/pkg/azdext" - "github.com/fatih/color" ) -func init() { - forceColorVal, has := os.LookupEnv("FORCE_COLOR") - if has && forceColorVal == "1" { - color.NoColor = false - } -} - func main() { - // Execute the root command - ctx := azdext.NewContext() - rootCmd := cmd.NewRootCommand() - - if err := rootCmd.ExecuteContext(ctx); err != nil { - color.Red("Error: %v", err) - os.Exit(1) - } + azdext.Run(cmd.NewRootCommand()) } diff --git a/cli/azd/pkg/azdext/atomicfile.go b/cli/azd/pkg/azdext/atomicfile.go new file mode 100644 index 00000000000..af4b19fcd64 --- /dev/null +++ b/cli/azd/pkg/azdext/atomicfile.go @@ -0,0 +1,196 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "context" + "fmt" + "io" + "os" + "path/filepath" + + "github.com/azure/azure-dev/cli/azd/pkg/osutil" +) + +// --------------------------------------------------------------------------- +// Atomic file operations +// --------------------------------------------------------------------------- + +// WriteFileAtomic writes data to the named file atomically. It writes to a +// temporary file in the same directory as path and renames it into place. This +// ensures that readers never see a partially-written file and that the +// operation is crash-safe on filesystems that support atomic rename (ext4, +// APFS, NTFS). +// +// Platform behavior: +// - Unix: os.Rename is atomic within the same filesystem. +// - Windows: os.Rename replaces the target if it exists (Go 1.16+). On +// older Go runtimes or cross-device moves, the operation may fail. +// WriteFileAtomic always places the temp file in the same directory to +// avoid cross-device issues. +// +// The file is created with the specified permissions. If the target already +// exists its permissions are preserved unless perm is explicitly non-zero. +// +// Returns an error if the directory does not exist, the temp file cannot be +// created, data cannot be written, or the rename fails. +func WriteFileAtomic(path string, data []byte, perm os.FileMode) error { + dir := filepath.Dir(path) + + // Validate that the target directory exists. + if _, err := os.Stat(dir); err != nil { + return fmt.Errorf("azdext.WriteFileAtomic: target directory: %w", err) + } + + // If perm is zero and the target exists, preserve existing permissions. + if perm == 0 { + if fi, err := os.Stat(path); err == nil { + perm = fi.Mode().Perm() + } else { + perm = 0o644 + } + } + + // Create temp file in the same directory (same filesystem = atomic rename). + tmp, err := os.CreateTemp(dir, ".azdext-atomic-*") + if err != nil { + return fmt.Errorf("azdext.WriteFileAtomic: create temp: %w", err) + } + tmpPath := tmp.Name() + + // Ensure cleanup on any failure path. + success := false + defer func() { + if !success { + _ = tmp.Close() + _ = os.Remove(tmpPath) + } + }() + + // Write data and sync to disk. + if _, err := tmp.Write(data); err != nil { + return fmt.Errorf("azdext.WriteFileAtomic: write: %w", err) + } + if err := tmp.Sync(); err != nil { + return fmt.Errorf("azdext.WriteFileAtomic: sync: %w", err) + } + if err := tmp.Close(); err != nil { + return fmt.Errorf("azdext.WriteFileAtomic: close: %w", err) + } + + // Set permissions on temp file before rename. + //nolint:gosec // G703: tmpPath is constructed internally + if err := os.Chmod(tmpPath, perm); err != nil { + return fmt.Errorf("azdext.WriteFileAtomic: chmod: %w", err) + } + + // Atomic rename into place. + if err := osutil.Rename(context.Background(), tmpPath, path); err != nil { + return fmt.Errorf("azdext.WriteFileAtomic: rename: %w", err) + } + + success = true + return nil +} + +// CopyFileAtomic copies src to dst atomically using the write-temp-rename +// pattern. The destination file is never in a partially-written state. +// +// Platform behavior: see [WriteFileAtomic]. +// +// If perm is zero, the source file's permissions are used. +func CopyFileAtomic(src, dst string, perm os.FileMode) error { + srcFile, err := os.Open(src) + if err != nil { + return fmt.Errorf("azdext.CopyFileAtomic: open source: %w", err) + } + defer srcFile.Close() + + // Determine permissions. + if perm == 0 { + if fi, err := srcFile.Stat(); err == nil { + perm = fi.Mode().Perm() + } else { + perm = 0o644 + } + } + + dir := filepath.Dir(dst) + if _, err := os.Stat(dir); err != nil { + return fmt.Errorf("azdext.CopyFileAtomic: target directory: %w", err) + } + + tmp, err := os.CreateTemp(dir, ".azdext-atomic-*") + if err != nil { + return fmt.Errorf("azdext.CopyFileAtomic: create temp: %w", err) + } + tmpPath := tmp.Name() + + success := false + defer func() { + if !success { + _ = tmp.Close() + _ = os.Remove(tmpPath) + } + }() + + if _, err := io.Copy(tmp, srcFile); err != nil { + return fmt.Errorf("azdext.CopyFileAtomic: copy source: %w", err) + } + if err := tmp.Sync(); err != nil { + return fmt.Errorf("azdext.CopyFileAtomic: sync: %w", err) + } + if err := tmp.Close(); err != nil { + return fmt.Errorf("azdext.CopyFileAtomic: close: %w", err) + } + //nolint:gosec // G703: tmpPath is constructed internally + if err := os.Chmod(tmpPath, perm); err != nil { + return fmt.Errorf("azdext.CopyFileAtomic: chmod: %w", err) + } + if err := osutil.Rename(context.Background(), tmpPath, dst); err != nil { + return fmt.Errorf("azdext.CopyFileAtomic: rename: %w", err) + } + + success = true + return nil +} + +// BackupFile creates a backup copy of path at path+suffix using atomic copy. +// If the source file does not exist, it returns nil (no backup needed). +// +// The default suffix is ".bak" if suffix is empty. +// +// Returns the backup path on success, or an error if the copy fails. +func BackupFile(path, suffix string) (string, error) { + if suffix == "" { + suffix = ".bak" + } + + if _, err := os.Stat(path); os.IsNotExist(err) { + return "", nil // Nothing to back up. + } + + backupPath := path + suffix + if err := CopyFileAtomic(path, backupPath, 0); err != nil { + return "", fmt.Errorf("azdext.BackupFile: %w", err) + } + + return backupPath, nil +} + +// EnsureDir creates directory dir and any necessary parents with the given +// permissions. If the directory already exists, EnsureDir is a no-op and +// returns nil. +// +// This is a convenience wrapper around [os.MkdirAll] with an explicit error +// prefix for diagnostics. +func EnsureDir(dir string, perm os.FileMode) error { + if perm == 0 { + perm = 0o755 + } + if err := os.MkdirAll(dir, perm); err != nil { + return fmt.Errorf("azdext.EnsureDir: %w", err) + } + return nil +} diff --git a/cli/azd/pkg/azdext/atomicfile_test.go b/cli/azd/pkg/azdext/atomicfile_test.go new file mode 100644 index 00000000000..4c50517a3f5 --- /dev/null +++ b/cli/azd/pkg/azdext/atomicfile_test.go @@ -0,0 +1,322 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "bytes" + "os" + "path/filepath" + "testing" +) + +// --------------------------------------------------------------------------- +// WriteFileAtomic +// --------------------------------------------------------------------------- + +func TestWriteFileAtomic_NewFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.txt") + data := []byte("hello, atomic world") + + if err := WriteFileAtomic(path, data, 0o644); err != nil { + t.Fatalf("WriteFileAtomic() error: %v", err) + } + + // Verify content. + got, err := os.ReadFile(path) + if err != nil { + t.Fatalf("ReadFile() error: %v", err) + } + if string(got) != string(data) { + t.Errorf("WriteFileAtomic() content = %q, want %q", got, data) + } + + // Verify permissions. + fi, err := os.Stat(path) + if err != nil { + t.Fatalf("Stat() error: %v", err) + } + // On Windows, permission bits are limited. Just verify the file exists. + if fi.Size() != int64(len(data)) { + t.Errorf("WriteFileAtomic() file size = %d, want %d", fi.Size(), len(data)) + } +} + +func TestWriteFileAtomic_Overwrite(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.txt") + + // Write initial content. + if err := os.WriteFile(path, []byte("old"), 0o600); err != nil { + t.Fatalf("WriteFile() error: %v", err) + } + + // Overwrite atomically. + newData := []byte("new content here") + if err := WriteFileAtomic(path, newData, 0o644); err != nil { + t.Fatalf("WriteFileAtomic() error: %v", err) + } + + got, err := os.ReadFile(path) + if err != nil { + t.Fatalf("ReadFile() error: %v", err) + } + if string(got) != string(newData) { + t.Errorf("WriteFileAtomic() overwrite content = %q, want %q", got, newData) + } +} + +func TestWriteFileAtomic_EmptyData(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "empty.txt") + + if err := WriteFileAtomic(path, []byte{}, 0o644); err != nil { + t.Fatalf("WriteFileAtomic(empty) error: %v", err) + } + + fi, err := os.Stat(path) + if err != nil { + t.Fatalf("Stat() error: %v", err) + } + if fi.Size() != 0 { + t.Errorf("WriteFileAtomic(empty) size = %d, want 0", fi.Size()) + } +} + +func TestWriteFileAtomic_MissingDirectory(t *testing.T) { + path := filepath.Join(t.TempDir(), "nonexistent", "file.txt") + err := WriteFileAtomic(path, []byte("data"), 0o644) + if err == nil { + t.Error("WriteFileAtomic() with missing directory = nil, want error") + } +} + +func TestWriteFileAtomic_PreservesPermissions(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "perm.txt") + + // Write with specific permissions. + if err := WriteFileAtomic(path, []byte("first"), 0o600); err != nil { + t.Fatalf("WriteFileAtomic(first) error: %v", err) + } + + // Overwrite with perm=0 to preserve existing. + if err := WriteFileAtomic(path, []byte("second"), 0); err != nil { + t.Fatalf("WriteFileAtomic(second) error: %v", err) + } + + got, err := os.ReadFile(path) + if err != nil { + t.Fatalf("ReadFile() error: %v", err) + } + if string(got) != "second" { + t.Errorf("content = %q, want %q", got, "second") + } +} + +func TestWriteFileAtomic_NoTempFileLeftOnFailure(t *testing.T) { + dir := t.TempDir() + + // Write a valid file first. + path := filepath.Join(dir, "test.txt") + if err := WriteFileAtomic(path, []byte("ok"), 0o644); err != nil { + t.Fatalf("WriteFileAtomic() error: %v", err) + } + + // List directory to establish baseline. + before, _ := os.ReadDir(dir) + beforeCount := len(before) + + // Attempt write to a non-existent sub-directory (should fail). + badPath := filepath.Join(dir, "nodir", "bad.txt") + _ = WriteFileAtomic(badPath, []byte("fail"), 0o644) + + // Verify no temp files were left behind. + after, _ := os.ReadDir(dir) + if len(after) != beforeCount { + t.Errorf("temp file leak: before=%d entries, after=%d entries", beforeCount, len(after)) + } +} + +// --------------------------------------------------------------------------- +// CopyFileAtomic +// --------------------------------------------------------------------------- + +func TestCopyFileAtomic(t *testing.T) { + dir := t.TempDir() + src := filepath.Join(dir, "source.txt") + dst := filepath.Join(dir, "dest.txt") + data := []byte("copy me atomically") + + if err := os.WriteFile(src, data, 0o600); err != nil { + t.Fatalf("WriteFile(src) error: %v", err) + } + + if err := CopyFileAtomic(src, dst, 0); err != nil { + t.Fatalf("CopyFileAtomic() error: %v", err) + } + + got, err := os.ReadFile(dst) + if err != nil { + t.Fatalf("ReadFile(dst) error: %v", err) + } + if string(got) != string(data) { + t.Errorf("CopyFileAtomic() content = %q, want %q", got, data) + } +} + +func TestCopyFileAtomic_SourceNotFound(t *testing.T) { + dir := t.TempDir() + err := CopyFileAtomic(filepath.Join(dir, "missing.txt"), filepath.Join(dir, "dst.txt"), 0) + if err == nil { + t.Error("CopyFileAtomic() with missing source = nil, want error") + } +} + +func TestCopyFileAtomic_LargeFileStreaming(t *testing.T) { + dir := t.TempDir() + src := filepath.Join(dir, "large-source.bin") + dst := filepath.Join(dir, "large-dest.bin") + + const size = 12 << 20 // 12 MiB + chunk := bytes.Repeat([]byte("azdext-streaming-copy"), 512) + + f, err := os.Create(src) + if err != nil { + t.Fatalf("Create(src) error: %v", err) + } + written := 0 + for written < size { + n := size - written + if n > len(chunk) { + n = len(chunk) + } + copied, err := f.Write(chunk[:n]) + if err != nil { + _ = f.Close() + t.Fatalf("Write(src) error: %v", err) + } + written += copied + } + if err := f.Close(); err != nil { + t.Fatalf("Close(src) error: %v", err) + } + + if err := CopyFileAtomic(src, dst, 0); err != nil { + t.Fatalf("CopyFileAtomic(large) error: %v", err) + } + + srcData, err := os.ReadFile(src) + if err != nil { + t.Fatalf("ReadFile(src) error: %v", err) + } + dstData, err := os.ReadFile(dst) + if err != nil { + t.Fatalf("ReadFile(dst) error: %v", err) + } + if !bytes.Equal(srcData, dstData) { + t.Fatal("large file copy mismatch") + } +} + +// --------------------------------------------------------------------------- +// BackupFile +// --------------------------------------------------------------------------- + +func TestBackupFile_CreatesBackup(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "config.yaml") + data := []byte("config: value") + + if err := os.WriteFile(path, data, 0o600); err != nil { + t.Fatalf("WriteFile() error: %v", err) + } + + backupPath, err := BackupFile(path, ".bak") + if err != nil { + t.Fatalf("BackupFile() error: %v", err) + } + if backupPath != path+".bak" { + t.Errorf("BackupFile() path = %q, want %q", backupPath, path+".bak") + } + + got, err := os.ReadFile(backupPath) + if err != nil { + t.Fatalf("ReadFile(backup) error: %v", err) + } + if string(got) != string(data) { + t.Errorf("BackupFile() content = %q, want %q", got, data) + } +} + +func TestBackupFile_DefaultSuffix(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "data.json") + if err := os.WriteFile(path, []byte("{}"), 0o600); err != nil { + t.Fatalf("WriteFile() error: %v", err) + } + + backupPath, err := BackupFile(path, "") + if err != nil { + t.Fatalf("BackupFile() error: %v", err) + } + if backupPath != path+".bak" { + t.Errorf("BackupFile() default suffix: path = %q, want %q", backupPath, path+".bak") + } +} + +func TestBackupFile_SourceNotFound(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "nonexistent.txt") + + backupPath, err := BackupFile(path, ".bak") + if err != nil { + t.Fatalf("BackupFile() error: %v", err) + } + if backupPath != "" { + t.Errorf("BackupFile() for nonexistent source: path = %q, want empty", backupPath) + } +} + +// --------------------------------------------------------------------------- +// EnsureDir +// --------------------------------------------------------------------------- + +func TestEnsureDir_CreatesNew(t *testing.T) { + dir := filepath.Join(t.TempDir(), "a", "b", "c") + + if err := EnsureDir(dir, 0o755); err != nil { + t.Fatalf("EnsureDir() error: %v", err) + } + + fi, err := os.Stat(dir) + if err != nil { + t.Fatalf("Stat() error: %v", err) + } + if !fi.IsDir() { + t.Error("EnsureDir() did not create a directory") + } +} + +func TestEnsureDir_ExistingIsNoOp(t *testing.T) { + dir := t.TempDir() + if err := EnsureDir(dir, 0o755); err != nil { + t.Fatalf("EnsureDir() on existing dir error: %v", err) + } +} + +func TestEnsureDir_DefaultPermissions(t *testing.T) { + dir := filepath.Join(t.TempDir(), "default-perm") + if err := EnsureDir(dir, 0); err != nil { + t.Fatalf("EnsureDir(perm=0) error: %v", err) + } + + fi, err := os.Stat(dir) + if err != nil { + t.Fatalf("Stat() error: %v", err) + } + if !fi.IsDir() { + t.Error("EnsureDir(perm=0) did not create a directory") + } +} diff --git a/cli/azd/pkg/azdext/config_helper.go b/cli/azd/pkg/azdext/config_helper.go new file mode 100644 index 00000000000..cc28ab67a01 --- /dev/null +++ b/cli/azd/pkg/azdext/config_helper.go @@ -0,0 +1,426 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "regexp" +) + +// ConfigHelper provides typed, ergonomic access to azd configuration through +// the gRPC UserConfig and Environment services. It eliminates the boilerplate +// of raw gRPC calls and JSON marshaling that extension authors otherwise need. +// +// Configuration sources (in merge priority, lowest to highest): +// 1. User config (global azd config) — via UserConfigService +// 2. Environment config (per-env) — via EnvironmentService +// +// Usage: +// +// ch := azdext.NewConfigHelper(client) +// port, err := ch.GetUserString(ctx, "extensions.myext.port") +// var cfg MyConfig +// err = ch.GetUserJSON(ctx, "extensions.myext", &cfg) +type ConfigHelper struct { + client *AzdClient +} + +// NewConfigHelper creates a [ConfigHelper] for the given AZD client. +func NewConfigHelper(client *AzdClient) (*ConfigHelper, error) { + if client == nil { + return nil, errors.New("azdext.NewConfigHelper: client must not be nil") + } + + return &ConfigHelper{client: client}, nil +} + +// --- User Config (global) --- + +// GetUserString retrieves a string value from the global user config at the +// given dot-separated path. Returns ("", false, nil) when the path does not +// exist, and ("", false, err) on gRPC errors. +func (ch *ConfigHelper) GetUserString(ctx context.Context, path string) (string, bool, error) { + if err := validatePath(path); err != nil { + return "", false, err + } + + resp, err := ch.client.UserConfig().GetString(ctx, &GetUserConfigStringRequest{Path: path}) + if err != nil { + return "", false, fmt.Errorf("azdext.ConfigHelper.GetUserString: gRPC call failed for path %q: %w", path, err) + } + + return resp.GetValue(), resp.GetFound(), nil +} + +// GetUserJSON retrieves a value from the global user config and unmarshals it +// into out. Returns (false, nil) when the path does not exist. +func (ch *ConfigHelper) GetUserJSON(ctx context.Context, path string, out any) (bool, error) { + if err := validatePath(path); err != nil { + return false, err + } + + if out == nil { + return false, errors.New("azdext.ConfigHelper.GetUserJSON: out must not be nil") + } + + resp, err := ch.client.UserConfig().Get(ctx, &GetUserConfigRequest{Path: path}) + if err != nil { + return false, fmt.Errorf("azdext.ConfigHelper.GetUserJSON: gRPC call failed for path %q: %w", path, err) + } + + if !resp.GetFound() { + return false, nil + } + + data := resp.GetValue() + if len(data) == 0 { + return false, nil + } + + if err := json.Unmarshal(data, out); err != nil { + return true, &ConfigError{ + Path: path, + Reason: ConfigReasonInvalidFormat, + Err: fmt.Errorf("failed to unmarshal config at path %q: %w", path, err), + } + } + + return true, nil +} + +// SetUserJSON marshals value as JSON and writes it to the global user config +// at the given path. +func (ch *ConfigHelper) SetUserJSON(ctx context.Context, path string, value any) error { + if err := validatePath(path); err != nil { + return err + } + + if value == nil { + return errors.New("azdext.ConfigHelper.SetUserJSON: value must not be nil") + } + + data, err := json.Marshal(value) + if err != nil { + return &ConfigError{ + Path: path, + Reason: ConfigReasonInvalidFormat, + Err: fmt.Errorf("failed to marshal value for path %q: %w", path, err), + } + } + + _, err = ch.client.UserConfig().Set(ctx, &SetUserConfigRequest{ + Path: path, + Value: data, + }) + if err != nil { + return fmt.Errorf("azdext.ConfigHelper.SetUserJSON: gRPC call failed for path %q: %w", path, err) + } + + return nil +} + +// UnsetUser removes a value from the global user config. +func (ch *ConfigHelper) UnsetUser(ctx context.Context, path string) error { + if err := validatePath(path); err != nil { + return err + } + + _, err := ch.client.UserConfig().Unset(ctx, &UnsetUserConfigRequest{Path: path}) + if err != nil { + return fmt.Errorf("azdext.ConfigHelper.UnsetUser: gRPC call failed for path %q: %w", path, err) + } + + return nil +} + +// --- Environment Config (per-environment) --- + +// GetEnvString retrieves a string config value from the current environment. +// Returns ("", false, nil) when the path does not exist. +func (ch *ConfigHelper) GetEnvString(ctx context.Context, path string) (string, bool, error) { + if err := validatePath(path); err != nil { + return "", false, err + } + + resp, err := ch.client.Environment().GetConfigString(ctx, &GetConfigStringRequest{Path: path}) + if err != nil { + return "", false, fmt.Errorf("azdext.ConfigHelper.GetEnvString: gRPC call failed for path %q: %w", path, err) + } + + return resp.GetValue(), resp.GetFound(), nil +} + +// GetEnvJSON retrieves a value from the current environment's config and +// unmarshals it into out. Returns (false, nil) when the path does not exist. +func (ch *ConfigHelper) GetEnvJSON(ctx context.Context, path string, out any) (bool, error) { + if err := validatePath(path); err != nil { + return false, err + } + + if out == nil { + return false, errors.New("azdext.ConfigHelper.GetEnvJSON: out must not be nil") + } + + resp, err := ch.client.Environment().GetConfig(ctx, &GetConfigRequest{Path: path}) + if err != nil { + return false, fmt.Errorf("azdext.ConfigHelper.GetEnvJSON: gRPC call failed for path %q: %w", path, err) + } + + if !resp.GetFound() { + return false, nil + } + + data := resp.GetValue() + if len(data) == 0 { + return false, nil + } + + if err := json.Unmarshal(data, out); err != nil { + return true, &ConfigError{ + Path: path, + Reason: ConfigReasonInvalidFormat, + Err: fmt.Errorf("failed to unmarshal env config at path %q: %w", path, err), + } + } + + return true, nil +} + +// SetEnvJSON marshals value as JSON and writes it to the current environment's config. +func (ch *ConfigHelper) SetEnvJSON(ctx context.Context, path string, value any) error { + if err := validatePath(path); err != nil { + return err + } + + if value == nil { + return errors.New("azdext.ConfigHelper.SetEnvJSON: value must not be nil") + } + + data, err := json.Marshal(value) + if err != nil { + return &ConfigError{ + Path: path, + Reason: ConfigReasonInvalidFormat, + Err: fmt.Errorf("failed to marshal value for env config path %q: %w", path, err), + } + } + + _, err = ch.client.Environment().SetConfig(ctx, &SetConfigRequest{ + Path: path, + Value: data, + }) + if err != nil { + return fmt.Errorf("azdext.ConfigHelper.SetEnvJSON: gRPC call failed for path %q: %w", path, err) + } + + return nil +} + +// UnsetEnv removes a value from the current environment's config. +func (ch *ConfigHelper) UnsetEnv(ctx context.Context, path string) error { + if err := validatePath(path); err != nil { + return err + } + + _, err := ch.client.Environment().UnsetConfig(ctx, &UnsetConfigRequest{Path: path}) + if err != nil { + return fmt.Errorf("azdext.ConfigHelper.UnsetEnv: gRPC call failed for path %q: %w", path, err) + } + + return nil +} + +// --- Merge --- + +// MergeJSON performs a shallow merge of override into base, returning a new map. +// Both inputs must be JSON-compatible maps (map[string]any). Keys in override +// take precedence over keys in base. +// +// This is NOT a deep merge — nested maps are replaced entirely by the override +// value. For predictable extension config behavior, keep config structures flat +// or use explicit path-based Set operations for nested values. +func MergeJSON(base, override map[string]any) map[string]any { + merged := make(map[string]any, len(base)+len(override)) + + for k, v := range base { + merged[k] = v + } + + for k, v := range override { + merged[k] = v + } + + return merged +} + +// deepMergeMaxDepth is the maximum recursion depth for [DeepMergeJSON]. +// This prevents stack overflow from deeply nested or adversarial JSON +// structures. 32 levels is far deeper than any legitimate config hierarchy. +const deepMergeMaxDepth = 32 + +// DeepMergeJSON performs a recursive merge of override into base. +// When both base and override have a map value for the same key, those maps +// are merged recursively. Otherwise the override value replaces the base value. +// +// Recursion is bounded to [deepMergeMaxDepth] levels to prevent stack overflow +// from deeply nested or adversarial inputs. Beyond the limit, the override +// value replaces the base value (merge degrades to shallow at that level). +func DeepMergeJSON(base, override map[string]any) map[string]any { + return deepMergeJSON(base, override, 0) +} + +func deepMergeJSON(base, override map[string]any, depth int) map[string]any { + merged := make(map[string]any, len(base)+len(override)) + + for k, v := range base { + merged[k] = v + } + + for k, v := range override { + baseVal, exists := merged[k] + if !exists { + merged[k] = v + continue + } + + baseMap, baseIsMap := baseVal.(map[string]any) + overMap, overIsMap := v.(map[string]any) + + if baseIsMap && overIsMap && depth < deepMergeMaxDepth { + merged[k] = deepMergeJSON(baseMap, overMap, depth+1) + } else { + merged[k] = v + } + } + + return merged +} + +// --- Validation --- + +// ConfigValidator defines a function that validates a config value. +// It returns nil if valid, or an error describing the validation failure. +type ConfigValidator func(value any) error + +// ValidateConfig unmarshals the raw JSON data and runs all supplied validators. +// Returns the first validation error encountered, wrapped in a [*ConfigError]. +func ValidateConfig(path string, data []byte, validators ...ConfigValidator) error { + if len(data) == 0 { + return &ConfigError{ + Path: path, + Reason: ConfigReasonMissing, + Err: fmt.Errorf("config at path %q is empty", path), + } + } + + var value any + if err := json.Unmarshal(data, &value); err != nil { + return &ConfigError{ + Path: path, + Reason: ConfigReasonInvalidFormat, + Err: fmt.Errorf("config at path %q is not valid JSON: %w", path, err), + } + } + + for _, v := range validators { + if err := v(value); err != nil { + return &ConfigError{ + Path: path, + Reason: ConfigReasonValidationFailed, + Err: fmt.Errorf("config validation failed at path %q: %w", path, err), + } + } + } + + return nil +} + +// RequiredKeys returns a [ConfigValidator] that checks for the presence of +// the specified keys in a map value. +func RequiredKeys(keys ...string) ConfigValidator { + return func(value any) error { + m, ok := value.(map[string]any) + if !ok { + return fmt.Errorf("expected object, got %T", value) + } + + for _, key := range keys { + if _, exists := m[key]; !exists { + return fmt.Errorf("required key %q is missing", key) + } + } + + return nil + } +} + +// --- Error types --- + +// ConfigReason classifies the cause of a [ConfigError]. +type ConfigReason int + +const ( + // ConfigReasonMissing indicates the config path does not exist or is empty. + ConfigReasonMissing ConfigReason = iota + + // ConfigReasonInvalidFormat indicates the config value is not valid JSON + // or cannot be unmarshaled into the target type. + ConfigReasonInvalidFormat + + // ConfigReasonValidationFailed indicates a validator rejected the config value. + ConfigReasonValidationFailed +) + +// String returns a human-readable label. +func (r ConfigReason) String() string { + switch r { + case ConfigReasonMissing: + return "missing" + case ConfigReasonInvalidFormat: + return "invalid_format" + case ConfigReasonValidationFailed: + return "validation_failed" + default: + return "unknown" + } +} + +// ConfigError is returned by [ConfigHelper] methods on domain-level failures. +type ConfigError struct { + // Path is the config path that was being accessed. + Path string + + // Reason classifies the failure. + Reason ConfigReason + + // Err is the underlying error. + Err error +} + +func (e *ConfigError) Error() string { + return fmt.Sprintf("azdext.ConfigHelper: %s (path=%s): %v", e.Reason, e.Path, e.Err) +} + +func (e *ConfigError) Unwrap() error { + return e.Err +} + +var configPathRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9._-]*$`) + +// validatePath checks that a config path is non-empty. +func validatePath(path string) error { + if path == "" { + return errors.New("azdext.ConfigHelper: config path must not be empty") + } + if !configPathRe.MatchString(path) { + return errors.New( + "azdext.ConfigHelper: config path must start with alphanumeric and contain only [a-zA-Z0-9._-]", + ) + } + + return nil +} diff --git a/cli/azd/pkg/azdext/config_helper_test.go b/cli/azd/pkg/azdext/config_helper_test.go new file mode 100644 index 00000000000..c68464f16bc --- /dev/null +++ b/cli/azd/pkg/azdext/config_helper_test.go @@ -0,0 +1,967 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "google.golang.org/grpc" +) + +// --- Stub UserConfigService --- + +type stubUserConfigService struct { + getResp *GetUserConfigResponse + getStringResp *GetUserConfigStringResponse + getSectionErr error + getErr error + getStringErr error + setErr error + unsetErr error +} + +func (s *stubUserConfigService) Get( + _ context.Context, _ *GetUserConfigRequest, _ ...grpc.CallOption, +) (*GetUserConfigResponse, error) { + return s.getResp, s.getErr +} + +func (s *stubUserConfigService) GetString( + _ context.Context, _ *GetUserConfigStringRequest, _ ...grpc.CallOption, +) (*GetUserConfigStringResponse, error) { + return s.getStringResp, s.getStringErr +} + +func (s *stubUserConfigService) GetSection( + _ context.Context, _ *GetUserConfigSectionRequest, _ ...grpc.CallOption, +) (*GetUserConfigSectionResponse, error) { + return nil, s.getSectionErr +} + +func (s *stubUserConfigService) Set( + _ context.Context, _ *SetUserConfigRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return &EmptyResponse{}, s.setErr +} + +func (s *stubUserConfigService) Unset( + _ context.Context, _ *UnsetUserConfigRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return &EmptyResponse{}, s.unsetErr +} + +// --- Stub EnvironmentService --- + +type stubEnvironmentService struct { + getConfigResp *GetConfigResponse + getConfigStringResp *GetConfigStringResponse + getConfigErr error + getConfigStringErr error + setConfigErr error + unsetConfigErr error +} + +func (s *stubEnvironmentService) GetCurrent( + _ context.Context, _ *EmptyRequest, _ ...grpc.CallOption, +) (*EnvironmentResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) List( + _ context.Context, _ *EmptyRequest, _ ...grpc.CallOption, +) (*EnvironmentListResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) Get( + _ context.Context, _ *GetEnvironmentRequest, _ ...grpc.CallOption, +) (*EnvironmentResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) Select( + _ context.Context, _ *SelectEnvironmentRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) GetValues( + _ context.Context, _ *GetEnvironmentRequest, _ ...grpc.CallOption, +) (*KeyValueListResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) GetValue( + _ context.Context, _ *GetEnvRequest, _ ...grpc.CallOption, +) (*KeyValueResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) SetValue( + _ context.Context, _ *SetEnvRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) GetConfig( + _ context.Context, _ *GetConfigRequest, _ ...grpc.CallOption, +) (*GetConfigResponse, error) { + return s.getConfigResp, s.getConfigErr +} + +func (s *stubEnvironmentService) GetConfigString( + _ context.Context, _ *GetConfigStringRequest, _ ...grpc.CallOption, +) (*GetConfigStringResponse, error) { + return s.getConfigStringResp, s.getConfigStringErr +} + +func (s *stubEnvironmentService) GetConfigSection( + _ context.Context, _ *GetConfigSectionRequest, _ ...grpc.CallOption, +) (*GetConfigSectionResponse, error) { + return nil, nil +} + +func (s *stubEnvironmentService) SetConfig( + _ context.Context, _ *SetConfigRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return &EmptyResponse{}, s.setConfigErr +} + +func (s *stubEnvironmentService) UnsetConfig( + _ context.Context, _ *UnsetConfigRequest, _ ...grpc.CallOption, +) (*EmptyResponse, error) { + return &EmptyResponse{}, s.unsetConfigErr +} + +// --- NewConfigHelper --- + +func TestNewConfigHelper_NilClient(t *testing.T) { + t.Parallel() + + _, err := NewConfigHelper(nil) + if err == nil { + t.Fatal("expected error for nil client") + } +} + +func TestNewConfigHelper_Success(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, err := NewConfigHelper(client) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if ch == nil { + t.Fatal("expected non-nil ConfigHelper") + } +} + +// --- GetUserString --- + +func TestGetUserString_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + _, _, err := ch.GetUserString(context.Background(), "") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +func TestGetUserString_Found(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + getStringResp: &GetUserConfigStringResponse{Value: "8080", Found: true}, + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + val, found, err := ch.GetUserString(context.Background(), "extensions.myext.port") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !found { + t.Error("expected found = true") + } + + if val != "8080" { + t.Errorf("value = %q, want %q", val, "8080") + } +} + +func TestGetUserString_NotFound(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + getStringResp: &GetUserConfigStringResponse{Value: "", Found: false}, + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + val, found, err := ch.GetUserString(context.Background(), "nonexistent.path") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if found { + t.Error("expected found = false") + } + + if val != "" { + t.Errorf("value = %q, want empty", val) + } +} + +func TestGetUserString_GRPCError(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + getStringErr: errors.New("grpc unavailable"), + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + _, _, err := ch.GetUserString(context.Background(), "some.path") + if err == nil { + t.Fatal("expected error for gRPC failure") + } +} + +// --- GetUserJSON --- + +func TestGetUserJSON_Found(t *testing.T) { + t.Parallel() + + type myConfig struct { + Port int `json:"port"` + Host string `json:"host"` + } + + data, _ := json.Marshal(myConfig{Port: 3000, Host: "localhost"}) + stub := &stubUserConfigService{ + getResp: &GetUserConfigResponse{Value: data, Found: true}, + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + var cfg myConfig + found, err := ch.GetUserJSON(context.Background(), "extensions.myext", &cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !found { + t.Error("expected found = true") + } + + if cfg.Port != 3000 { + t.Errorf("Port = %d, want 3000", cfg.Port) + } + + if cfg.Host != "localhost" { + t.Errorf("Host = %q, want %q", cfg.Host, "localhost") + } +} + +func TestGetUserJSON_NotFound(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + getResp: &GetUserConfigResponse{Value: nil, Found: false}, + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + var cfg map[string]any + found, err := ch.GetUserJSON(context.Background(), "nonexistent", &cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if found { + t.Error("expected found = false") + } +} + +func TestGetUserJSON_NilOut(t *testing.T) { + t.Parallel() + + client := &AzdClient{userConfigClient: &stubUserConfigService{}} + ch, _ := NewConfigHelper(client) + + _, err := ch.GetUserJSON(context.Background(), "some.path", nil) + if err == nil { + t.Fatal("expected error for nil out parameter") + } +} + +func TestGetUserJSON_InvalidJSON(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + getResp: &GetUserConfigResponse{Value: []byte("not json"), Found: true}, + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + var cfg map[string]any + _, err := ch.GetUserJSON(context.Background(), "bad.json", &cfg) + if err == nil { + t.Fatal("expected error for invalid JSON") + } + + var cfgErr *ConfigError + if !errors.As(err, &cfgErr) { + t.Fatalf("error type = %T, want *ConfigError", err) + } + + if cfgErr.Reason != ConfigReasonInvalidFormat { + t.Errorf("Reason = %v, want %v", cfgErr.Reason, ConfigReasonInvalidFormat) + } +} + +func TestGetUserJSON_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + var cfg map[string]any + _, err := ch.GetUserJSON(context.Background(), "", &cfg) + if err == nil { + t.Fatal("expected error for empty path") + } +} + +// --- SetUserJSON --- + +func TestSetUserJSON_Success(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{} + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + err := ch.SetUserJSON(context.Background(), "extensions.myext.port", 3000) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestSetUserJSON_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + err := ch.SetUserJSON(context.Background(), "", "value") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +func TestSetUserJSON_NilValue(t *testing.T) { + t.Parallel() + + client := &AzdClient{userConfigClient: &stubUserConfigService{}} + ch, _ := NewConfigHelper(client) + + err := ch.SetUserJSON(context.Background(), "some.path", nil) + if err == nil { + t.Fatal("expected error for nil value") + } +} + +func TestSetUserJSON_GRPCError(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{ + setErr: errors.New("grpc write error"), + } + + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + err := ch.SetUserJSON(context.Background(), "some.path", "value") + if err == nil { + t.Fatal("expected error for gRPC failure") + } +} + +func TestSetUserJSON_UnmarshalableValue(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{} + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + // Channels cannot be marshaled to JSON + err := ch.SetUserJSON(context.Background(), "some.path", make(chan int)) + if err == nil { + t.Fatal("expected error for unmarshalable value") + } + + var cfgErr *ConfigError + if !errors.As(err, &cfgErr) { + t.Fatalf("error type = %T, want *ConfigError", err) + } + + if cfgErr.Reason != ConfigReasonInvalidFormat { + t.Errorf("Reason = %v, want %v", cfgErr.Reason, ConfigReasonInvalidFormat) + } +} + +// --- UnsetUser --- + +func TestUnsetUser_Success(t *testing.T) { + t.Parallel() + + stub := &stubUserConfigService{} + client := &AzdClient{userConfigClient: stub} + ch, _ := NewConfigHelper(client) + + err := ch.UnsetUser(context.Background(), "some.path") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestUnsetUser_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + err := ch.UnsetUser(context.Background(), "") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +// --- GetEnvString --- + +func TestGetEnvString_Found(t *testing.T) { + t.Parallel() + + stub := &stubEnvironmentService{ + getConfigStringResp: &GetConfigStringResponse{Value: "prod", Found: true}, + } + + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + val, found, err := ch.GetEnvString(context.Background(), "extensions.myext.mode") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !found { + t.Error("expected found = true") + } + + if val != "prod" { + t.Errorf("value = %q, want %q", val, "prod") + } +} + +func TestGetEnvString_NotFound(t *testing.T) { + t.Parallel() + + stub := &stubEnvironmentService{ + getConfigStringResp: &GetConfigStringResponse{Value: "", Found: false}, + } + + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + _, found, err := ch.GetEnvString(context.Background(), "nonexistent") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if found { + t.Error("expected found = false") + } +} + +func TestGetEnvString_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + _, _, err := ch.GetEnvString(context.Background(), "") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +// --- GetEnvJSON --- + +func TestGetEnvJSON_Found(t *testing.T) { + t.Parallel() + + type envConfig struct { + Debug bool `json:"debug"` + } + + data, _ := json.Marshal(envConfig{Debug: true}) + stub := &stubEnvironmentService{ + getConfigResp: &GetConfigResponse{Value: data, Found: true}, + } + + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + var cfg envConfig + found, err := ch.GetEnvJSON(context.Background(), "extensions.myext", &cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !found { + t.Error("expected found = true") + } + + if !cfg.Debug { + t.Error("expected Debug = true") + } +} + +func TestGetEnvJSON_NotFound(t *testing.T) { + t.Parallel() + + stub := &stubEnvironmentService{ + getConfigResp: &GetConfigResponse{Value: nil, Found: false}, + } + + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + var cfg map[string]any + found, err := ch.GetEnvJSON(context.Background(), "nonexistent", &cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if found { + t.Error("expected found = false") + } +} + +func TestGetEnvJSON_NilOut(t *testing.T) { + t.Parallel() + + client := &AzdClient{environmentClient: &stubEnvironmentService{}} + ch, _ := NewConfigHelper(client) + + _, err := ch.GetEnvJSON(context.Background(), "some.path", nil) + if err == nil { + t.Fatal("expected error for nil out parameter") + } +} + +// --- SetEnvJSON --- + +func TestSetEnvJSON_Success(t *testing.T) { + t.Parallel() + + stub := &stubEnvironmentService{} + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + err := ch.SetEnvJSON(context.Background(), "extensions.myext.mode", "prod") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestSetEnvJSON_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + err := ch.SetEnvJSON(context.Background(), "", "value") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +func TestSetEnvJSON_NilValue(t *testing.T) { + t.Parallel() + + client := &AzdClient{environmentClient: &stubEnvironmentService{}} + ch, _ := NewConfigHelper(client) + + err := ch.SetEnvJSON(context.Background(), "some.path", nil) + if err == nil { + t.Fatal("expected error for nil value") + } +} + +// --- UnsetEnv --- + +func TestUnsetEnv_Success(t *testing.T) { + t.Parallel() + + stub := &stubEnvironmentService{} + client := &AzdClient{environmentClient: stub} + ch, _ := NewConfigHelper(client) + + err := ch.UnsetEnv(context.Background(), "some.path") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestUnsetEnv_EmptyPath(t *testing.T) { + t.Parallel() + + client := &AzdClient{} + ch, _ := NewConfigHelper(client) + + err := ch.UnsetEnv(context.Background(), "") + if err == nil { + t.Fatal("expected error for empty path") + } +} + +// --- MergeJSON --- + +func TestMergeJSON_Basic(t *testing.T) { + t.Parallel() + + base := map[string]any{"a": 1, "b": 2} + override := map[string]any{"b": 3, "c": 4} + + result := MergeJSON(base, override) + + if result["a"] != 1 { + t.Errorf("a = %v, want 1", result["a"]) + } + + if result["b"] != 3 { + t.Errorf("b = %v, want 3 (override wins)", result["b"]) + } + + if result["c"] != 4 { + t.Errorf("c = %v, want 4", result["c"]) + } +} + +func TestMergeJSON_EmptyBase(t *testing.T) { + t.Parallel() + + result := MergeJSON(nil, map[string]any{"x": "y"}) + + if result["x"] != "y" { + t.Errorf("x = %v, want y", result["x"]) + } +} + +func TestMergeJSON_EmptyOverride(t *testing.T) { + t.Parallel() + + result := MergeJSON(map[string]any{"x": "y"}, nil) + + if result["x"] != "y" { + t.Errorf("x = %v, want y", result["x"]) + } +} + +func TestMergeJSON_BothEmpty(t *testing.T) { + t.Parallel() + + result := MergeJSON(nil, nil) + + if len(result) != 0 { + t.Errorf("len(result) = %d, want 0", len(result)) + } +} + +func TestMergeJSON_DoesNotMutateInputs(t *testing.T) { + t.Parallel() + + base := map[string]any{"a": 1} + override := map[string]any{"b": 2} + + _ = MergeJSON(base, override) + + if _, ok := base["b"]; ok { + t.Error("MergeJSON mutated base map") + } + + if _, ok := override["a"]; ok { + t.Error("MergeJSON mutated override map") + } +} + +// --- DeepMergeJSON --- + +func TestDeepMergeJSON_RecursiveMerge(t *testing.T) { + t.Parallel() + + base := map[string]any{ + "server": map[string]any{ + "host": "localhost", + "port": 3000, + }, + "debug": false, + } + + override := map[string]any{ + "server": map[string]any{ + "port": 8080, + "tls": true, + }, + "version": "1.0", + } + + result := DeepMergeJSON(base, override) + + server, ok := result["server"].(map[string]any) + if !ok { + t.Fatal("server should be a map") + } + + if server["host"] != "localhost" { + t.Errorf("server.host = %v, want localhost", server["host"]) + } + + if server["port"] != 8080 { + t.Errorf("server.port = %v, want 8080 (override wins)", server["port"]) + } + + if server["tls"] != true { + t.Errorf("server.tls = %v, want true", server["tls"]) + } + + if result["debug"] != false { + t.Errorf("debug = %v, want false", result["debug"]) + } + + if result["version"] != "1.0" { + t.Errorf("version = %v, want 1.0", result["version"]) + } +} + +func TestDeepMergeJSON_OverrideReplacesNonMap(t *testing.T) { + t.Parallel() + + base := map[string]any{"x": "string-value"} + override := map[string]any{"x": map[string]any{"nested": true}} + + result := DeepMergeJSON(base, override) + + nested, ok := result["x"].(map[string]any) + if !ok { + t.Fatal("override should replace string with map") + } + + if nested["nested"] != true { + t.Errorf("x.nested = %v, want true", nested["nested"]) + } +} + +func TestDeepMergeJSON_DoesNotMutateInputs(t *testing.T) { + t.Parallel() + + base := map[string]any{"a": map[string]any{"x": 1}} + override := map[string]any{"a": map[string]any{"y": 2}} + + _ = DeepMergeJSON(base, override) + + baseA := base["a"].(map[string]any) + if _, ok := baseA["y"]; ok { + t.Error("DeepMergeJSON mutated base nested map") + } +} + +// --- ValidateConfig --- + +func TestValidateConfig_EmptyData(t *testing.T) { + t.Parallel() + + err := ValidateConfig("test.path", nil) + if err == nil { + t.Fatal("expected error for empty data") + } + + var cfgErr *ConfigError + if !errors.As(err, &cfgErr) { + t.Fatalf("error type = %T, want *ConfigError", err) + } + + if cfgErr.Reason != ConfigReasonMissing { + t.Errorf("Reason = %v, want %v", cfgErr.Reason, ConfigReasonMissing) + } +} + +func TestValidateConfig_InvalidJSON(t *testing.T) { + t.Parallel() + + err := ValidateConfig("test.path", []byte("not json")) + if err == nil { + t.Fatal("expected error for invalid JSON") + } + + var cfgErr *ConfigError + if !errors.As(err, &cfgErr) { + t.Fatalf("error type = %T, want *ConfigError", err) + } + + if cfgErr.Reason != ConfigReasonInvalidFormat { + t.Errorf("Reason = %v, want %v", cfgErr.Reason, ConfigReasonInvalidFormat) + } +} + +func TestValidateConfig_ValidatorFails(t *testing.T) { + t.Parallel() + + data, _ := json.Marshal(map[string]any{"a": 1}) + failValidator := func(_ any) error { return errors.New("validation failed") } + + err := ValidateConfig("test.path", data, failValidator) + if err == nil { + t.Fatal("expected error from failing validator") + } + + var cfgErr *ConfigError + if !errors.As(err, &cfgErr) { + t.Fatalf("error type = %T, want *ConfigError", err) + } + + if cfgErr.Reason != ConfigReasonValidationFailed { + t.Errorf("Reason = %v, want %v", cfgErr.Reason, ConfigReasonValidationFailed) + } +} + +func TestValidateConfig_AllValidatorsPass(t *testing.T) { + t.Parallel() + + data, _ := json.Marshal(map[string]any{"a": 1, "b": 2}) + passValidator := func(_ any) error { return nil } + + err := ValidateConfig("test.path", data, passValidator, passValidator) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestValidateConfig_NoValidators(t *testing.T) { + t.Parallel() + + data, _ := json.Marshal(map[string]any{"a": 1}) + + err := ValidateConfig("test.path", data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +// --- RequiredKeys --- + +func TestRequiredKeys_AllPresent(t *testing.T) { + t.Parallel() + + validator := RequiredKeys("host", "port") + value := map[string]any{"host": "localhost", "port": 3000, "extra": true} + + err := validator(value) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestRequiredKeys_MissingKey(t *testing.T) { + t.Parallel() + + validator := RequiredKeys("host", "port") + value := map[string]any{"host": "localhost"} + + err := validator(value) + if err == nil { + t.Fatal("expected error for missing key") + } +} + +func TestRequiredKeys_NotAMap(t *testing.T) { + t.Parallel() + + validator := RequiredKeys("key") + + err := validator("not a map") + if err == nil { + t.Fatal("expected error for non-map value") + } +} + +// --- ConfigError --- + +func TestConfigError_Error(t *testing.T) { + t.Parallel() + + err := &ConfigError{ + Path: "test.path", + Reason: ConfigReasonMissing, + Err: errors.New("not found"), + } + + got := err.Error() + if got == "" { + t.Fatal("Error() returned empty string") + } +} + +func TestConfigError_Unwrap(t *testing.T) { + t.Parallel() + + inner := errors.New("inner error") + err := &ConfigError{ + Path: "test.path", + Reason: ConfigReasonInvalidFormat, + Err: inner, + } + + if !errors.Is(err, inner) { + t.Error("Unwrap should expose inner error via errors.Is") + } +} + +func TestConfigReason_String(t *testing.T) { + t.Parallel() + + tests := []struct { + reason ConfigReason + want string + }{ + {ConfigReasonMissing, "missing"}, + {ConfigReasonInvalidFormat, "invalid_format"}, + {ConfigReasonValidationFailed, "validation_failed"}, + {ConfigReason(99), "unknown"}, + } + + for _, tt := range tests { + if got := tt.reason.String(); got != tt.want { + t.Errorf("ConfigReason(%d).String() = %q, want %q", tt.reason, got, tt.want) + } + } +} diff --git a/cli/azd/pkg/azdext/context.go b/cli/azd/pkg/azdext/context.go index 47e4b1915fe..d4d9d5bef85 100644 --- a/cli/azd/pkg/azdext/context.go +++ b/cli/azd/pkg/azdext/context.go @@ -21,6 +21,13 @@ const ( ) // NewContext initializes a new context with tracing information extracted from environment variables. +// +// Deprecated: Use [Run] for custom-command extensions — it creates the context, +// injects the access token, reports structured errors, and handles os.Exit. +// For lifecycle-listener extensions, use [NewListenCommand] which sets up +// context and access token automatically. +// If you need parsed global flags (--debug, --no-prompt, --cwd, -e), use +// [NewExtensionRootCommand] together with [Run]. func NewContext() context.Context { ctx := context.Background() parent := os.Getenv(TraceparentEnv) diff --git a/cli/azd/pkg/azdext/extension_command.go b/cli/azd/pkg/azdext/extension_command.go index 8ea892af734..e96a7851cd2 100644 --- a/cli/azd/pkg/azdext/extension_command.go +++ b/cli/azd/pkg/azdext/extension_command.go @@ -55,6 +55,11 @@ type ExtensionCommandOptions struct { // - Calls WithAccessToken() on the command context // // The returned command has PersistentPreRunE configured to set up the ExtensionContext. +// +// NOTE: This function and its companion helpers ([NewListenCommand], [NewMetadataCommand], +// [NewVersionCommand]) depend on [github.com/spf13/cobra]. If non-cobra CLI frameworks +// gain adoption among extension authors, these symbols are candidates for extraction into +// an azdext/cobra sub-package so the core SDK remains framework-agnostic. func NewExtensionRootCommand(opts ExtensionCommandOptions) (*cobra.Command, *ExtensionContext) { extCtx := &ExtensionContext{} diff --git a/cli/azd/pkg/azdext/logger.go b/cli/azd/pkg/azdext/logger.go new file mode 100644 index 00000000000..bc246a49fa5 --- /dev/null +++ b/cli/azd/pkg/azdext/logger.go @@ -0,0 +1,170 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "io" + "log/slog" + "os" + "strings" +) + +// LoggerOptions configures [SetupLogging] and [NewLogger]. +type LoggerOptions struct { + // Debug enables debug-level logging. When false, messages below Info are + // suppressed. If not set explicitly, [NewLogger] checks the AZD_DEBUG + // environment variable. + Debug bool + // Structured selects JSON output when true, human-readable text when false. + Structured bool + // Writer overrides the output destination. Defaults to os.Stderr. + Writer io.Writer +} + +// SetupLogging configures the process-wide default [slog.Logger]. +// It is typically called once at startup (for example from +// [NewExtensionRootCommand]'s PersistentPreRunE callback). +// +// Calling SetupLogging is optional — [NewLogger] works without it and creates +// loggers that inherit from [slog.Default]. SetupLogging is provided for +// extensions that want explicit control over the global log level and format. +func SetupLogging(opts LoggerOptions) { + handler := newHandler(opts) + slog.SetDefault(slog.New(handler)) +} + +// Logger provides component-scoped structured logging built on [log/slog]. +// +// Each Logger carries a "component" attribute so log lines can be filtered or +// routed by subsystem. Additional context can be attached via [Logger.With], +// [Logger.WithComponent], or [Logger.WithOperation]. +// +// Logger writes to stderr by default and never writes to stdout, so it does +// not interfere with command output or JSON-mode piping. +type Logger struct { + slogger *slog.Logger + component string +} + +// NewLogger creates a Logger scoped to the given component name. +// +// If the AZD_DEBUG environment variable is set to a truthy value ("1", "true", +// "yes") and opts.Debug is false, debug logging is enabled automatically. This +// lets extension authors respect the framework's debug flag without extra +// plumbing. +// +// When opts is omitted (zero value), the logger uses Info level with text +// format on stderr. +func NewLogger(component string, opts ...LoggerOptions) *Logger { + var o LoggerOptions + if len(opts) > 0 { + o = opts[0] + } + + // Auto-detect debug from environment when not explicitly set. + if !o.Debug { + o.Debug = isDebugEnv() + } + + handler := newHandler(o) + base := slog.New(handler).With("component", component) + + return &Logger{ + slogger: base, + component: component, + } +} + +// Component returns the component name this logger was created with. +func (l *Logger) Component() string { + return l.component +} + +// Debug logs a message at debug level with optional key-value pairs. +func (l *Logger) Debug(msg string, args ...any) { + l.slogger.Debug(msg, args...) +} + +// Info logs a message at info level with optional key-value pairs. +func (l *Logger) Info(msg string, args ...any) { + l.slogger.Info(msg, args...) +} + +// Warn logs a message at warn level with optional key-value pairs. +func (l *Logger) Warn(msg string, args ...any) { + l.slogger.Warn(msg, args...) +} + +// Error logs a message at error level with optional key-value pairs. +func (l *Logger) Error(msg string, args ...any) { + l.slogger.Error(msg, args...) +} + +// With returns a new Logger that includes the given key-value pairs in every +// subsequent log entry. Keys must be strings; values can be any type +// supported by [slog]. +// +// Example: +// +// l := logger.With("request_id", reqID) +// l.Info("processing") // includes component + request_id +func (l *Logger) With(args ...any) *Logger { + return &Logger{ + slogger: l.slogger.With(args...), + component: l.component, + } +} + +// WithComponent returns a new Logger with a different component name. The +// original component is preserved as "parent_component". +func (l *Logger) WithComponent(name string) *Logger { + return &Logger{ + slogger: l.slogger.With("parent_component", l.component, "component", name), + component: name, + } +} + +// WithOperation returns a new Logger with an "operation" attribute. +func (l *Logger) WithOperation(name string) *Logger { + return &Logger{ + slogger: l.slogger.With("operation", name), + component: l.component, + } +} + +// Slogger returns the underlying [*slog.Logger] for advanced use cases such +// as passing to libraries that accept a standard slog logger. +func (l *Logger) Slogger() *slog.Logger { + return l.slogger +} + +// --------------------------------------------------------------------------- +// Internal helpers +// --------------------------------------------------------------------------- + +// newHandler creates an slog.Handler from LoggerOptions. +func newHandler(opts LoggerOptions) slog.Handler { + w := opts.Writer + if w == nil { + w = os.Stderr + } + + level := slog.LevelInfo + if opts.Debug { + level = slog.LevelDebug + } + + handlerOpts := &slog.HandlerOptions{Level: level} + + if opts.Structured { + return slog.NewJSONHandler(w, handlerOpts) + } + return slog.NewTextHandler(w, handlerOpts) +} + +// isDebugEnv checks the AZD_DEBUG environment variable. +func isDebugEnv() bool { + v := strings.ToLower(os.Getenv("AZD_DEBUG")) + return v == "1" || v == "true" || v == "yes" +} diff --git a/cli/azd/pkg/azdext/logger_test.go b/cli/azd/pkg/azdext/logger_test.go new file mode 100644 index 00000000000..9e21b763ada --- /dev/null +++ b/cli/azd/pkg/azdext/logger_test.go @@ -0,0 +1,295 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "bytes" + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// NewLogger — basic construction +// --------------------------------------------------------------------------- + +func TestNewLogger_DefaultOptions(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("test-component", LoggerOptions{Writer: &buf}) + + require.NotNil(t, logger) + require.Equal(t, "test-component", logger.Component()) +} + +func TestNewLogger_ZeroOptions(t *testing.T) { + // Zero-value opts should not panic (writes to stderr). + logger := NewLogger("safe") + require.NotNil(t, logger) + require.Equal(t, "safe", logger.Component()) +} + +func TestNewLogger_NoOpts(t *testing.T) { + // Calling without variadic opts should not panic. + logger := NewLogger("minimal") + require.NotNil(t, logger) +} + +// --------------------------------------------------------------------------- +// Log levels — Info +// --------------------------------------------------------------------------- + +func TestLogger_Info(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("mycomp", LoggerOptions{Writer: &buf}) + + logger.Info("hello world", "key", "val") + + output := buf.String() + require.Contains(t, output, "hello world") + require.Contains(t, output, "key=val") + require.Contains(t, output, "component=mycomp") +} + +// --------------------------------------------------------------------------- +// Log levels — Debug +// --------------------------------------------------------------------------- + +func TestLogger_Debug_Enabled(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("dbg", LoggerOptions{Debug: true, Writer: &buf}) + + logger.Debug("debug message", "detail", "x") + + require.Contains(t, buf.String(), "debug message") + require.Contains(t, buf.String(), "detail=x") +} + +func TestLogger_Debug_Disabled(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("dbg", LoggerOptions{Debug: false, Writer: &buf}) + + logger.Debug("should not appear") + + require.Empty(t, buf.String()) +} + +func TestLogger_Debug_AZD_DEBUG_EnvVar(t *testing.T) { + t.Setenv("AZD_DEBUG", "true") + + var buf bytes.Buffer + logger := NewLogger("env-debug", LoggerOptions{Writer: &buf}) + + logger.Debug("from env var") + + require.Contains(t, buf.String(), "from env var") +} + +func TestLogger_Debug_AZD_DEBUG_EnvVar_One(t *testing.T) { + t.Setenv("AZD_DEBUG", "1") + + var buf bytes.Buffer + logger := NewLogger("env-one", LoggerOptions{Writer: &buf}) + + logger.Debug("debug via 1") + + require.Contains(t, buf.String(), "debug via 1") +} + +func TestLogger_Debug_AZD_DEBUG_EnvVar_Yes(t *testing.T) { + t.Setenv("AZD_DEBUG", "yes") + + var buf bytes.Buffer + logger := NewLogger("env-yes", LoggerOptions{Writer: &buf}) + + logger.Debug("debug via yes") + + require.Contains(t, buf.String(), "debug via yes") +} + +func TestLogger_Debug_AZD_DEBUG_Unset(t *testing.T) { + t.Setenv("AZD_DEBUG", "") + + var buf bytes.Buffer + logger := NewLogger("env-empty", LoggerOptions{Writer: &buf}) + + logger.Debug("hidden") + + require.Empty(t, buf.String()) +} + +// --------------------------------------------------------------------------- +// Log levels — Warn / Error +// --------------------------------------------------------------------------- + +func TestLogger_Warn(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("warn-test", LoggerOptions{Writer: &buf}) + + logger.Warn("something concerning", "retries", 3) + + require.Contains(t, buf.String(), "something concerning") + require.Contains(t, buf.String(), "retries=3") +} + +func TestLogger_Error(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("err-test", LoggerOptions{Writer: &buf}) + + logger.Error("bad thing happened", "code", 500) + + require.Contains(t, buf.String(), "bad thing happened") + require.Contains(t, buf.String(), "code=500") +} + +// --------------------------------------------------------------------------- +// Structured (JSON) output +// --------------------------------------------------------------------------- + +func TestLogger_StructuredJSON(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("json-comp", LoggerOptions{Structured: true, Writer: &buf}) + + logger.Info("structured entry", "env", "prod") + + // Each line should be valid JSON. + var parsed map[string]any + err := json.Unmarshal(buf.Bytes(), &parsed) + require.NoError(t, err) + require.Equal(t, "structured entry", parsed["msg"]) + require.Equal(t, "prod", parsed["env"]) + require.Equal(t, "json-comp", parsed["component"]) +} + +// --------------------------------------------------------------------------- +// Context chaining — With +// --------------------------------------------------------------------------- + +func TestLogger_With(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("base", LoggerOptions{Writer: &buf}) + + child := logger.With("request_id", "abc-123") + child.Info("processing") + + output := buf.String() + require.Contains(t, output, "request_id=abc-123") + require.Contains(t, output, "component=base") + require.Contains(t, output, "processing") + + // Child should have the same component. + require.Equal(t, "base", child.Component()) +} + +func TestLogger_With_ChainMultiple(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("chain", LoggerOptions{Writer: &buf}) + + child := logger.With("a", "1").With("b", "2") + child.Info("chained") + + output := buf.String() + require.Contains(t, output, "a=1") + require.Contains(t, output, "b=2") +} + +// --------------------------------------------------------------------------- +// Context chaining — WithComponent +// --------------------------------------------------------------------------- + +func TestLogger_WithComponent(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("parent", LoggerOptions{Structured: true, Writer: &buf}) + + child := logger.WithComponent("child-subsystem") + child.Info("from child") + + require.Equal(t, "child-subsystem", child.Component()) + + var parsed map[string]any + err := json.Unmarshal(buf.Bytes(), &parsed) + require.NoError(t, err) + require.Equal(t, "child-subsystem", parsed["component"]) + require.Equal(t, "parent", parsed["parent_component"]) +} + +// --------------------------------------------------------------------------- +// Context chaining — WithOperation +// --------------------------------------------------------------------------- + +func TestLogger_WithOperation(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("ops", LoggerOptions{Writer: &buf}) + + child := logger.WithOperation("deploy") + child.Info("starting deploy") + + output := buf.String() + require.Contains(t, output, "operation=deploy") + require.Equal(t, "ops", child.Component()) +} + +// --------------------------------------------------------------------------- +// Slogger accessor +// --------------------------------------------------------------------------- + +func TestLogger_Slogger(t *testing.T) { + logger := NewLogger("access", LoggerOptions{Writer: &bytes.Buffer{}}) + require.NotNil(t, logger.Slogger()) +} + +// --------------------------------------------------------------------------- +// SetupLogging — global logger configuration +// --------------------------------------------------------------------------- + +func TestSetupLogging_DoesNotPanic(t *testing.T) { + // SetupLogging modifies slog.Default which is global state. + // We only verify it does not panic here. + var buf bytes.Buffer + SetupLogging(LoggerOptions{Debug: true, Structured: true, Writer: &buf}) + + // Restore a sensible default after the test. + SetupLogging(LoggerOptions{Writer: &bytes.Buffer{}}) +} + +// --------------------------------------------------------------------------- +// isDebugEnv internal helper +// --------------------------------------------------------------------------- + +func TestIsDebugEnv_Truthy(t *testing.T) { + truthy := []string{"1", "true", "TRUE", "True", "yes", "YES", "Yes"} + for _, v := range truthy { + t.Run(v, func(t *testing.T) { + t.Setenv("AZD_DEBUG", v) + require.True(t, isDebugEnv()) + }) + } +} + +func TestIsDebugEnv_Falsy(t *testing.T) { + falsy := []string{"", "0", "false", "no", "maybe"} + for _, v := range falsy { + t.Run("value="+v, func(t *testing.T) { + t.Setenv("AZD_DEBUG", v) + require.False(t, isDebugEnv()) + }) + } +} + +// --------------------------------------------------------------------------- +// Text format verification +// --------------------------------------------------------------------------- + +func TestLogger_TextFormat_ContainsLevel(t *testing.T) { + var buf bytes.Buffer + logger := NewLogger("lvl", LoggerOptions{Writer: &buf}) + + logger.Info("test level") + + output := buf.String() + require.True(t, + strings.Contains(output, "INFO") || strings.Contains(output, "level=INFO"), + "expected level indicator in text output: %s", output) +} diff --git a/cli/azd/pkg/azdext/mcp_security.go b/cli/azd/pkg/azdext/mcp_security.go index 9f3c84f46c0..6c27173f8cc 100644 --- a/cli/azd/pkg/azdext/mcp_security.go +++ b/cli/azd/pkg/azdext/mcp_security.go @@ -24,12 +24,10 @@ type MCPSecurityPolicy struct { allowedBasePaths []string blockedCIDRs []*net.IPNet blockedHosts map[string]bool + // onBlocked is invoked whenever a URL or path is blocked, for audit logging. + onBlocked func(violation string) // lookupHost is used for DNS resolution; override in tests. lookupHost func(string) ([]string, error) - // onBlocked is an optional callback invoked when a URL or path is blocked. - // Parameters: action ("url_blocked", "path_blocked"), - // detail (human-readable explanation). Safe for concurrent use. - onBlocked func(action, detail string) } // NewMCPSecurityPolicy creates an empty security policy. @@ -47,12 +45,7 @@ func (p *MCPSecurityPolicy) BlockMetadataEndpoints() *MCPSecurityPolicy { p.mu.Lock() defer p.mu.Unlock() p.blockMetadata = true - for _, host := range []string{ - "169.254.169.254", - "fd00:ec2::254", - "metadata.google.internal", - "100.100.100.200", - } { + for _, host := range ssrfMetadataHosts { p.blockedHosts[strings.ToLower(host)] = true } return p @@ -65,23 +58,7 @@ func (p *MCPSecurityPolicy) BlockPrivateNetworks() *MCPSecurityPolicy { p.mu.Lock() defer p.mu.Unlock() p.blockPrivate = true - for _, cidr := range []string{ - "0.0.0.0/8", // "this" network (reaches loopback on Linux/macOS) - "10.0.0.0/8", // RFC 1918 private - "172.16.0.0/12", // RFC 1918 private - "192.168.0.0/16", // RFC 1918 private - "127.0.0.0/8", // loopback - "100.64.0.0/10", // RFC 6598 shared/CGNAT (internal in cloud environments) - "169.254.0.0/16", // IPv4 link-local - "::1/128", // IPv6 loopback - "::/128", // IPv6 unspecified (reaches loopback) - "fc00::/7", // IPv6 unique local addresses (RFC 4193, equiv of RFC 1918) - "fe80::/10", // IPv6 link-local - "2002::/16", // 6to4 relay (deprecated RFC 7526; can embed private IPv4) - "2001::/32", // Teredo tunneling (deprecated; can embed private IPv4) - "64:ff9b::/96", // NAT64 well-known prefix (RFC 6052; embeds IPv4 in last 32 bits) - "64:ff9b:1::/48", // NAT64 local-use prefix (RFC 8215; embeds IPv4 in last 32 bits) - } { + for _, cidr := range ssrfBlockedCIDRs { _, ipNet, err := net.ParseCIDR(cidr) if err == nil { p.blockedCIDRs = append(p.blockedCIDRs, ipNet) @@ -116,14 +93,10 @@ func (p *MCPSecurityPolicy) ValidatePathsWithinBase(basePaths ...string) *MCPSec return p } -// OnBlocked registers a callback that is invoked whenever a URL or path is -// blocked by the security policy. This enables security audit -// logging without coupling the policy to a specific logging framework. -// -// The callback receives an action tag ("url_blocked", "path_blocked") -// and a human-readable detail string. It must be safe -// for concurrent invocation. -func (p *MCPSecurityPolicy) OnBlocked(fn func(action, detail string)) *MCPSecurityPolicy { +// OnBlocked registers a callback invoked whenever a URL or path check fails. +// The callback receives a human-readable description of the violation. +// This is intended for audit logging; the callback must not block. +func (p *MCPSecurityPolicy) OnBlocked(fn func(violation string)) *MCPSecurityPolicy { p.mu.Lock() defer p.mu.Unlock() p.onBlocked = fn @@ -144,19 +117,17 @@ func isLocalhostHost(host string) bool { // Returns an error describing the violation, or nil if allowed. func (p *MCPSecurityPolicy) CheckURL(rawURL string) error { p.mu.RLock() - fn := p.onBlocked err := p.checkURLCore(rawURL) + onBlocked := p.onBlocked p.mu.RUnlock() - if fn != nil && err != nil { - fn("url_blocked", err.Error()) + if err != nil && onBlocked != nil { + onBlocked(err.Error()) } return err } -// checkURLCore performs URL validation without acquiring the lock or invoking -// the onBlocked callback. Callers must hold p.mu (at least RLock). func (p *MCPSecurityPolicy) checkURLCore(rawURL string) error { u, err := url.Parse(rawURL) if err != nil { @@ -171,7 +142,7 @@ func (p *MCPSecurityPolicy) checkURLCore(rawURL string) error { // always allowed case "http": if p.requireHTTPS && !isLocalhostHost(host) { - return fmt.Errorf("HTTPS required: %s", redactSecurityURL(rawURL)) + return fmt.Errorf("HTTPS required: %s", rawURL) } default: return fmt.Errorf("scheme not allowed: %q (only http and https are permitted)", u.Scheme) @@ -210,44 +181,9 @@ func (p *MCPSecurityPolicy) checkURLCore(rawURL string) error { return nil } -func redactSecurityURL(rawURL string) string { - u, err := url.Parse(rawURL) - if err != nil { - return "" - } - u.RawQuery = "" - u.Fragment = "" - return u.String() -} - func (p *MCPSecurityPolicy) checkIP(ip net.IP, originalHost string) error { - for _, cidr := range p.blockedCIDRs { - if cidr.Contains(ip) { - return fmt.Errorf("blocked IP %s (CIDR %s) for host %s", ip, cidr, originalHost) - } - } - - if p.blockPrivate { - // Catch encoding variants (e.g., IPv4-compatible IPv6 like ::127.0.0.1) - // that may not match CIDR entries due to byte-length mismatch. - if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsUnspecified() { - return fmt.Errorf("blocked IP %s (private/loopback/link-local) for host %s", ip, originalHost) - } - - // Handle encoding variants that Go's net.IP methods don't classify, - // by extracting the embedded IPv4 and re-checking it. - if v4 := extractEmbeddedIPv4(ip); v4 != nil { - for _, cidr := range p.blockedCIDRs { - if cidr.Contains(v4) { - return fmt.Errorf("blocked IP %s (embedded %s, CIDR %s) for host %s", - ip, v4, cidr, originalHost) - } - } - if v4.IsLoopback() || v4.IsPrivate() || v4.IsLinkLocalUnicast() || v4.IsUnspecified() { - return fmt.Errorf("blocked IP %s (embedded %s, private/loopback) for host %s", - ip, v4, originalHost) - } - } + if _, detail, blocked := ssrfCheckIP(ip, originalHost, p.blockedCIDRs, p.blockPrivate); blocked { + return fmt.Errorf("%s", detail) } return nil @@ -256,35 +192,25 @@ func (p *MCPSecurityPolicy) checkIP(ip net.IP, originalHost string) error { // CheckPath validates a file path against the security policy. // Resolves symlinks and checks for directory traversal. // -// Security note (TOCTOU): There is an inherent time-of-check to time-of-use -// gap between the symlink resolution performed here and the caller's -// subsequent file operation. An adversary with write access to the filesystem -// could create or modify a symlink between the check and the use. This is a -// fundamental limitation of path-based validation on POSIX systems. -// -// Mitigations callers should consider: -// - Use O_NOFOLLOW when opening files after validation (prevents symlink -// following at the final component). -// - Use file-descriptor-based approaches (openat2 with RESOLVE_BENEATH on -// Linux 5.6+) where possible. -// - Avoid writing to directories that untrusted users can modify. -// - Consider validating the opened fd's path post-open via /proc/self/fd/N -// or fstat. +// SECURITY NOTE (TOCTOU): This check is inherently susceptible to +// time-of-check-to-time-of-use races — the filesystem state may change between +// the validation here and the actual file access by the caller. Callers that +// operate in adversarial environments (e.g., shared file systems) should open +// the file immediately after validation and re-verify the resolved path via +// /proc/self/fd or fstat before processing. func (p *MCPSecurityPolicy) CheckPath(path string) error { p.mu.RLock() - fn := p.onBlocked err := p.checkPathCore(path) + onBlocked := p.onBlocked p.mu.RUnlock() - if fn != nil && err != nil { - fn("path_blocked", err.Error()) + if err != nil && onBlocked != nil { + onBlocked(err.Error()) } return err } -// checkPathCore performs path validation without acquiring the lock or invoking -// the onBlocked callback. Callers must hold p.mu (at least RLock). func (p *MCPSecurityPolicy) checkPathCore(path string) error { if len(p.allowedBasePaths) == 0 { return nil @@ -380,12 +306,10 @@ func resolveExistingPrefix(p string) string { } } -// --------------------------------------------------------------------------- -// Redirect SSRF protection -// --------------------------------------------------------------------------- - -// redirectBlockedHosts lists cloud metadata service endpoints that must never -// be the target of an HTTP redirect. +// redirectBlockedHosts lists hostnames that HTTP redirects must never follow. +// This covers cloud metadata services that attackers commonly target via +// redirect-based SSRF (the initial request hits an allowed host which 302s +// to the metadata endpoint). var redirectBlockedHosts = map[string]bool{ "169.254.169.254": true, "fd00:ec2::254": true, @@ -395,9 +319,8 @@ var redirectBlockedHosts = map[string]bool{ // SSRFSafeRedirect is an [http.Client] CheckRedirect function that blocks // redirects to private/loopback IP literals, hostnames that resolve to private -// networks, and cloud metadata endpoints. It prevents -// redirect-based SSRF attacks where an attacker-controlled URL redirects to -// an internal service. +// networks, and cloud metadata endpoints. It prevents redirect-based SSRF +// attacks where an attacker-controlled URL redirects to an internal service. // // Usage: // @@ -414,8 +337,6 @@ func ssrfSafeRedirect(req *http.Request, via []*http.Request, lookupHost func(st // Block HTTPS → HTTP scheme downgrades to prevent leaking // Authorization headers (including Bearer tokens) in cleartext. - // Go's net/http preserves headers on same-host redirects regardless - // of scheme change. if len(via) > 0 && via[len(via)-1].URL.Scheme == "https" && req.URL.Scheme != "https" { return fmt.Errorf( "redirect from HTTPS to %s blocked (credential protection)", req.URL.Scheme) @@ -428,24 +349,18 @@ func ssrfSafeRedirect(req *http.Request, via []*http.Request, lookupHost func(st return fmt.Errorf("redirect to metadata endpoint %s blocked (SSRF protection)", host) } - // Block redirects to localhost hostnames (e.g. "localhost", - // "127.0.0.1") regardless of how they are spelled, preventing - // hostname-based SSRF bypasses of the IP-literal checks below. + // Block redirects to localhost hostnames. if isLocalhostHost(host) { return fmt.Errorf("redirect to localhost %s blocked (SSRF protection)", host) } // Block redirects to private/loopback IP addresses, including - // IPv4-compatible and IPv4-translated IPv6 encoding variants - // that bypass Go's IsPrivate()/IsLoopback() classification. + // IPv6 encoding variants that embed private IPv4 addresses. if ip := net.ParseIP(host); ip != nil { if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsUnspecified() { return fmt.Errorf("redirect to private/loopback IP %s blocked (SSRF protection)", ip) } - // Check IPv6 encoding variants (IPv4-compatible, IPv4-translated) - // that embed private IPv4 addresses but aren't caught by Go's - // net.IP classifier methods. if err := checkIPEncodingVariants(ip, host); err != nil { return err } @@ -491,51 +406,15 @@ func checkIPEncodingVariants(ip net.IP, originalHost string) error { } // extractEmbeddedIPv4 returns the embedded IPv4 address from IPv4-compatible -// (::x.x.x.x, RFC 4291 §2.5.5.1) or IPv4-translated (::ffff:0:x.x.x.x, -// RFC 2765 §4.2.1) IPv6 encodings. Returns nil if the address is not one of -// these encoding variants. -// -// This handles addresses that Go's net.IP.To4() does not classify as IPv4 -// (To4 returns nil for these), which means Go's IsPrivate()/IsLoopback() -// methods also return false for them. +// (::x.x.x.x) or IPv4-translated (::ffff:0:x.x.x.x) IPv6 encodings. +// Returns nil if the address is not one of these encoding variants. func extractEmbeddedIPv4(ip net.IP) net.IP { if len(ip) != net.IPv6len || ip.To4() != nil { - return nil // Not a pure IPv6 address or already handled as IPv4-mapped - } - - // Check if last 4 bytes are non-zero (otherwise it's just :: which is - // already handled by IsUnspecified). - if ip[12] == 0 && ip[13] == 0 && ip[14] == 0 && ip[15] == 0 { return nil } - // IPv4-compatible (::x.x.x.x): first 12 bytes are zero. - isV4Compatible := true - for i := 0; i < 12; i++ { - if ip[i] != 0 { - isV4Compatible = false - break - } - } - if isV4Compatible { - return net.IPv4(ip[12], ip[13], ip[14], ip[15]) + if v4 := extractIPv4Compatible(ip); v4 != nil { + return v4 } - - // IPv4-translated (::ffff:0:x.x.x.x, RFC 2765): bytes 0-7 zero, - // bytes 8-9 = 0xFF 0xFF, bytes 10-11 = 0x00 0x00, bytes 12-15 = IPv4. - // Distinct from IPv4-mapped (bytes 10-11 = 0xFF), so To4() returns nil. - if ip[8] == 0xFF && ip[9] == 0xFF && ip[10] == 0x00 && ip[11] == 0x00 { - allZero := true - for i := 0; i < 8; i++ { - if ip[i] != 0 { - allZero = false - break - } - } - if allZero { - return net.IPv4(ip[12], ip[13], ip[14], ip[15]) - } - } - - return nil + return extractIPv4Translated(ip) } diff --git a/cli/azd/pkg/azdext/mcp_security_test.go b/cli/azd/pkg/azdext/mcp_security_test.go index d4d568c41dc..d0dd2f71333 100644 --- a/cli/azd/pkg/azdext/mcp_security_test.go +++ b/cli/azd/pkg/azdext/mcp_security_test.go @@ -446,16 +446,14 @@ func TestMCPSecurityOnBlocked_URLCallback(t *testing.T) { t.Parallel() var ( - gotAction string - gotDetail string - callCount int + gotViolation string + callCount int ) policy := NewMCPSecurityPolicy(). RequireHTTPS(). - OnBlocked(func(action, detail string) { - gotAction = action - gotDetail = detail + OnBlocked(func(violation string) { + gotViolation = violation callCount++ }) @@ -468,18 +466,15 @@ func TestMCPSecurityOnBlocked_URLCallback(t *testing.T) { if callCount != 1 { t.Errorf("callCount = %d, want 1", callCount) } - if gotAction != "url_blocked" { - t.Errorf("action = %q, want %q", gotAction, "url_blocked") - } - if !strings.Contains(gotDetail, "HTTPS required") { - t.Errorf("detail = %q, want to contain %q", gotDetail, "HTTPS required") + if !strings.Contains(gotViolation, "HTTPS required") { + t.Errorf("violation = %q, want to contain %q", gotViolation, "HTTPS required") } } func TestMCPSecurityOnBlocked_PathCallback(t *testing.T) { t.Parallel() - var gotAction string + var gotViolation string base := t.TempDir() outside := t.TempDir() @@ -490,8 +485,8 @@ func TestMCPSecurityOnBlocked_PathCallback(t *testing.T) { policy := NewMCPSecurityPolicy(). ValidatePathsWithinBase(base). - OnBlocked(func(action, detail string) { - gotAction = action + OnBlocked(func(violation string) { + gotViolation = violation }) err := policy.CheckPath(outsideFile) @@ -499,8 +494,8 @@ func TestMCPSecurityOnBlocked_PathCallback(t *testing.T) { t.Fatal("expected error for path outside base") } - if gotAction != "path_blocked" { - t.Errorf("action = %q, want %q", gotAction, "path_blocked") + if gotViolation == "" { + t.Error("expected OnBlocked callback to be invoked with violation message") } } diff --git a/cli/azd/pkg/azdext/mcp_server_builder.go b/cli/azd/pkg/azdext/mcp_server_builder.go index c94eaeb4e73..1775f7fb81c 100644 --- a/cli/azd/pkg/azdext/mcp_server_builder.go +++ b/cli/azd/pkg/azdext/mcp_server_builder.go @@ -34,6 +34,12 @@ type serverToolEntry struct { } // MCPServerBuilder provides a fluent API for building MCP servers with middleware. +// +// NOTE: This builder and the associated types ([MCPToolHandler], [MCPToolOptions], +// [ToolArgs], MCP result helpers) depend on [github.com/mark3labs/mcp-go]. +// If alternative Go MCP libraries gain traction, these symbols are candidates for +// extraction into an azdext/mcpgo sub-package so the core SDK remains +// MCP-library-agnostic. type MCPServerBuilder struct { name string version string diff --git a/cli/azd/pkg/azdext/output.go b/cli/azd/pkg/azdext/output.go new file mode 100644 index 00000000000..489696c2887 --- /dev/null +++ b/cli/azd/pkg/azdext/output.go @@ -0,0 +1,255 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "encoding/json" + "fmt" + "io" + "os" + "strings" + + "github.com/fatih/color" +) + +// OutputFormat represents the output format for extension commands. +type OutputFormat string + +const ( + // OutputFormatDefault is human-readable text with optional color. + OutputFormatDefault OutputFormat = "default" + // OutputFormatJSON outputs structured JSON for machine consumption. + OutputFormatJSON OutputFormat = "json" +) + +// ParseOutputFormat converts a string to an OutputFormat. +// Returns OutputFormatDefault for unrecognized values and a non-nil error. +func ParseOutputFormat(s string) (OutputFormat, error) { + switch strings.ToLower(s) { + case "default", "": + return OutputFormatDefault, nil + case "json": + return OutputFormatJSON, nil + default: + return OutputFormatDefault, fmt.Errorf("invalid output format %q (valid: default, json)", s) + } +} + +// OutputOptions configures an [Output] instance. +type OutputOptions struct { + // Format controls the output style. Defaults to OutputFormatDefault. + Format OutputFormat + // Writer is the destination for normal output. Defaults to os.Stdout. + Writer io.Writer + // ErrWriter is the destination for error/warning output. Defaults to os.Stderr. + ErrWriter io.Writer +} + +// Output provides formatted, format-aware output for extension commands. +// In default mode it writes human-readable text with ANSI color; in JSON mode +// it writes structured JSON objects to stdout and suppresses decorative output. +// +// Output is safe for use from a single goroutine. If concurrent use is needed +// callers should synchronize externally. +type Output struct { + writer io.Writer + errWriter io.Writer + format OutputFormat + + // Color printers — configured once at construction. + successColor *color.Color + warningColor *color.Color + errorColor *color.Color + infoColor *color.Color + headerColor *color.Color + dimColor *color.Color +} + +// NewOutput creates an Output configured by opts. +// If opts.Writer or opts.ErrWriter are nil they default to os.Stdout / os.Stderr. +func NewOutput(opts OutputOptions) *Output { + w := opts.Writer + if w == nil { + w = os.Stdout + } + ew := opts.ErrWriter + if ew == nil { + ew = os.Stderr + } + + return &Output{ + writer: w, + errWriter: ew, + format: opts.Format, + successColor: color.New(color.FgGreen), + warningColor: color.New(color.FgYellow), + errorColor: color.New(color.FgRed), + infoColor: color.New(color.FgCyan), + headerColor: color.New(color.Bold), + dimColor: color.New(color.Faint), + } +} + +// IsJSON returns true when the output format is JSON. +// Callers can use this to skip decorative output that is only relevant in +// human-readable mode. +func (o *Output) IsJSON() bool { + return o.format == OutputFormatJSON +} + +// Success prints a success message prefixed with a green check mark. +// In JSON mode the call is a no-op (use [Output.JSON] for structured data). +func (o *Output) Success(format string, args ...any) { + if o.IsJSON() { + return + } + msg := fmt.Sprintf(format, args...) + o.successColor.Fprintf(o.writer, "(✓) Done: %s\n", msg) +} + +// Warning prints a warning message prefixed with a yellow exclamation mark. +// Warnings are written to ErrWriter in both default and JSON mode so they are +// visible even when stdout is piped through a JSON consumer. +func (o *Output) Warning(format string, args ...any) { + msg := fmt.Sprintf(format, args...) + if o.IsJSON() { + // In JSON mode emit a structured warning to stderr. + _ = json.NewEncoder(o.errWriter).Encode(map[string]string{ + "level": "warning", + "message": msg, + }) + return + } + o.warningColor.Fprintf(o.errWriter, "(!) Warning: %s\n", msg) +} + +// Error prints an error message prefixed with a red cross. +// Errors are always written to ErrWriter. +func (o *Output) Error(format string, args ...any) { + msg := fmt.Sprintf(format, args...) + if o.IsJSON() { + _ = json.NewEncoder(o.errWriter).Encode(map[string]string{ + "level": "error", + "message": msg, + }) + return + } + o.errorColor.Fprintf(o.errWriter, "(✗) Error: %s\n", msg) +} + +// Info prints an informational message prefixed with an info symbol. +// In JSON mode the call is a no-op (use [Output.JSON] for structured data). +func (o *Output) Info(format string, args ...any) { + if o.IsJSON() { + return + } + msg := fmt.Sprintf(format, args...) + o.infoColor.Fprintf(o.writer, "(i) %s\n", msg) +} + +// Message prints an undecorated message to stdout. +// In JSON mode the call is a no-op. +func (o *Output) Message(format string, args ...any) { + if o.IsJSON() { + return + } + fmt.Fprintf(o.writer, format+"\n", args...) +} + +// JSON writes data as a pretty-printed JSON object to stdout. +// It is active in all output modes so callers can unconditionally emit +// structured payloads (in default mode the JSON is still human-readable). +func (o *Output) JSON(data any) error { + enc := json.NewEncoder(o.writer) + enc.SetIndent("", " ") + if err := enc.Encode(data); err != nil { + return fmt.Errorf("output: failed to encode JSON: %w", err) + } + return nil +} + +// Table prints a formatted text table with headers and rows. +// In JSON mode the table is emitted as a JSON array of objects instead. +// +// headers defines the column names. Each row is a slice of cell values +// with the same length as headers. Rows with fewer cells are padded with +// empty strings; extra cells are silently ignored. +func (o *Output) Table(headers []string, rows [][]string) { + if len(headers) == 0 { + return + } + + if o.IsJSON() { + o.tableJSON(headers, rows) + return + } + + o.tableText(headers, rows) +} + +// tableJSON emits the table as a JSON array of objects keyed by header name. +func (o *Output) tableJSON(headers []string, rows [][]string) { + out := make([]map[string]string, 0, len(rows)) + for _, row := range rows { + obj := make(map[string]string, len(headers)) + for i, h := range headers { + if i < len(row) { + obj[h] = row[i] + } else { + obj[h] = "" + } + } + out = append(out, obj) + } + _ = o.JSON(out) +} + +// tableText renders an aligned text table with a header separator. +func (o *Output) tableText(headers []string, rows [][]string) { + // Calculate column widths. + widths := make([]int, len(headers)) + for i, h := range headers { + widths[i] = len(h) + } + for _, row := range rows { + for i := range headers { + if i < len(row) && len(row[i]) > widths[i] { + widths[i] = len(row[i]) + } + } + } + + // Print header row. + for i, h := range headers { + if i > 0 { + fmt.Fprint(o.writer, " ") + } + o.headerColor.Fprintf(o.writer, "%-*s", widths[i], h) + } + fmt.Fprintln(o.writer) + + // Print separator. + for i, w := range widths { + if i > 0 { + fmt.Fprint(o.writer, " ") + } + fmt.Fprint(o.writer, strings.Repeat("─", w)) + } + fmt.Fprintln(o.writer) + + // Print data rows. + for _, row := range rows { + for i := range headers { + if i > 0 { + fmt.Fprint(o.writer, " ") + } + cell := "" + if i < len(row) { + cell = row[i] + } + fmt.Fprintf(o.writer, "%-*s", widths[i], cell) + } + fmt.Fprintln(o.writer) + } +} diff --git a/cli/azd/pkg/azdext/output_test.go b/cli/azd/pkg/azdext/output_test.go new file mode 100644 index 00000000000..4a31fe7ccae --- /dev/null +++ b/cli/azd/pkg/azdext/output_test.go @@ -0,0 +1,358 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "bytes" + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// ParseOutputFormat +// --------------------------------------------------------------------------- + +func TestParseOutputFormat(t *testing.T) { + tests := []struct { + name string + input string + expected OutputFormat + expectErr bool + }{ + {name: "default string", input: "default", expected: OutputFormatDefault}, + {name: "empty string", input: "", expected: OutputFormatDefault}, + {name: "json lowercase", input: "json", expected: OutputFormatJSON}, + {name: "JSON uppercase", input: "JSON", expected: OutputFormatJSON}, + {name: "Json mixed case", input: "Json", expected: OutputFormatJSON}, + {name: "DEFAULT uppercase", input: "DEFAULT", expected: OutputFormatDefault}, + {name: "invalid format", input: "xml", expected: OutputFormatDefault, expectErr: true}, + {name: "invalid format yaml", input: "yaml", expected: OutputFormatDefault, expectErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseOutputFormat(tt.input) + if tt.expectErr { + require.Error(t, err) + require.Contains(t, err.Error(), "invalid output format") + } else { + require.NoError(t, err) + } + require.Equal(t, tt.expected, got) + }) + } +} + +// --------------------------------------------------------------------------- +// NewOutput defaults +// --------------------------------------------------------------------------- + +func TestNewOutput_DefaultWriters(t *testing.T) { + out := NewOutput(OutputOptions{}) + require.NotNil(t, out) + // Default format should be "default" (zero-value of OutputFormat). + require.False(t, out.IsJSON()) +} + +func TestNewOutput_JSONMode(t *testing.T) { + out := NewOutput(OutputOptions{Format: OutputFormatJSON}) + require.True(t, out.IsJSON()) +} + +// --------------------------------------------------------------------------- +// Success +// --------------------------------------------------------------------------- + +func TestOutput_Success_DefaultFormat(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + out.Success("deployed %s", "myapp") + + // Should contain the message text (color codes may wrap it). + require.Contains(t, buf.String(), "Done: deployed myapp") +} + +func TestOutput_Success_JSONFormat_IsNoop(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf, Format: OutputFormatJSON}) + + out.Success("should not appear") + + require.Empty(t, buf.String()) +} + +// --------------------------------------------------------------------------- +// Warning +// --------------------------------------------------------------------------- + +func TestOutput_Warning_DefaultFormat(t *testing.T) { + var errBuf bytes.Buffer + out := NewOutput(OutputOptions{ErrWriter: &errBuf}) + + out.Warning("deprecated %s", "v1") + + require.Contains(t, errBuf.String(), "Warning: deprecated v1") +} + +func TestOutput_Warning_JSONFormat(t *testing.T) { + var errBuf bytes.Buffer + out := NewOutput(OutputOptions{ErrWriter: &errBuf, Format: OutputFormatJSON}) + + out.Warning("api deprecated") + + var parsed map[string]string + err := json.Unmarshal(errBuf.Bytes(), &parsed) + require.NoError(t, err) + require.Equal(t, "warning", parsed["level"]) + require.Equal(t, "api deprecated", parsed["message"]) +} + +// --------------------------------------------------------------------------- +// Error +// --------------------------------------------------------------------------- + +func TestOutput_Error_DefaultFormat(t *testing.T) { + var errBuf bytes.Buffer + out := NewOutput(OutputOptions{ErrWriter: &errBuf}) + + out.Error("connection failed: %s", "timeout") + + require.Contains(t, errBuf.String(), "Error: connection failed: timeout") +} + +func TestOutput_Error_JSONFormat(t *testing.T) { + var errBuf bytes.Buffer + out := NewOutput(OutputOptions{ErrWriter: &errBuf, Format: OutputFormatJSON}) + + out.Error("disk full") + + var parsed map[string]string + err := json.Unmarshal(errBuf.Bytes(), &parsed) + require.NoError(t, err) + require.Equal(t, "error", parsed["level"]) + require.Equal(t, "disk full", parsed["message"]) +} + +// --------------------------------------------------------------------------- +// Info +// --------------------------------------------------------------------------- + +func TestOutput_Info_DefaultFormat(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + out.Info("fetching %d items", 5) + + require.Contains(t, buf.String(), "fetching 5 items") +} + +func TestOutput_Info_JSONFormat_IsNoop(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf, Format: OutputFormatJSON}) + + out.Info("hidden") + + require.Empty(t, buf.String()) +} + +// --------------------------------------------------------------------------- +// Message +// --------------------------------------------------------------------------- + +func TestOutput_Message_DefaultFormat(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + out.Message("plain text %d", 42) + + require.Equal(t, "plain text 42\n", buf.String()) +} + +func TestOutput_Message_JSONFormat_IsNoop(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf, Format: OutputFormatJSON}) + + out.Message("should not appear") + + require.Empty(t, buf.String()) +} + +// --------------------------------------------------------------------------- +// JSON +// --------------------------------------------------------------------------- + +func TestOutput_JSON_Struct(t *testing.T) { + type result struct { + Name string `json:"name"` + Count int `json:"count"` + } + + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + err := out.JSON(result{Name: "test", Count: 7}) + require.NoError(t, err) + + var decoded result + require.NoError(t, json.Unmarshal(buf.Bytes(), &decoded)) + require.Equal(t, "test", decoded.Name) + require.Equal(t, 7, decoded.Count) +} + +func TestOutput_JSON_Map(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + err := out.JSON(map[string]string{"key": "value"}) + require.NoError(t, err) + + var decoded map[string]string + require.NoError(t, json.Unmarshal(buf.Bytes(), &decoded)) + require.Equal(t, "value", decoded["key"]) +} + +func TestOutput_JSON_Unmarshalable(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + err := out.JSON(make(chan int)) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to encode JSON") +} + +func TestOutput_JSON_PrettyPrinted(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + err := out.JSON(map[string]int{"a": 1}) + require.NoError(t, err) + + // Verify indentation is present (pretty-printed). + require.Contains(t, buf.String(), " ") +} + +// --------------------------------------------------------------------------- +// Table — default format +// --------------------------------------------------------------------------- + +func TestOutput_Table_DefaultFormat(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + headers := []string{"Name", "Status"} + rows := [][]string{ + {"api", "running"}, + {"web", "stopped"}, + } + + out.Table(headers, rows) + + text := buf.String() + require.Contains(t, text, "Name") + require.Contains(t, text, "Status") + require.Contains(t, text, "api") + require.Contains(t, text, "running") + require.Contains(t, text, "web") + require.Contains(t, text, "stopped") + + // Separator line should be present. + require.Contains(t, text, "─") +} + +func TestOutput_Table_EmptyHeaders(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + out.Table(nil, [][]string{{"a"}}) + + require.Empty(t, buf.String()) +} + +func TestOutput_Table_EmptyRows(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + out.Table([]string{"Name"}, nil) + + // Header + separator should still be printed. + text := buf.String() + require.Contains(t, text, "Name") + require.Contains(t, text, "─") +} + +func TestOutput_Table_ShortRow(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + // Row has fewer cells than headers — should pad with empty strings. + out.Table([]string{"A", "B", "C"}, [][]string{{"only-a"}}) + + text := buf.String() + require.Contains(t, text, "only-a") + // No panic from short row. + lines := strings.Split(strings.TrimSpace(text), "\n") + require.Len(t, lines, 3) // header + separator + 1 data row +} + +func TestOutput_Table_ColumnAlignment(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf}) + + headers := []string{"ID", "LongerName"} + rows := [][]string{ + {"1", "short"}, + {"2", "a-much-longer-value"}, + } + + out.Table(headers, rows) + + lines := strings.Split(strings.TrimSpace(buf.String()), "\n") + require.GreaterOrEqual(t, len(lines), 3) + + // All separator dashes should align with header width. + sepLine := lines[1] + require.NotEmpty(t, sepLine) +} + +// --------------------------------------------------------------------------- +// Table — JSON format +// --------------------------------------------------------------------------- + +func TestOutput_Table_JSONFormat(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf, Format: OutputFormatJSON}) + + headers := []string{"Service", "Port"} + rows := [][]string{ + {"api", "8080"}, + {"web", "3000"}, + } + + out.Table(headers, rows) + + var decoded []map[string]string + require.NoError(t, json.Unmarshal(buf.Bytes(), &decoded)) + require.Len(t, decoded, 2) + require.Equal(t, "api", decoded[0]["Service"]) + require.Equal(t, "8080", decoded[0]["Port"]) + require.Equal(t, "web", decoded[1]["Service"]) + require.Equal(t, "3000", decoded[1]["Port"]) +} + +func TestOutput_Table_JSONFormat_ShortRow(t *testing.T) { + var buf bytes.Buffer + out := NewOutput(OutputOptions{Writer: &buf, Format: OutputFormatJSON}) + + out.Table([]string{"A", "B"}, [][]string{{"only-a"}}) + + var decoded []map[string]string + require.NoError(t, json.Unmarshal(buf.Bytes(), &decoded)) + require.Len(t, decoded, 1) + require.Equal(t, "only-a", decoded[0]["A"]) + require.Equal(t, "", decoded[0]["B"]) +} diff --git a/cli/azd/pkg/azdext/pagination.go b/cli/azd/pkg/azdext/pagination.go index 35030a25446..e8860bb90c9 100644 --- a/cli/azd/pkg/azdext/pagination.go +++ b/cli/azd/pkg/azdext/pagination.go @@ -12,6 +12,7 @@ import ( "net/http" "net/url" "strings" + "unicode" ) const ( @@ -51,11 +52,10 @@ type Pager[T any] struct { client HTTPDoer nextURL string done bool - initErr error + truncated bool opts PagerOptions originHost string // host of the initial URL for SSRF protection pageCount int // number of pages fetched so far - truncated bool } // PageResponse is a single page returned by [Pager.NextPage]. @@ -97,10 +97,6 @@ type stdHTTPDoer struct { } func (s *stdHTTPDoer) Do(ctx context.Context, method, url string, body io.Reader) (*http.Response, error) { - if s.client == nil { - return nil, errors.New("azdext.Pager.NextPage: client must not be nil") - } - req, err := http.NewRequestWithContext(ctx, method, url, body) if err != nil { return nil, err @@ -122,32 +118,25 @@ func NewPager[T any](client HTTPDoer, firstURL string, opts *PagerOptions) *Page opts.Method = http.MethodGet } - var ( - originHost string - initErr error - ) - if firstURL != "" { - u, err := url.Parse(firstURL) - if err != nil { - initErr = fmt.Errorf("azdext.NewPager: invalid first URL: %w", err) - } else if u.Hostname() == "" { - initErr = errors.New("azdext.NewPager: invalid first URL: missing host") - } else { - originHost = strings.ToLower(u.Hostname()) - } + var originHost string + if u, err := url.Parse(firstURL); err == nil { + originHost = strings.ToLower(u.Hostname()) } return &Pager[T]{ client: client, nextURL: firstURL, - initErr: initErr, opts: *opts, originHost: originHost, } } // NewPagerFromHTTPClient creates a [Pager] backed by a standard [*http.Client]. +// If client is nil, [http.DefaultClient] is used. func NewPagerFromHTTPClient[T any](client *http.Client, firstURL string, opts *PagerOptions) *Pager[T] { + if client == nil { + client = http.DefaultClient + } return NewPager[T](&stdHTTPDoer{client: client}, firstURL, opts) } @@ -156,10 +145,8 @@ func (p *Pager[T]) More() bool { return !p.done && p.nextURL != "" } -// Truncated reports whether the last [Collect] call stopped early due to -// MaxPages or MaxItems limits. This allows callers to detect truncation -// without a breaking API change (Collect still returns ([]T, nil) on -// successful truncation). +// Truncated reports whether the last [Collect] call stopped early because +// a collection bound (MaxPages or MaxItems) was hit. func (p *Pager[T]) Truncated() bool { return p.truncated } @@ -176,11 +163,6 @@ func (p *Pager[T]) NextPage(ctx context.Context) (*PageResponse[T], error) { if !p.More() { return nil, errors.New("azdext.Pager.NextPage: no more pages") } - if p.initErr != nil { - p.done = true - p.nextURL = "" - return nil, p.initErr - } if p.client == nil { return nil, errors.New("azdext.Pager.NextPage: client must not be nil") @@ -197,17 +179,14 @@ func (p *Pager[T]) NextPage(ctx context.Context) (*PageResponse[T], error) { return nil, &PaginationError{ StatusCode: resp.StatusCode, URL: p.nextURL, - Body: sanitizeErrorBody(string(body)), + Body: sanitizeControlChars(string(body)), } } - data, err := io.ReadAll(io.LimitReader(resp.Body, maxPageResponseSize+1)) + data, err := io.ReadAll(io.LimitReader(resp.Body, maxPageResponseSize)) if err != nil { return nil, fmt.Errorf("azdext.Pager.NextPage: failed to read response: %w", err) } - if int64(len(data)) > maxPageResponseSize { - return nil, fmt.Errorf("azdext.Pager.NextPage: response exceeds max page size (%d bytes)", maxPageResponseSize) - } var page PageResponse[T] if err := json.Unmarshal(data, &page); err != nil { @@ -240,6 +219,12 @@ func (p *Pager[T]) validateNextLink(nextLink string) error { return fmt.Errorf("invalid nextLink URL: %w", err) } + // Reject relative URLs (empty scheme) — they would fail at request time + // with a "missing protocol scheme" error, and non-absolute URLs may be + // used for path-based SSRF attacks. + if u.Scheme == "" { + return fmt.Errorf("nextLink must be an absolute URL with an HTTPS scheme") + } if u.Scheme != "https" { return fmt.Errorf("nextLink must use HTTPS (got %q)", u.Scheme) } @@ -286,18 +271,21 @@ func (p *Pager[T]) Collect(ctx context.Context) ([]T, error) { // Enforce MaxItems: truncate and stop if exceeded. if p.opts.MaxItems > 0 && len(all) >= p.opts.MaxItems { + truncatedByItems := len(all) > p.opts.MaxItems if len(all) > p.opts.MaxItems { all = all[:p.opts.MaxItems] } - p.truncated = true - p.done = true + if truncatedByItems || p.More() { + p.truncated = true + } break } // Enforce MaxPages: stop after collecting the configured number of pages. if p.pageCount >= maxPages { - p.truncated = true - p.done = true + if p.More() { + p.truncated = true + } break } } @@ -305,20 +293,11 @@ func (p *Pager[T]) Collect(ctx context.Context) ([]T, error) { return all, nil } -// maxPaginationErrorBodyLen limits the response body length stored in -// PaginationError to prevent sensitive data leakage through error messages. -// Response bodies from non-2xx pages may contain credentials, tokens, or -// other secrets embedded by the upstream service. -const maxPaginationErrorBodyLen = 1024 - // PaginationError is returned when a page request receives a non-2xx response. type PaginationError struct { StatusCode int URL string - // Body is a truncated, sanitized excerpt of the error response body for - // diagnostics. It is capped at [maxPaginationErrorBodyLen] bytes and - // stripped of control characters to prevent log forging. - Body string + Body string } func (e *PaginationError) Error() string { @@ -328,33 +307,6 @@ func (e *PaginationError) Error() string { ) } -// sanitizeErrorBody truncates and strips control characters from an error -// response body to prevent log forging and sensitive data leakage. -func sanitizeErrorBody(body string) string { - if len(body) > maxPaginationErrorBodyLen { - body = body[:maxPaginationErrorBodyLen] + "...[truncated]" - } - return stripControlChars(body) -} - -// stripControlChars replaces ASCII control characters (except tab) with a -// space to prevent log forging via CR/LF injection or terminal escape -// sequences. Tab (0x09) is preserved as it appears in legitimate JSON. -func stripControlChars(s string) string { - var b strings.Builder - b.Grow(len(s)) - for _, r := range s { - if r < 0x20 && r != '\t' { - b.WriteRune(' ') - } else if r == 0x7F { - b.WriteRune(' ') - } else { - b.WriteRune(r) - } - } - return b.String() -} - // redactURL strips query parameters and fragments from a URL to avoid leaking // tokens, SAS signatures, or other secrets in log/error messages. func redactURL(rawURL string) string { @@ -366,3 +318,14 @@ func redactURL(rawURL string) string { u.Fragment = "" return u.String() } + +// sanitizeControlChars replaces control characters (except newlines and tabs) +// with spaces to prevent log-forging attacks in stored error bodies. +func sanitizeControlChars(s string) string { + return strings.Map(func(r rune) rune { + if unicode.IsControl(r) && r != '\n' && r != '\t' { + return ' ' + } + return r + }, s) +} diff --git a/cli/azd/pkg/azdext/pagination_test.go b/cli/azd/pkg/azdext/pagination_test.go index 807353c7e84..7073903ff19 100644 --- a/cli/azd/pkg/azdext/pagination_test.go +++ b/cli/azd/pkg/azdext/pagination_test.go @@ -7,7 +7,6 @@ import ( "context" "encoding/json" "errors" - "fmt" "io" "net/http" "strings" @@ -262,24 +261,6 @@ func TestPager_EmptyFirstURL(t *testing.T) { } } -func TestPager_InvalidFirstURL(t *testing.T) { - t.Parallel() - - doer := &mockDoer{} - pager := NewPager[string](doer, "://bad", nil) - if !pager.More() { - t.Fatal("expected More() = true for non-empty initial URL") - } - - _, err := pager.NextPage(context.Background()) - if err == nil { - t.Fatal("expected error for invalid first URL") - } - if !strings.Contains(err.Error(), "invalid first URL") { - t.Errorf("error = %q, want mention of invalid first URL", err.Error()) - } -} - type testStruct struct { Name string `json:"name"` Count int `json:"count"` @@ -382,19 +363,6 @@ func TestPager_NilClient(t *testing.T) { } } -func TestPager_NilStdHTTPClient(t *testing.T) { - t.Parallel() - - pager := NewPagerFromHTTPClient[string](nil, "https://example.com/api", nil) - _, err := pager.NextPage(context.Background()) - if err == nil { - t.Fatal("expected error for nil std http client") - } - if !strings.Contains(err.Error(), "client must not be nil") { - t.Errorf("error = %q, want mention of nil client", err.Error()) - } -} - func TestPager_NextLinkSSRF_DifferentHost(t *testing.T) { t.Parallel() @@ -459,33 +427,6 @@ func TestPager_NextLinkHTTP(t *testing.T) { } } -func TestPager_NextLinkRelativeURL(t *testing.T) { - t.Parallel() - - page1 := pageJSON([]string{"a"}, "/page2") - - doer := &mockDoer{ - responses: []*doerResponse{ - {resp: &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader(page1)), - Header: http.Header{}, - }}, - }, - } - - pager := NewPager[string](doer, "https://example.com/api", nil) - - _, err := pager.NextPage(context.Background()) - if err == nil { - t.Fatal("expected error for relative nextLink") - } - - if !strings.Contains(err.Error(), "HTTPS") { - t.Errorf("error = %q, want mention of HTTPS", err.Error()) - } -} - func TestPager_NextLinkUserCredentials(t *testing.T) { t.Parallel() @@ -541,203 +482,93 @@ func TestPager_CollectWithSSRFError(t *testing.T) { } } -func TestPager_ResponseTooLarge(t *testing.T) { +func TestPager_CollectTruncatedByMaxPages(t *testing.T) { t.Parallel() - oversized := pageJSON([]string{strings.Repeat("a", int(maxPageResponseSize))}, "") + page1 := pageJSON([]int{1, 2}, "https://example.com/api?page=2") + page2 := pageJSON([]int{3, 4}, "") doer := &mockDoer{ responses: []*doerResponse{ {resp: &http.Response{ StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader(oversized)), + Body: io.NopCloser(strings.NewReader(page1)), Header: http.Header{}, }}, - }, - } - - pager := NewPager[string](doer, "https://example.com/api", nil) - - _, err := pager.NextPage(context.Background()) - if err == nil { - t.Fatal("expected error for oversized response") - } - if !strings.Contains(err.Error(), "response exceeds max page size") { - t.Errorf("error = %q, want explicit max page size error", err.Error()) - } -} - -func TestPager_CollectMaxPages(t *testing.T) { - t.Parallel() - - // Build 5 pages; set MaxPages to 3. - var responses []*doerResponse - for i := 1; i <= 5; i++ { - nextLink := "" - if i < 5 { - nextLink = fmt.Sprintf("https://example.com/api?page=%d", i+1) - } - body := pageJSON([]int{i}, nextLink) - responses = append(responses, &doerResponse{ - resp: &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader(body)), - Header: http.Header{}, - }, - }) - } - - doer := &mockDoer{responses: responses} - pager := NewPager[int](doer, "https://example.com/api?page=1", &PagerOptions{MaxPages: 3}) - - all, err := pager.Collect(context.Background()) - if err != nil { - t.Fatalf("Collect failed: %v", err) - } - - if len(all) != 3 { - t.Errorf("len(all) = %d, want 3 (MaxPages=3)", len(all)) - } - for i, want := range []int{1, 2, 3} { - if all[i] != want { - t.Errorf("all[%d] = %d, want %d", i, all[i], want) - } - } -} - -func TestPager_CollectMaxItems(t *testing.T) { - t.Parallel() - - // Build 3 pages of 4 items each; set MaxItems to 5. - var responses []*doerResponse - for i := 0; i < 3; i++ { - items := []int{i*4 + 1, i*4 + 2, i*4 + 3, i*4 + 4} - nextLink := "" - if i < 2 { - nextLink = fmt.Sprintf("https://example.com/api?page=%d", i+2) - } - body := pageJSON(items, nextLink) - responses = append(responses, &doerResponse{ - resp: &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader(body)), - Header: http.Header{}, - }, - }) - } - - doer := &mockDoer{responses: responses} - pager := NewPager[int](doer, "https://example.com/api?page=1", &PagerOptions{MaxItems: 5}) - - all, err := pager.Collect(context.Background()) - if err != nil { - t.Fatalf("Collect failed: %v", err) - } - - if len(all) != 5 { - t.Errorf("len(all) = %d, want 5 (MaxItems=5)", len(all)) - } -} - -func TestPager_TruncatedByMaxPages(t *testing.T) { - t.Parallel() - - var responses []*doerResponse - for i := 1; i <= 5; i++ { - nextLink := "" - if i < 5 { - nextLink = fmt.Sprintf("https://example.com/api?page=%d", i+1) - } - body := pageJSON([]int{i}, nextLink) - responses = append(responses, &doerResponse{ - resp: &http.Response{ + {resp: &http.Response{ StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader(body)), + Body: io.NopCloser(strings.NewReader(page2)), Header: http.Header{}, - }, - }) + }}, + }, } - doer := &mockDoer{responses: responses} - pager := NewPager[int](doer, "https://example.com/api?page=1", &PagerOptions{MaxPages: 3}) - + pager := NewPager[int](doer, "https://example.com/api?page=1", &PagerOptions{MaxPages: 1}) all, err := pager.Collect(context.Background()) if err != nil { t.Fatalf("Collect failed: %v", err) } - if len(all) != 3 { - t.Errorf("len(all) = %d, want 3", len(all)) + if len(all) != 2 { + t.Fatalf("len(all) = %d, want 2", len(all)) } - if !pager.Truncated() { - t.Error("Truncated() = false, want true (stopped at MaxPages)") + t.Fatal("expected Truncated() = true when MaxPages stops collection early") } } -func TestPager_TruncatedByMaxItems(t *testing.T) { +func TestPager_CollectTruncatedByMaxItems(t *testing.T) { t.Parallel() - var responses []*doerResponse - for i := 0; i < 3; i++ { - items := []int{i*4 + 1, i*4 + 2, i*4 + 3, i*4 + 4} - nextLink := "" - if i < 2 { - nextLink = fmt.Sprintf("https://example.com/api?page=%d", i+2) - } - body := pageJSON(items, nextLink) - responses = append(responses, &doerResponse{ - resp: &http.Response{ + page := pageJSON([]string{"a", "b", "c"}, "") + doer := &mockDoer{ + responses: []*doerResponse{ + {resp: &http.Response{ StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader(body)), + Body: io.NopCloser(strings.NewReader(page)), Header: http.Header{}, - }, - }) + }}, + }, } - doer := &mockDoer{responses: responses} - pager := NewPager[int](doer, "https://example.com/api?page=1", &PagerOptions{MaxItems: 5}) - + pager := NewPager[string](doer, "https://example.com/api", &PagerOptions{MaxItems: 2}) all, err := pager.Collect(context.Background()) if err != nil { t.Fatalf("Collect failed: %v", err) } - if len(all) != 5 { - t.Errorf("len(all) = %d, want 5", len(all)) + if len(all) != 2 { + t.Fatalf("len(all) = %d, want 2", len(all)) } - if !pager.Truncated() { - t.Error("Truncated() = false, want true (stopped at MaxItems)") + t.Fatal("expected Truncated() = true when MaxItems truncates page data") } } -func TestPager_NotTruncatedOnNaturalEnd(t *testing.T) { +func TestPager_CollectNotTruncated(t *testing.T) { t.Parallel() - body := pageJSON([]string{"a", "b"}, "") + page := pageJSON([]string{"x", "y"}, "") doer := &mockDoer{ responses: []*doerResponse{ {resp: &http.Response{ StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader(body)), + Body: io.NopCloser(strings.NewReader(page)), Header: http.Header{}, }}, }, } - pager := NewPager[string](doer, "https://example.com/api", nil) - + pager := NewPager[string](doer, "https://example.com/api", &PagerOptions{MaxPages: 10, MaxItems: 10}) all, err := pager.Collect(context.Background()) if err != nil { t.Fatalf("Collect failed: %v", err) } if len(all) != 2 { - t.Errorf("len(all) = %d, want 2", len(all)) + t.Fatalf("len(all) = %d, want 2", len(all)) } - if pager.Truncated() { - t.Error("Truncated() = true, want false (natural end)") + t.Fatal("expected Truncated() = false when all data is collected") } } diff --git a/cli/azd/pkg/azdext/process.go b/cli/azd/pkg/azdext/process.go new file mode 100644 index 00000000000..7ae99ca3f8b --- /dev/null +++ b/cli/azd/pkg/azdext/process.go @@ -0,0 +1,175 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "fmt" + "os" + "runtime" + "strings" +) + +// --------------------------------------------------------------------------- +// Cross-platform process detection +// --------------------------------------------------------------------------- + +// ProcessInfo contains information about a running process. +type ProcessInfo struct { + // PID is the process identifier. + PID int + // Name is the process name (executable basename without extension). + Name string + // Executable is the full path to the process executable, if available. + Executable string + // Running is true if the process was found and appears to be alive. + Running bool +} + +// IsProcessRunning checks whether a process with the given PID exists and +// is still running. +// +// Platform behavior: +// - Unix (Linux/macOS): Sends signal 0 to the process. If the process +// exists (even if owned by another user), this returns true. If the +// process does not exist, it returns false. This does NOT verify that +// the process is the expected one (PID reuse is possible). +// - Windows: Opens the process with PROCESS_QUERY_LIMITED_INFORMATION +// access and checks the exit code. If the process handle is valid and +// the exit code is STILL_ACTIVE, returns true. +// +// Note: PID reuse can cause false positives on all platforms. For critical +// use cases, combine PID checks with process name verification using +// [GetProcessInfo]. +// +// Returns false if the PID is invalid (≤ 0). +func IsProcessRunning(pid int) bool { + if pid <= 0 { + return false + } + return isProcessRunningOS(pid) +} + +// GetProcessInfo retrieves information about the process with the given PID. +// +// Platform behavior: +// - Linux: Reads /proc//comm, /proc//exe. +// - macOS: Uses ps(1) to query process info. +// - Windows: Uses QueryFullProcessImageName via Windows API. +// +// Returns a [ProcessInfo] with Running=false if the process does not exist +// or cannot be queried (e.g., insufficient permissions). +func GetProcessInfo(pid int) ProcessInfo { + if pid <= 0 { + return ProcessInfo{PID: pid, Running: false} + } + return getProcessInfoOS(pid) +} + +// CurrentProcessInfo returns [ProcessInfo] for the current process. +func CurrentProcessInfo() ProcessInfo { + pid := os.Getpid() + exe, _ := os.Executable() + + name := "" + if exe != "" { + name = extractBaseName(exe) + } + + return ProcessInfo{ + PID: pid, + Name: name, + Executable: exe, + Running: true, + } +} + +// ParentProcessInfo returns [ProcessInfo] for the parent of the current +// process. +// +// Platform behavior: +// - All platforms: Uses os.Getppid() to obtain the parent PID, then +// delegates to [GetProcessInfo]. +// - On orphaned processes (parent PID = 1 on Unix), the returned info +// describes the init/launchd process. +func ParentProcessInfo() ProcessInfo { + return GetProcessInfo(os.Getppid()) +} + +// FindProcessByName searches for running processes with the given name. +// The search is case-insensitive and matches the executable basename +// (without file extension on Windows). +// +// Platform behavior: +// - Linux: Scans /proc/*/comm. +// - macOS: Uses ps(1) to list processes. +// - Windows: Uses CreateToolhelp32Snapshot to enumerate processes. +// +// Returns a slice of matching [ProcessInfo]. If no processes are found, +// returns an empty (non-nil) slice. +// +// This function is best-effort: some processes may be inaccessible due to +// permissions. +func FindProcessByName(name string) []ProcessInfo { + if name == "" { + return []ProcessInfo{} + } + return findProcessByNameOS(name) +} + +// ProcessEnvironment describes the process execution context for diagnostics. +type ProcessEnvironment struct { + // PID is the current process ID. + PID int + // PPID is the parent process ID. + PPID int + // Executable is the current process executable path. + Executable string + // WorkingDir is the current working directory. + WorkingDir string + // OS is the operating system (runtime.GOOS). + OS string + // Arch is the CPU architecture (runtime.GOARCH). + Arch string + // NumCPU is the number of logical CPUs available. + NumCPU int +} + +// GetProcessEnvironment collects process execution context useful for +// diagnostics, logging, and support information. +func GetProcessEnvironment() ProcessEnvironment { + exe, _ := os.Executable() + cwd, _ := os.Getwd() + + return ProcessEnvironment{ + PID: os.Getpid(), + PPID: os.Getppid(), + Executable: exe, + WorkingDir: cwd, + OS: runtime.GOOS, + Arch: runtime.GOARCH, + NumCPU: runtime.NumCPU(), + } +} + +// String returns a human-readable summary of the process environment. +func (pe ProcessEnvironment) String() string { + return fmt.Sprintf("pid=%d ppid=%d os=%s arch=%s cpus=%d cwd=%s exe=%s", + pe.PID, pe.PPID, pe.OS, pe.Arch, pe.NumCPU, pe.WorkingDir, pe.Executable) +} + +// --------------------------------------------------------------------------- +// Internal shared helpers +// --------------------------------------------------------------------------- + +// extractBaseName returns the base name of a path without extension. +func extractBaseName(path string) string { + // Handle both Unix and Windows separators. + base := path + if idx := strings.LastIndexAny(path, `/\`); idx >= 0 { + base = path[idx+1:] + } + // Remove common extensions. + base = strings.TrimSuffix(base, ".exe") + return base +} diff --git a/cli/azd/pkg/azdext/process_darwin.go b/cli/azd/pkg/azdext/process_darwin.go new file mode 100644 index 00000000000..e7978f2779c --- /dev/null +++ b/cli/azd/pkg/azdext/process_darwin.go @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//go:build darwin + +package azdext + +import ( + "errors" + "os" + "os/exec" + "strconv" + "strings" + "syscall" +) + +// isProcessRunningOS checks if a process is running on macOS using signal 0. +func isProcessRunningOS(pid int) bool { + proc, err := os.FindProcess(pid) + if err != nil { + return false + } + err = proc.Signal(syscall.Signal(0)) + return err == nil || errors.Is(err, syscall.EPERM) +} + +// getProcessInfoOS retrieves process info on macOS using ps(1). +func getProcessInfoOS(pid int) ProcessInfo { + info := ProcessInfo{PID: pid} + + // Use ps to get process name and executable path. + cmd := exec.Command("ps", "-p", strconv.Itoa(pid), "-o", "comm=") + output, err := cmd.Output() + if err != nil { + return info // Process does not exist or is inaccessible. + } + + info.Name = extractBaseName(strings.TrimSpace(string(output))) + info.Running = true + + // Get full command path. + cmd = exec.Command("ps", "-p", strconv.Itoa(pid), "-o", "args=") + output, err = cmd.Output() + if err == nil { + args := strings.TrimSpace(string(output)) + if fields := strings.Fields(args); len(fields) > 0 { + info.Executable = fields[0] + } + } + + return info +} + +// findProcessByNameOS searches for processes by name on macOS using ps(1). +func findProcessByNameOS(name string) []ProcessInfo { + // ps -ax -o pid=,comm= lists all processes with PID and command name. + cmd := exec.Command("ps", "-ax", "-o", "pid=,comm=") + output, err := cmd.Output() + if err != nil { + return []ProcessInfo{} + } + + nameLower := strings.ToLower(name) + var results []ProcessInfo + + for _, line := range strings.Split(string(output), "\n") { + line = strings.TrimSpace(line) + if line == "" { + continue + } + + // Format: " PID COMM" + fields := strings.Fields(line) + if len(fields) < 2 { + continue + } + + pid, err := strconv.Atoi(fields[0]) + if err != nil { + continue + } + + // The comm field is the rest of the line after PID. + procPath := strings.Join(fields[1:], " ") + procName := extractBaseName(procPath) + + if strings.EqualFold(procName, nameLower) { + results = append(results, ProcessInfo{ + PID: pid, + Name: procName, + Executable: procPath, + Running: true, + }) + } + } + + if results == nil { + return []ProcessInfo{} + } + return results +} diff --git a/cli/azd/pkg/azdext/process_linux.go b/cli/azd/pkg/azdext/process_linux.go new file mode 100644 index 00000000000..93b7ecd40d1 --- /dev/null +++ b/cli/azd/pkg/azdext/process_linux.go @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//go:build !windows && !darwin + +package azdext + +import ( + "errors" + "fmt" + "os" + "strconv" + "strings" + "syscall" +) + +// isProcessRunningOS checks if a process is running on Linux using signal 0. +func isProcessRunningOS(pid int) bool { + proc, err := os.FindProcess(pid) + if err != nil { + return false + } + // Signal 0 does not send a signal but performs error checking. + // If the process exists, err is nil. If it doesn't, err is non-nil. + err = proc.Signal(syscall.Signal(0)) + return err == nil || errors.Is(err, syscall.EPERM) +} + +// getProcessInfoOS retrieves process info on Linux via /proc. +func getProcessInfoOS(pid int) ProcessInfo { + info := ProcessInfo{PID: pid} + + // Read process name from /proc//comm. + commPath := fmt.Sprintf("/proc/%d/comm", pid) + commData, err := os.ReadFile(commPath) + if err != nil { + return info // Process does not exist or is inaccessible. + } + info.Name = strings.TrimSpace(string(commData)) + info.Running = true + + // Read executable symlink from /proc//exe. + exePath := fmt.Sprintf("/proc/%d/exe", pid) + exe, err := os.Readlink(exePath) + if err == nil { + info.Executable = exe + } + + return info +} + +// findProcessByNameOS searches for processes by name on Linux via /proc. +func findProcessByNameOS(name string) []ProcessInfo { + entries, err := os.ReadDir("/proc") + if err != nil { + return []ProcessInfo{} + } + + nameLower := strings.ToLower(name) + var results []ProcessInfo + + for _, entry := range entries { + if !entry.IsDir() { + continue + } + pid, err := strconv.Atoi(entry.Name()) + if err != nil { + continue // Not a PID directory. + } + + commPath := fmt.Sprintf("/proc/%d/comm", pid) + commData, err := os.ReadFile(commPath) + if err != nil { + continue + } + + procName := strings.TrimSpace(string(commData)) + if strings.EqualFold(procName, nameLower) { + info := ProcessInfo{ + PID: pid, + Name: procName, + Running: true, + } + // Try to get executable path. + exePath := fmt.Sprintf("/proc/%d/exe", pid) + if exe, err := os.Readlink(exePath); err == nil { + info.Executable = exe + } + results = append(results, info) + } + } + + if results == nil { + return []ProcessInfo{} + } + return results +} diff --git a/cli/azd/pkg/azdext/process_test.go b/cli/azd/pkg/azdext/process_test.go new file mode 100644 index 00000000000..c59de779795 --- /dev/null +++ b/cli/azd/pkg/azdext/process_test.go @@ -0,0 +1,221 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "os" + "runtime" + "strings" + "testing" +) + +// --------------------------------------------------------------------------- +// IsProcessRunning +// --------------------------------------------------------------------------- + +func TestIsProcessRunning_CurrentProcess(t *testing.T) { + pid := os.Getpid() + if !IsProcessRunning(pid) { + t.Errorf("IsProcessRunning(%d) = false for current process, want true", pid) + } +} + +func TestIsProcessRunning_InvalidPID(t *testing.T) { + if IsProcessRunning(0) { + t.Error("IsProcessRunning(0) = true, want false") + } + if IsProcessRunning(-1) { + t.Error("IsProcessRunning(-1) = true, want false") + } +} + +func TestIsProcessRunning_NonexistentPID(t *testing.T) { + // PID 99999999 is extremely unlikely to exist. + if IsProcessRunning(99999999) { + t.Error("IsProcessRunning(99999999) = true, want false") + } +} + +// --------------------------------------------------------------------------- +// GetProcessInfo +// --------------------------------------------------------------------------- + +func TestGetProcessInfo_CurrentProcess(t *testing.T) { + pid := os.Getpid() + info := GetProcessInfo(pid) + if !info.Running { + t.Errorf("GetProcessInfo(%d).Running = false for current process, want true", pid) + } + if info.PID != pid { + t.Errorf("GetProcessInfo(%d).PID = %d, want %d", pid, info.PID, pid) + } + // Name should be non-empty for the current process. + if info.Name == "" { + t.Errorf("GetProcessInfo(%d).Name is empty, want non-empty", pid) + } +} + +func TestGetProcessInfo_InvalidPID(t *testing.T) { + info := GetProcessInfo(-1) + if info.Running { + t.Error("GetProcessInfo(-1).Running = true, want false") + } +} + +func TestGetProcessInfo_NonexistentPID(t *testing.T) { + info := GetProcessInfo(99999999) + if info.Running { + t.Error("GetProcessInfo(99999999).Running = true, want false") + } +} + +// --------------------------------------------------------------------------- +// CurrentProcessInfo +// --------------------------------------------------------------------------- + +func TestCurrentProcessInfo(t *testing.T) { + info := CurrentProcessInfo() + if !info.Running { + t.Error("CurrentProcessInfo().Running = false, want true") + } + if info.PID != os.Getpid() { + t.Errorf("CurrentProcessInfo().PID = %d, want %d", info.PID, os.Getpid()) + } + if info.Executable == "" { + t.Error("CurrentProcessInfo().Executable is empty") + } + if info.Name == "" { + t.Error("CurrentProcessInfo().Name is empty") + } +} + +// --------------------------------------------------------------------------- +// ParentProcessInfo +// --------------------------------------------------------------------------- + +func TestParentProcessInfo(t *testing.T) { + info := ParentProcessInfo() + // Parent process should exist (the test runner). + if info.PID <= 0 { + t.Errorf("ParentProcessInfo().PID = %d, want > 0", info.PID) + } + // In most environments, the parent should be running. + // Skip assertion on Running since it depends on the test environment. +} + +// --------------------------------------------------------------------------- +// FindProcessByName +// --------------------------------------------------------------------------- + +func TestFindProcessByName_Empty(t *testing.T) { + results := FindProcessByName("") + if results == nil { + t.Error("FindProcessByName(\"\") returned nil, want empty slice") + } + if len(results) != 0 { + t.Errorf("FindProcessByName(\"\") returned %d results, want 0", len(results)) + } +} + +func TestFindProcessByName_CurrentProcess(t *testing.T) { + // Get current process name. + current := CurrentProcessInfo() + if current.Name == "" { + t.Skip("cannot determine current process name") + } + + results := FindProcessByName(current.Name) + if len(results) == 0 { + t.Errorf("FindProcessByName(%q) returned 0 results, want >= 1", current.Name) + } + + // Verify at least one result matches our PID. + found := false + for _, r := range results { + if r.PID == current.PID { + found = true + break + } + } + if !found { + t.Errorf("FindProcessByName(%q) did not find current process PID %d", current.Name, current.PID) + } +} + +func TestFindProcessByName_Nonexistent(t *testing.T) { + results := FindProcessByName("azdext-nonexistent-process-xyz") + if results == nil { + t.Error("FindProcessByName(nonexistent) returned nil, want empty slice") + } + if len(results) != 0 { + t.Errorf("FindProcessByName(nonexistent) returned %d results, want 0", len(results)) + } +} + +// --------------------------------------------------------------------------- +// ProcessEnvironment +// --------------------------------------------------------------------------- + +func TestGetProcessEnvironment(t *testing.T) { + env := GetProcessEnvironment() + + if env.PID != os.Getpid() { + t.Errorf("ProcessEnvironment.PID = %d, want %d", env.PID, os.Getpid()) + } + if env.PPID <= 0 { + t.Errorf("ProcessEnvironment.PPID = %d, want > 0", env.PPID) + } + if env.OS != runtime.GOOS { + t.Errorf("ProcessEnvironment.OS = %q, want %q", env.OS, runtime.GOOS) + } + if env.Arch != runtime.GOARCH { + t.Errorf("ProcessEnvironment.Arch = %q, want %q", env.Arch, runtime.GOARCH) + } + if env.NumCPU <= 0 { + t.Errorf("ProcessEnvironment.NumCPU = %d, want > 0", env.NumCPU) + } + if env.Executable == "" { + t.Error("ProcessEnvironment.Executable is empty") + } +} + +func TestProcessEnvironment_String(t *testing.T) { + env := GetProcessEnvironment() + s := env.String() + + if !strings.Contains(s, "pid=") { + t.Errorf("ProcessEnvironment.String() = %q, missing pid=", s) + } + if !strings.Contains(s, "os=") { + t.Errorf("ProcessEnvironment.String() = %q, missing os=", s) + } + if !strings.Contains(s, "arch=") { + t.Errorf("ProcessEnvironment.String() = %q, missing arch=", s) + } +} + +// --------------------------------------------------------------------------- +// extractBaseName +// --------------------------------------------------------------------------- + +func TestExtractBaseName(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"/usr/bin/bash", "bash"}, + {`C:\Windows\System32\cmd.exe`, "cmd"}, + {"simple", "simple"}, + {"program.exe", "program"}, + {"/path/to/my-tool", "my-tool"}, + {`C:\tools\mytool.exe`, "mytool"}, + {"", ""}, + } + for _, tc := range tests { + got := extractBaseName(tc.input) + if got != tc.want { + t.Errorf("extractBaseName(%q) = %q, want %q", tc.input, got, tc.want) + } + } +} diff --git a/cli/azd/pkg/azdext/process_windows.go b/cli/azd/pkg/azdext/process_windows.go new file mode 100644 index 00000000000..c3542d36a48 --- /dev/null +++ b/cli/azd/pkg/azdext/process_windows.go @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//go:build windows + +package azdext + +import ( + "strings" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +// isProcessRunningOS checks if a process is running on Windows using the +// Windows API. Opens the process with PROCESS_QUERY_LIMITED_INFORMATION and +// checks whether the exit code is STILL_ACTIVE (259). +func isProcessRunningOS(pid int) bool { + //nolint:gosec // G115: pid is validated in caller + handle, err := windows.OpenProcess( + windows.PROCESS_QUERY_LIMITED_INFORMATION, + false, + uint32(pid), + ) + if err != nil { + return false + } + defer windows.CloseHandle(handle) + + var exitCode uint32 + err = windows.GetExitCodeProcess(handle, &exitCode) + if err != nil { + return false + } + + // STILL_ACTIVE (259) means the process has not exited. + return exitCode == 259 +} + +// getProcessInfoOS retrieves process info on Windows using the Windows API. +func getProcessInfoOS(pid int) ProcessInfo { + info := ProcessInfo{PID: pid} + + //nolint:gosec // G115: pid is validated in caller + handle, err := windows.OpenProcess( + windows.PROCESS_QUERY_LIMITED_INFORMATION, + false, + uint32(pid), + ) + if err != nil { + return info + } + defer windows.CloseHandle(handle) + + // Check if process is still running. + var exitCode uint32 + if err := windows.GetExitCodeProcess(handle, &exitCode); err != nil { + return info + } + if exitCode != 259 { + return info // Process has exited. + } + + info.Running = true + + // Get executable path. + bufSize := uint32(windows.MAX_PATH) + buf := make([]uint16, bufSize) + if err := windows.QueryFullProcessImageName(handle, 0, &buf[0], &bufSize); err != nil { + // Try with a larger buffer. + bufSize = 32768 + buf = make([]uint16, bufSize) + if err := windows.QueryFullProcessImageName(handle, 0, &buf[0], &bufSize); err != nil { + return info + } + } + + info.Executable = syscall.UTF16ToString(buf[:bufSize]) + info.Name = extractBaseName(info.Executable) + + return info +} + +// findProcessByNameOS searches for processes by name on Windows using +// CreateToolhelp32Snapshot. +func findProcessByNameOS(name string) []ProcessInfo { + snapshot, err := windows.CreateToolhelp32Snapshot(windows.TH32CS_SNAPPROCESS, 0) + if err != nil { + return []ProcessInfo{} + } + defer windows.CloseHandle(snapshot) + + var entry windows.ProcessEntry32 + entry.Size = uint32(unsafe.Sizeof(entry)) + + if err := windows.Process32First(snapshot, &entry); err != nil { + return []ProcessInfo{} + } + + nameLower := strings.ToLower(name) + var results []ProcessInfo + + for { + exeName := syscall.UTF16ToString(entry.ExeFile[:]) + baseName := extractBaseName(exeName) + + if strings.EqualFold(baseName, nameLower) { + info := ProcessInfo{ + PID: int(entry.ProcessID), + Name: baseName, + Running: true, + } + + // Try to get full executable path. + //nolint:gosec // G115: PID comes from OS snapshot + handle, err := windows.OpenProcess( + windows.PROCESS_QUERY_LIMITED_INFORMATION, + false, + entry.ProcessID, + ) + if err == nil { + bufSize := uint32(windows.MAX_PATH) + buf := make([]uint16, bufSize) + if err := windows.QueryFullProcessImageName(handle, 0, &buf[0], &bufSize); err == nil { + info.Executable = syscall.UTF16ToString(buf[:bufSize]) + } + windows.CloseHandle(handle) + } + + results = append(results, info) + } + + if err := windows.Process32Next(snapshot, &entry); err != nil { + break + } + } + + if results == nil { + return []ProcessInfo{} + } + return results +} diff --git a/cli/azd/pkg/azdext/resilient_http_client.go b/cli/azd/pkg/azdext/resilient_http_client.go index f6538a86a42..2a043c138e0 100644 --- a/cli/azd/pkg/azdext/resilient_http_client.go +++ b/cli/azd/pkg/azdext/resilient_http_client.go @@ -5,19 +5,17 @@ package azdext import ( "context" - "crypto/rand" - "encoding/binary" "errors" "fmt" "io" "math" + "math/rand/v2" "net/http" + "strconv" "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" - "github.com/azure/azure-dev/cli/azd/pkg/httputil" - "github.com/google/uuid" ) const ( @@ -25,23 +23,14 @@ const ( // a malicious or misconfigured server from stalling the client indefinitely. maxRetryAfterDuration = 120 * time.Second - // maxRetryBodyDrain limits how many bytes are consumed when draining a - // retryable response body before the next attempt. This prevents a - // malicious or misconfigured server from stalling the client with an - // unbounded response body. - maxRetryBodyDrain int64 = 1 << 20 // 1 MB - - userAgentHeaderName = "User-Agent" - clientRequestIDHeaderName = "x-ms-client-request-id" - msCorrelationIDHeaderName = "x-ms-correlation-request-id" - defaultUserAgent = "azdext-resilient-client/" + Version + // maxDrainBytes limits how many bytes we read from a retryable response body + // before discarding it. Prevents stalling on large or malicious payloads. + maxDrainBytes = 1 << 20 // 1 MiB ) // ResilientClient is an HTTP client with built-in retry, exponential backoff, // timeout, and optional bearer-token injection. It is designed for extension // authors who need to call Azure REST APIs directly. -// For full Azure SDK HTTP pipeline behavior (telemetry, logging, and -// policy-chain extensibility), prefer runtime.NewPipeline with TokenProvider. // // Usage: // @@ -54,13 +43,10 @@ type ResilientClient struct { opts ResilientClientOptions } -var _ HTTPDoer = (*ResilientClient)(nil) - // ResilientClientOptions configures a [ResilientClient]. type ResilientClientOptions struct { // MaxRetries is the maximum number of retry attempts for transient failures. - // A value of 0 disables retries. - // A negative value uses the default (3). + // Defaults to 3. MaxRetries int // InitialDelay is the base delay before the first retry. Subsequent retries @@ -75,10 +61,6 @@ type ResilientClientOptions struct { // A value of zero or less uses the default of 30s. Timeout time.Duration - // UserAgent overrides the default User-Agent header. - // When empty, defaults to "azdext-resilient-client/". - UserAgent string - // Transport overrides the default HTTP transport. Useful for testing. Transport http.RoundTripper @@ -89,7 +71,7 @@ type ResilientClientOptions struct { // defaults fills zero-value fields with production defaults. func (o *ResilientClientOptions) defaults() { - if o.MaxRetries < 0 { + if o.MaxRetries <= 0 { o.MaxRetries = 3 } @@ -113,9 +95,7 @@ func (o *ResilientClientOptions) defaults() { // resolved from the request URL via the [ScopeDetector]. func NewResilientClient(tokenProvider azcore.TokenCredential, opts *ResilientClientOptions) *ResilientClient { if opts == nil { - opts = &ResilientClientOptions{ - MaxRetries: 3, - } + opts = &ResilientClientOptions{} } opts.defaults() @@ -153,9 +133,6 @@ func (rc *ResilientClient) Do(ctx context.Context, method, url string, body io.R return nil, errors.New("azdext.ResilientClient.Do: context must not be nil") } - // Validate body seekability upfront when retries are enabled. - // Fail fast rather than discovering the body is not seekable after the - // first attempt has already consumed it. if body != nil && rc.opts.MaxRetries > 0 { if _, ok := body.(io.ReadSeeker); !ok { return nil, errors.New( @@ -186,12 +163,7 @@ func (rc *ResilientClient) Do(ctx context.Context, method, url string, body io.R // Reset body for retry — require io.ReadSeeker for non-nil bodies. if body != nil { - seeker, ok := body.(io.ReadSeeker) - if !ok { - return nil, errors.New( - "azdext.ResilientClient.Do: request body does not implement io.ReadSeeker; " + - "retries require a seekable body (use bytes.NewReader or strings.NewReader)") - } + seeker := body.(io.ReadSeeker) if _, err := seeker.Seek(0, io.SeekStart); err != nil { return nil, fmt.Errorf("azdext.ResilientClient.Do: failed to reset request body: %w", err) } @@ -202,7 +174,6 @@ func (rc *ResilientClient) Do(ctx context.Context, method, url string, body io.R if err != nil { return nil, fmt.Errorf("azdext.ResilientClient.Do: failed to create request: %w", err) } - rc.setRequestHeaders(req) // Inject bearer token when a token provider is available. if rc.tokenProvider != nil { @@ -228,14 +199,13 @@ func (rc *ResilientClient) Do(ctx context.Context, method, url string, body io.R } // Consume body before retry to release the connection. - // Bound the read to prevent a malicious server from stalling the - // client with an infinitely long response body. - _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, maxRetryBodyDrain)) + // Bounded drain prevents stalling on very large or unbounded responses. + _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, maxDrainBytes)) resp.Body.Close() // Capture Retry-After for the next iteration's delay, // capped to prevent indefinite stalling. - if ra := httputil.RetryAfter(resp); ra > 0 { + if ra := retryAfterFromResponse(resp); ra > 0 { if ra > maxRetryAfterDuration { ra = maxRetryAfterDuration } @@ -273,40 +243,19 @@ func (rc *ResilientClient) applyAuth(ctx context.Context, req *http.Request) err return nil } -func (rc *ResilientClient) setRequestHeaders(req *http.Request) { - if req.Header.Get(userAgentHeaderName) == "" { - ua := rc.opts.UserAgent - if ua == "" { - ua = defaultUserAgent - } - req.Header.Set(userAgentHeaderName, ua) - } - correlationID := req.Header.Get(clientRequestIDHeaderName) - if correlationID == "" { - correlationID = uuid.NewString() - req.Header.Set(clientRequestIDHeaderName, correlationID) - } - if req.Header.Get(msCorrelationIDHeaderName) == "" { - req.Header.Set(msCorrelationIDHeaderName, correlationID) - } -} - // backoff computes the delay for a given attempt using exponential backoff. func (rc *ResilientClient) backoff(attempt int) time.Duration { delay := time.Duration(float64(rc.opts.InitialDelay) * math.Pow(2, float64(attempt-1))) if delay > rc.opts.MaxDelay { delay = rc.opts.MaxDelay } - - // Add jitter: randomize between [50%, 100%) of computed delay to prevent - // thundering herd when multiple clients retry simultaneously. - var b [8]byte - jitter := 0.75 - if _, err := rand.Read(b[:]); err == nil { - randFloat := float64(binary.BigEndian.Uint64(b[:])) / (float64(math.MaxUint64) + 1) - jitter = 0.5 + randFloat*0.5 + jitter := 0.8 + rand.Float64()*0.4 //nolint:gosec // G404: non-security jitter + delay = time.Duration(float64(delay) * jitter) + if delay > rc.opts.MaxDelay { + delay = rc.opts.MaxDelay } - return time.Duration(float64(delay) * jitter) + + return delay } // isRetryable returns true for status codes that indicate a transient failure. @@ -324,6 +273,51 @@ func isRetryable(statusCode int) bool { } } +// retryAfterFromResponse extracts the Retry-After duration from response headers. +// Checks: retry-after-ms, x-ms-retry-after-ms, retry-after (seconds or HTTP-date). +func retryAfterFromResponse(resp *http.Response) time.Duration { + if resp == nil { + return 0 + } + + type retryHeader struct { + header string + units time.Duration + custom func(string) time.Duration + } + + nop := func(string) time.Duration { return 0 } + + headers := []retryHeader{ + {header: "retry-after-ms", units: time.Millisecond, custom: nop}, + {header: "x-ms-retry-after-ms", units: time.Millisecond, custom: nop}, + {header: "retry-after", units: time.Second, custom: func(v string) time.Duration { + t, err := time.Parse(time.RFC1123, v) + if err != nil { + return 0 + } + return time.Until(t) + }}, + } + + for _, rh := range headers { + v := resp.Header.Get(rh.header) + if v == "" { + continue + } + + if n, _ := strconv.Atoi(v); n > 0 { + return time.Duration(n) * rh.units + } + + if d := rh.custom(v); d > 0 { + return d + } + } + + return 0 +} + // RetryableHTTPError represents a retryable HTTP failure. type RetryableHTTPError struct { StatusCode int diff --git a/cli/azd/pkg/azdext/resilient_http_client_test.go b/cli/azd/pkg/azdext/resilient_http_client_test.go index 54abd845329..5b6835071fa 100644 --- a/cli/azd/pkg/azdext/resilient_http_client_test.go +++ b/cli/azd/pkg/azdext/resilient_http_client_test.go @@ -16,7 +16,6 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" - "github.com/azure/azure-dev/cli/azd/pkg/httputil" ) // roundTripFunc is an adapter to allow ordinary functions as http.RoundTripper. @@ -58,45 +57,6 @@ func TestResilientClient_Success(t *testing.T) { } } -func TestResilientClient_AddsDefaultHeaders(t *testing.T) { - t.Parallel() - - var gotUserAgent string - var gotClientRequestID string - var gotCorrelationID string - transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { - gotUserAgent = r.Header.Get(userAgentHeaderName) - gotClientRequestID = r.Header.Get(clientRequestIDHeaderName) - gotCorrelationID = r.Header.Get(msCorrelationIDHeaderName) - return &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader(`{"ok":true}`)), - Header: http.Header{}, - }, nil - }) - - rc := NewResilientClient(nil, &ResilientClientOptions{Transport: transport}) - - resp, err := rc.Do(context.Background(), http.MethodGet, "https://example.com/api", nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - defer resp.Body.Close() - - if gotUserAgent != defaultUserAgent { - t.Errorf("User-Agent = %q, want %q", gotUserAgent, defaultUserAgent) - } - if gotClientRequestID == "" { - t.Fatal("expected x-ms-client-request-id to be set") - } - if gotCorrelationID == "" { - t.Fatal("expected x-ms-correlation-request-id to be set") - } - if gotCorrelationID != gotClientRequestID { - t.Fatalf("expected correlation IDs to match, got client=%q correlation=%q", gotClientRequestID, gotCorrelationID) - } -} - func TestResilientClient_RetriesTransientFailures(t *testing.T) { t.Parallel() @@ -398,7 +358,7 @@ func TestResilientClient_NilContext(t *testing.T) { func TestResilientClient_DefaultOptions(t *testing.T) { t.Parallel() - opts := &ResilientClientOptions{MaxRetries: -1} + opts := &ResilientClientOptions{} opts.defaults() if opts.MaxRetries != 3 { @@ -418,16 +378,6 @@ func TestResilientClient_DefaultOptions(t *testing.T) { } } -func TestResilientClient_DefaultOptions_ZeroRetriesPreserved(t *testing.T) { - t.Parallel() - - opts := &ResilientClientOptions{MaxRetries: 0} - opts.defaults() - if opts.MaxRetries != 0 { - t.Errorf("MaxRetries = %d, want 0", opts.MaxRetries) - } -} - func TestRetryAfterFromResponse(t *testing.T) { t.Parallel() @@ -453,7 +403,7 @@ func TestRetryAfterFromResponse(t *testing.T) { } resp := &http.Response{Header: h} - got := httputil.RetryAfter(resp) + got := retryAfterFromResponse(resp) if got != tt.want { t.Errorf("retryAfterFromResponse() = %v, want %v", got, tt.want) @@ -465,7 +415,7 @@ func TestRetryAfterFromResponse(t *testing.T) { func TestRetryAfterFromResponse_Nil(t *testing.T) { t.Parallel() - got := httputil.RetryAfter(nil) + got := retryAfterFromResponse(nil) if got != 0 { t.Errorf("retryAfterFromResponse(nil) = %v, want 0", got) } @@ -570,20 +520,19 @@ func TestResilientClient_NonSeekableBodyRetryError(t *testing.T) { }) // io.NopCloser wrapping strings.NewReader is NOT an io.ReadSeeker. - // With upfront validation, the error is caught before any HTTP call. body := io.NopCloser(strings.NewReader("payload")) _, err := rc.Do(context.Background(), http.MethodPost, "https://example.com/api", body) if err == nil { - t.Fatal("expected error for non-seekable body with retries enabled") + t.Fatal("expected error for non-seekable body on retry") } if !strings.Contains(err.Error(), "io.ReadSeeker") { t.Errorf("error = %q, want mention of io.ReadSeeker", err.Error()) } - // Should have made zero attempts — upfront check rejects before any HTTP call. + // Should fail fast before any request attempt. if attempts != 0 { - t.Errorf("attempts = %d, want 0 (fail fast before any request)", attempts) + t.Errorf("attempts = %d, want 0 (fail fast on non-seekable body)", attempts) } } @@ -665,94 +614,33 @@ func TestResilientClient_RetryAfterCapped(t *testing.T) { h.Set("retry-after", "999999") resp := &http.Response{Header: h} - got := httputil.RetryAfter(resp) - // RetryAfter parser itself doesn't cap (pure parser), but Do() caps it. + got := retryAfterFromResponse(resp) + // retryAfterFromResponse itself doesn't cap (pure parser), but Do() caps it. if got != 999999*time.Second { - t.Errorf("RetryAfter() = %v, want %v (capping happens in Do)", got, 999999*time.Second) + t.Errorf("retryAfterFromResponse() = %v, want %v (capping happens in Do)", got, 999999*time.Second) } } -func TestResilientClient_RetryBodyDrainBounded(t *testing.T) { - t.Parallel() - - // Verify the constant used for bounded retry body drain is set - // and reasonable: it should prevent memory exhaustion but allow - // realistic retryable response bodies to be fully drained. - if maxRetryBodyDrain <= 0 { - t.Fatal("maxRetryBodyDrain must be positive") - } - if maxRetryBodyDrain > 10<<20 { // 10 MB - t.Errorf("maxRetryBodyDrain = %d, should be <= 10 MB", maxRetryBodyDrain) - } - - // Simulate a retry scenario where the retryable response body is larger - // than the drain limit. The client should not hang or OOM. - var attempts int - transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { - attempts++ - if attempts == 1 { - // Return a retryable status with a body larger than the drain limit. - // Use a LimitedReader to simulate a large body without allocating. - bigBody := io.LimitReader(infiniteReader{}, maxRetryBodyDrain+1024) - return &http.Response{ - StatusCode: http.StatusServiceUnavailable, - Body: io.NopCloser(bigBody), - Header: http.Header{}, - }, nil - } - return &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader("ok")), - Header: http.Header{}, - }, nil - }) - - rc := NewResilientClient(nil, &ResilientClientOptions{ - Transport: transport, - MaxRetries: 1, - InitialDelay: time.Millisecond, - }) - - resp, err := rc.Do(context.Background(), http.MethodGet, "https://example.com/api", nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - defer resp.Body.Close() - - if attempts != 2 { - t.Errorf("attempts = %d, want 2", attempts) - } -} - -// infiniteReader is an io.Reader that produces zero bytes forever. -type infiniteReader struct{} - -func (infiniteReader) Read(p []byte) (int, error) { - clear(p) - return len(p), nil -} - func TestResilientClient_BackoffJitter(t *testing.T) { t.Parallel() rc := NewResilientClient(nil, &ResilientClientOptions{ InitialDelay: 100 * time.Millisecond, - MaxDelay: 10 * time.Second, + MaxDelay: 5 * time.Second, }) - // Run backoff multiple times for the same attempt and verify results - // vary (jitter produces different values). - seen := make(map[time.Duration]bool) - for range 20 { + const samples = 20 + delays := make(map[time.Duration]struct{}, samples) + for range samples { d := rc.backoff(1) - seen[d] = true - // With jitter in [50%, 100%), delay should be in [50ms, 100ms). - if d < 50*time.Millisecond || d >= 100*time.Millisecond { - t.Errorf("backoff(1) = %v, want in [50ms, 100ms)", d) + delays[d] = struct{}{} + if d < 80*time.Millisecond || d > 120*time.Millisecond { + t.Fatalf("backoff with jitter = %v, want in [80ms, 120ms]", d) } } - if len(seen) < 2 { - t.Error("backoff jitter produced identical values across 20 calls") + + if len(delays) < 2 { + t.Fatalf("expected jitter to produce varying delays, got %d unique values", len(delays)) } } @@ -770,67 +658,62 @@ func TestResilientClient_NonSeekableBodyFailsFast(t *testing.T) { }) rc := NewResilientClient(nil, &ResilientClientOptions{ - Transport: transport, - MaxRetries: 2, - InitialDelay: time.Millisecond, + Transport: transport, + MaxRetries: 1, }) - // Non-seekable body with retries enabled should fail before any request. body := io.NopCloser(strings.NewReader("payload")) _, err := rc.Do(context.Background(), http.MethodPost, "https://example.com/api", body) if err == nil { - t.Fatal("expected error for non-seekable body with retries enabled") + t.Fatal("expected non-seekable body error") } if !strings.Contains(err.Error(), "io.ReadSeeker") { t.Errorf("error = %q, want mention of io.ReadSeeker", err.Error()) } - // Should NOT have made any HTTP request. if attempts != 0 { - t.Errorf("attempts = %d, want 0 (fail fast before any request)", attempts) + t.Errorf("attempts = %d, want 0", attempts) } } func TestResilientClient_RetryAfterCappedInDo(t *testing.T) { t.Parallel() - // A huge Retry-After should be capped to maxRetryAfterDuration. - // We verify this by using a very short context timeout: if the raw - // value (999999s) were used, the context would expire instantly - // rather than letting the retry proceed. With capping, the context - // timeout (250ms here) is less than the cap, so we expect the - // context to cancel — proving the delay is finite and capped. var attempts int transport := roundTripFunc(func(r *http.Request) (*http.Response, error) { attempts++ - h := http.Header{} - h.Set("retry-after", "999999") + if attempts == 1 { + h := http.Header{} + h.Set("retry-after", "999999") + return &http.Response{ + StatusCode: http.StatusTooManyRequests, + Body: io.NopCloser(strings.NewReader("throttled")), + Header: h, + }, nil + } + return &http.Response{ - StatusCode: http.StatusTooManyRequests, - Body: io.NopCloser(strings.NewReader("throttled")), - Header: h, + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("ok")), + Header: http.Header{}, }, nil }) rc := NewResilientClient(nil, &ResilientClientOptions{ - Transport: transport, - MaxRetries: 1, - InitialDelay: time.Millisecond, + Transport: transport, + MaxRetries: 1, }) - ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 25*time.Millisecond) defer cancel() _, err := rc.Do(ctx, http.MethodGet, "https://example.com/api", nil) - // The context should cancel during the capped delay (120s > 250ms), - // which means the raw 999999s was replaced by the cap. if !errors.Is(err, context.DeadlineExceeded) { - t.Errorf("expected context.DeadlineExceeded (proving cap was applied), got: %v", err) + t.Fatalf("error = %v, want context deadline exceeded", err) } - // Only 1 attempt — the retry wait for the capped delay gets canceled. if attempts != 1 { - t.Errorf("attempts = %d, want 1", attempts) + t.Fatalf("attempts = %d, want 1", attempts) } } diff --git a/cli/azd/pkg/azdext/run_test.go b/cli/azd/pkg/azdext/run_test.go new file mode 100644 index 00000000000..eb18edfab6e --- /dev/null +++ b/cli/azd/pkg/azdext/run_test.go @@ -0,0 +1,166 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "errors" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestErrorSuggestion(t *testing.T) { + tests := []struct { + name string + err error + want string + }{ + { + name: "LocalErrorWithSuggestion", + err: &LocalError{ + Message: "invalid config", + Code: "invalid_config", + Category: LocalErrorCategoryValidation, + Suggestion: "Check your azure.yaml file", + }, + want: "Check your azure.yaml file", + }, + { + name: "ServiceErrorWithSuggestion", + err: &ServiceError{ + Message: "rate limited", + ErrorCode: "TooManyRequests", + StatusCode: 429, + Suggestion: "Retry with exponential backoff", + }, + want: "Retry with exponential backoff", + }, + { + name: "LocalErrorWithoutSuggestion", + err: &LocalError{ + Message: "missing field", + Code: "missing_field", + Category: LocalErrorCategoryValidation, + }, + want: "", + }, + { + name: "ServiceErrorWithoutSuggestion", + err: &ServiceError{ + Message: "not found", + ErrorCode: "NotFound", + StatusCode: 404, + }, + want: "", + }, + { + name: "PlainError", + err: errors.New("something went wrong"), + want: "", + }, + { + name: "NilError", + err: nil, + want: "", + }, + { + name: "WrappedLocalError", + err: fmt.Errorf("operation failed: %w", &LocalError{ + Message: "bad input", + Suggestion: "Fix the input", + }), + want: "Fix the input", + }, + { + name: "WrappedServiceError", + err: fmt.Errorf("deploy failed: %w", &ServiceError{ + Message: "quota exceeded", + Suggestion: "Request a quota increase", + }), + want: "Request a quota increase", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ErrorSuggestion(tt.err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestErrorMessage(t *testing.T) { + tests := []struct { + name string + err error + want string + }{ + { + name: "LocalError", + err: &LocalError{ + Message: "invalid config", + Code: "invalid_config", + Category: LocalErrorCategoryValidation, + }, + want: "invalid config", + }, + { + name: "ServiceError", + err: &ServiceError{ + Message: "rate limited", + ErrorCode: "TooManyRequests", + StatusCode: 429, + }, + want: "rate limited", + }, + { + name: "LocalErrorEmptyMessage", + err: &LocalError{ + Code: "no_msg", + Category: LocalErrorCategoryLocal, + }, + want: "", + }, + { + name: "ServiceErrorEmptyMessage", + err: &ServiceError{ + ErrorCode: "Unknown", + StatusCode: 500, + }, + want: "", + }, + { + name: "PlainError", + err: errors.New("plain error"), + want: "", + }, + { + name: "NilError", + err: nil, + want: "", + }, + { + name: "WrappedLocalError", + err: fmt.Errorf("op: %w", &LocalError{ + Message: "wrapped local", + }), + want: "wrapped local", + }, + { + name: "WrappedServiceError", + err: fmt.Errorf("op: %w", &ServiceError{ + Message: "wrapped service", + }), + want: "wrapped service", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ErrorMessage(tt.err) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/cli/azd/pkg/azdext/scope_detector.go b/cli/azd/pkg/azdext/scope_detector.go index 5b67d9c68ef..a9191f32be1 100644 --- a/cli/azd/pkg/azdext/scope_detector.go +++ b/cli/azd/pkg/azdext/scope_detector.go @@ -6,7 +6,7 @@ package azdext import ( "errors" "net/url" - "slices" + "sort" "strings" ) @@ -68,9 +68,7 @@ func defaultRules() []scopeRule { {match: suffix(".file.core.windows.net"), scope: "https://storage.azure.com/.default"}, {match: suffix(".dfs.core.windows.net"), scope: "https://storage.azure.com/.default"}, - // Azure Container Registry (control-plane operations). - // Data-plane operations may require a different scope, so extensions can - // override this mapping via CustomRules. + // Azure Container Registry {match: suffix(".azurecr.io"), scope: "https://management.azure.com/.default"}, // Azure Cognitive Services / OpenAI @@ -93,10 +91,7 @@ func defaultRules() []scopeRule { // Azure Cosmos DB {match: suffix(".documents.azure.com"), scope: "https://cosmos.azure.com/.default"}, - // Azure Event Hubs / Service Bus host suffix ambiguity: - // both services use .servicebus.windows.net. The default maps to Event Hubs; - // extensions targeting Service Bus should override with CustomRules: - // ".servicebus.windows.net" -> "https://servicebus.azure.net/.default". + // Azure Event Hubs {match: suffix(".servicebus.windows.net"), scope: "https://eventhubs.azure.net/.default"}, // Azure App Configuration @@ -106,10 +101,9 @@ func defaultRules() []scopeRule { // NewScopeDetector creates a [ScopeDetector] with the built-in Azure endpoint // mappings. Additional custom rules can be supplied via opts. -// Custom rules take precedence over defaults, allowing callers to override -// built-in mappings (e.g. Service Bus vs Event Hubs on .servicebus.windows.net). +// Custom rules are evaluated before defaults, so they can override built-in mappings. func NewScopeDetector(opts *ScopeDetectorOptions) *ScopeDetector { - var custom []scopeRule + var customRules []scopeRule if opts != nil { // Sort keys for deterministic rule evaluation order. @@ -117,7 +111,7 @@ func NewScopeDetector(opts *ScopeDetectorOptions) *ScopeDetector { for k := range opts.CustomRules { keys = append(keys, k) } - slices.Sort(keys) + sort.Strings(keys) for _, hostSuffix := range keys { if hostSuffix == "" { @@ -129,14 +123,14 @@ func NewScopeDetector(opts *ScopeDetectorOptions) *ScopeDetector { if strings.HasPrefix(hs, ".") { // Dot-prefixed: suffix match (subdomain matching). - custom = append(custom, scopeRule{ + customRules = append(customRules, scopeRule{ match: func(host string) bool { return strings.HasSuffix(host, hs) }, scope: scope, }) } else { // No dot prefix: exact host match to prevent partial-host // matching (e.g. "azure.com" matching "fakeazure.com"). - custom = append(custom, scopeRule{ + customRules = append(customRules, scopeRule{ match: func(host string) bool { return host == hs }, scope: scope, }) @@ -144,8 +138,8 @@ func NewScopeDetector(opts *ScopeDetectorOptions) *ScopeDetector { } } - // Prepend custom rules before defaults so they take precedence. - rules := append(custom, defaultRules()...) + // Custom rules first so they take precedence over built-in defaults. + rules := append(customRules, defaultRules()...) //nolint:gocritic return &ScopeDetector{rules: rules} } diff --git a/cli/azd/pkg/azdext/security_validation.go b/cli/azd/pkg/azdext/security_validation.go new file mode 100644 index 00000000000..5d5e5dd1c7e --- /dev/null +++ b/cli/azd/pkg/azdext/security_validation.go @@ -0,0 +1,249 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "fmt" + "os" + "regexp" + "strings" +) + +// --------------------------------------------------------------------------- +// Validation error type +// --------------------------------------------------------------------------- + +// ValidationError describes a failed input validation with structured context. +type ValidationError struct { + // Field is the logical name of the input being validated (e.g. "service_name"). + Field string + + // Value is the rejected input value. For security-sensitive inputs the value + // may be truncated or redacted by the caller before constructing the error. + Value string + + // Rule is a short machine-readable tag for the violated constraint + // (e.g. "format", "length", "characters"). + Rule string + + // Message is a human-readable explanation suitable for end-user display. + Message string +} + +func (e *ValidationError) Error() string { + return fmt.Sprintf("azdext.Validate: %s %q: %s (%s)", e.Field, e.Value, e.Message, e.Rule) +} + +// --------------------------------------------------------------------------- +// Service name validation +// --------------------------------------------------------------------------- + +// serviceNameRe matches DNS-safe service names: +// - starts with alphanumeric +// - contains only alphanumeric, '.', '_', '-' +// - 1–63 characters total (DNS label limit) +var serviceNameRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9._-]{0,62}$`) + +// ValidateServiceName checks that name is a valid DNS-safe service identifier. +// +// Rules: +// - Must start with an alphanumeric character. +// - May contain alphanumeric characters, '.', '_', and '-'. +// - Must be 1–63 characters (DNS label length limit per RFC 1035). +// +// Returns a [*ValidationError] on failure. +func ValidateServiceName(name string) error { + if name == "" { + return &ValidationError{ + Field: "service_name", + Value: "", + Rule: "required", + Message: "service name must not be empty", + } + } + + if !serviceNameRe.MatchString(name) { + return &ValidationError{ + Field: "service_name", + Value: truncateValue(name, 64), + Rule: "format", + Message: "service name must start with alphanumeric and contain only [a-zA-Z0-9._-], max 63 chars", + } + } + + return nil +} + +// --------------------------------------------------------------------------- +// Hostname validation +// --------------------------------------------------------------------------- + +// hostnameRe matches RFC 952/1123 hostnames: +// - labels separated by '.' +// - each label: starts and ends with alphanumeric, may contain '-', 1–63 chars +// - total length <= 253 characters +var hostnameRe = regexp.MustCompile( + `^[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?` + + `(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$`, +) + +// ValidateHostname checks that hostname conforms to RFC 952/1123. +// +// Rules: +// - Each label must start and end with an alphanumeric character. +// - Labels may contain alphanumeric characters and '-'. +// - Each label is 1–63 characters. +// - Total hostname length is ≤ 253 characters. +// +// Returns a [*ValidationError] on failure. +func ValidateHostname(hostname string) error { + if hostname == "" { + return &ValidationError{ + Field: "hostname", + Value: "", + Rule: "required", + Message: "hostname must not be empty", + } + } + + if len(hostname) > 253 { + return &ValidationError{ + Field: "hostname", + Value: truncateValue(hostname, 64), + Rule: "length", + Message: "hostname must not exceed 253 characters", + } + } + + if !hostnameRe.MatchString(hostname) { + return &ValidationError{ + Field: "hostname", + Value: truncateValue(hostname, 64), + Rule: "format", + Message: "hostname must conform to RFC 952/1123 " + + "(labels: alphanumeric start/end, may contain '-', 1-63 chars each)", + } + } + + return nil +} + +// --------------------------------------------------------------------------- +// Script name validation +// --------------------------------------------------------------------------- + +// shellMetacharacters contains characters that have special meaning in common +// shells (bash, sh, zsh, cmd, PowerShell). A script name containing any of +// these is rejected to prevent command injection. +const shellMetacharacters = ";|&`$(){}[]<>!#~*?\"\\'%\n\r\x00" + +// ValidateScriptName checks that name does not contain shell metacharacters +// or path traversal sequences that could lead to command injection. +// +// Rejected patterns: +// - Shell metacharacters: ; | & ` $ ( ) { } [ ] < > ! # ~ * ? " ' \ % +// - Path traversal: ".." +// - Null bytes and newlines +// - Empty names +// +// Returns a [*ValidationError] on failure. +func ValidateScriptName(name string) error { + if name == "" { + return &ValidationError{ + Field: "script_name", + Value: "", + Rule: "required", + Message: "script name must not be empty", + } + } + + if strings.Contains(name, "..") { + return &ValidationError{ + Field: "script_name", + Value: truncateValue(name, 64), + Rule: "traversal", + Message: "script name must not contain path traversal sequences (..)", + } + } + + if idx := strings.IndexAny(name, shellMetacharacters); idx >= 0 { + return &ValidationError{ + Field: "script_name", + Value: truncateValue(name, 64), + Rule: "characters", + Message: fmt.Sprintf("script name contains forbidden shell metacharacter at position %d", idx), + } + } + + return nil +} + +// --------------------------------------------------------------------------- +// Container environment detection +// --------------------------------------------------------------------------- + +// containerEnvVars maps environment variables to the container runtime they indicate. +var containerEnvVars = map[string]string{ + "CODESPACES": "codespaces", + "KUBERNETES_SERVICE_HOST": "kubernetes", + "REMOTE_CONTAINERS": "devcontainer", + "REMOTE_CONTAINERS_IPC": "devcontainer", +} + +// IsContainerEnvironment reports whether the current process is running inside +// a container environment. It checks for: +// - GitHub Codespaces (CODESPACES env var) +// - Kubernetes (KUBERNETES_SERVICE_HOST env var) +// - VS Code Dev Containers (REMOTE_CONTAINERS / REMOTE_CONTAINERS_IPC env vars) +// - Docker (/.dockerenv file) +// +// The detection is best-effort and does not guarantee accuracy in all +// environments. It is intended for feature gating and diagnostics, not +// security decisions. +func IsContainerEnvironment() bool { + // Check well-known environment variables. + for envKey := range containerEnvVars { + if v := os.Getenv(envKey); v != "" { + return true + } + } + + // Check for Docker's marker file. + if _, err := os.Stat("/.dockerenv"); err == nil { + return true + } + + return false +} + +// ContainerRuntime returns the detected container runtime name, or an empty +// string if no container environment is detected. +// +// Possible return values: "codespaces", "kubernetes", "devcontainer", "docker", "". +func ContainerRuntime() string { + for envKey, runtime := range containerEnvVars { + if v := os.Getenv(envKey); v != "" { + return runtime + } + } + + if _, err := os.Stat("/.dockerenv"); err == nil { + return "docker" + } + + return "" +} + +// --------------------------------------------------------------------------- +// Internal helpers +// --------------------------------------------------------------------------- + +// truncateValue truncates s to maxLen characters for safe inclusion in error +// messages. If truncated, an ellipsis is appended. +func truncateValue(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen] + "..." +} diff --git a/cli/azd/pkg/azdext/security_validation_test.go b/cli/azd/pkg/azdext/security_validation_test.go new file mode 100644 index 00000000000..cd4165e67dd --- /dev/null +++ b/cli/azd/pkg/azdext/security_validation_test.go @@ -0,0 +1,354 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "errors" + "os" + "strings" + "testing" +) + +// --------------------------------------------------------------------------- +// ValidateServiceName +// --------------------------------------------------------------------------- + +func TestValidateServiceName_Valid(t *testing.T) { + valid := []string{ + "web", + "my-service", + "my_service", + "my.service", + "A1", + "a", + "service-v2.0", + "a23456789012345678901234567890123456789012345678901234567890123", // 63 chars + } + for _, name := range valid { + if err := ValidateServiceName(name); err != nil { + t.Errorf("ValidateServiceName(%q) = %v, want nil", name, err) + } + } +} + +func TestValidateServiceName_Invalid(t *testing.T) { + tests := []struct { + name string + rule string + }{ + {"", "required"}, + {"-starts-with-dash", "format"}, + {".starts-with-dot", "format"}, + {"_starts-with-underscore", "format"}, + {"has space", "format"}, + {"has;semicolon", "format"}, + {"has|pipe", "format"}, + {"has$dollar", "format"}, + {"has/slash", "format"}, + {"has@at", "format"}, + // 64 chars (too long) + {"a234567890123456789012345678901234567890123456789012345678901234", "format"}, + } + for _, tc := range tests { + err := ValidateServiceName(tc.name) + if err == nil { + t.Errorf("ValidateServiceName(%q) = nil, want error", tc.name) + continue + } + var ve *ValidationError + if ok := isValidationError(err, &ve); !ok { + t.Errorf("ValidateServiceName(%q) returned %T, want *ValidationError", tc.name, err) + continue + } + if ve.Rule != tc.rule { + t.Errorf("ValidateServiceName(%q).Rule = %q, want %q", tc.name, ve.Rule, tc.rule) + } + if ve.Field != "service_name" { + t.Errorf("ValidateServiceName(%q).Field = %q, want %q", tc.name, ve.Field, "service_name") + } + } +} + +// --------------------------------------------------------------------------- +// ValidateHostname +// --------------------------------------------------------------------------- + +func TestValidateHostname_Valid(t *testing.T) { + valid := []string{ + "example.com", + "sub.example.com", + "a.b.c.d.example.com", + "my-host", + "a", + "1", + "192-168-1-1.nip.io", + "xn--nxasmq6b.example.com", // punycode + } + for _, h := range valid { + if err := ValidateHostname(h); err != nil { + t.Errorf("ValidateHostname(%q) = %v, want nil", h, err) + } + } +} + +func TestValidateHostname_Invalid(t *testing.T) { + tests := []struct { + hostname string + rule string + }{ + {"", "required"}, + {"-starts-with-dash.com", "format"}, + {"ends-with-dash-.com", "format"}, + {"has space.com", "format"}, + {"has_underscore.com", "format"}, + {"has..double-dot.com", "format"}, + {".starts-with-dot.com", "format"}, + {"has;semicolon.com", "format"}, + // 254 chars (too long) + {strings.Repeat("a", 254), "length"}, + } + for _, tc := range tests { + err := ValidateHostname(tc.hostname) + if err == nil { + t.Errorf("ValidateHostname(%q) = nil, want error", tc.hostname) + continue + } + var ve *ValidationError + if ok := isValidationError(err, &ve); !ok { + t.Errorf("ValidateHostname(%q) returned %T, want *ValidationError", tc.hostname, err) + continue + } + if ve.Rule != tc.rule { + t.Errorf("ValidateHostname(%q).Rule = %q, want %q", tc.hostname, ve.Rule, tc.rule) + } + } +} + +func TestValidateHostname_LabelLength(t *testing.T) { + // Each label max 63 chars. A 64-char label should fail. + longLabel := strings.Repeat("a", 64) + ".com" + if err := ValidateHostname(longLabel); err == nil { + t.Errorf("ValidateHostname with 64-char label = nil, want error") + } + + // 63-char label should succeed. + okLabel := strings.Repeat("a", 63) + ".com" + if err := ValidateHostname(okLabel); err != nil { + t.Errorf("ValidateHostname with 63-char label = %v, want nil", err) + } +} + +// --------------------------------------------------------------------------- +// ValidateScriptName +// --------------------------------------------------------------------------- + +func TestValidateScriptName_Valid(t *testing.T) { + valid := []string{ + "script.sh", + "my-script.py", + "build_project.ps1", + "run", + "deploy-v2.sh", + "test.cmd", + "start server", // spaces are OK in script names (not metacharacters) + } + for _, name := range valid { + if err := ValidateScriptName(name); err != nil { + t.Errorf("ValidateScriptName(%q) = %v, want nil", name, err) + } + } +} + +func TestValidateScriptName_ShellMetacharacters(t *testing.T) { + dangerous := []struct { + name string + desc string + }{ + {"script;rm -rf /", "semicolon command chaining"}, + {"script|cat /etc/passwd", "pipe"}, + {"script&background", "ampersand"}, + {"script`whoami`", "backtick command substitution"}, + {"script$(id)", "dollar-paren command substitution"}, + {"script > /dev/null", "output redirect"}, + {"script < /etc/passwd", "input redirect"}, + {"script\nrm -rf /", "newline injection"}, + {"script\x00null", "null byte"}, + {"script'quoted'", "single quote"}, + {"script\"quoted\"", "double quote"}, + {"script\\escaped", "backslash"}, + {"script!history", "exclamation/history expansion"}, + {"script#comment", "hash/comment"}, + {"script~home", "tilde expansion"}, + {"script*glob", "glob star"}, + {"script?glob", "glob question"}, + {"script%env", "percent"}, + {"script(sub)", "open paren"}, + {"script{brace}", "open brace"}, + {"script[bracket]", "open bracket"}, + } + for _, tc := range dangerous { + err := ValidateScriptName(tc.name) + if err == nil { + t.Errorf("ValidateScriptName(%q) = nil, want error (%s)", tc.name, tc.desc) + continue + } + var ve *ValidationError + if ok := isValidationError(err, &ve); !ok { + t.Errorf("ValidateScriptName(%q) returned %T, want *ValidationError", tc.name, err) + continue + } + if ve.Rule != "characters" { + t.Errorf("ValidateScriptName(%q).Rule = %q, want %q (%s)", tc.name, ve.Rule, "characters", tc.desc) + } + } +} + +func TestValidateScriptName_PathTraversal(t *testing.T) { + traversal := []string{ + "../etc/passwd", + "../../secret.sh", + "dir/../../../root.sh", + "..\\windows\\system32", + } + for _, name := range traversal { + err := ValidateScriptName(name) + if err == nil { + t.Errorf("ValidateScriptName(%q) = nil, want error", name) + continue + } + var ve *ValidationError + if ok := isValidationError(err, &ve); !ok { + t.Errorf("ValidateScriptName(%q) returned %T, want *ValidationError", name, err) + continue + } + if ve.Rule != "traversal" { + t.Errorf("ValidateScriptName(%q).Rule = %q, want %q", name, ve.Rule, "traversal") + } + } +} + +func TestValidateScriptName_Empty(t *testing.T) { + err := ValidateScriptName("") + if err == nil { + t.Error("ValidateScriptName(\"\") = nil, want error") + return + } + var ve *ValidationError + if ok := isValidationError(err, &ve); !ok { + t.Errorf("ValidateScriptName(\"\") returned %T, want *ValidationError", err) + return + } + if ve.Rule != "required" { + t.Errorf("ValidateScriptName(\"\").Rule = %q, want %q", ve.Rule, "required") + } +} + +// --------------------------------------------------------------------------- +// IsContainerEnvironment / ContainerRuntime +// --------------------------------------------------------------------------- + +func TestIsContainerEnvironment_EnvVars(t *testing.T) { + tests := []struct { + envKey string + runtime string + }{ + {"CODESPACES", "codespaces"}, + {"KUBERNETES_SERVICE_HOST", "kubernetes"}, + {"REMOTE_CONTAINERS", "devcontainer"}, + {"REMOTE_CONTAINERS_IPC", "devcontainer"}, + } + for _, tc := range tests { + t.Run(tc.envKey, func(t *testing.T) { + // Ensure the env var is clean before/after. + orig := os.Getenv(tc.envKey) + t.Setenv(tc.envKey, "true") + + if !IsContainerEnvironment() { + t.Errorf("IsContainerEnvironment() = false with %s set", tc.envKey) + } + + rt := ContainerRuntime() + if rt != tc.runtime { + t.Errorf("ContainerRuntime() = %q with %s set, want %q", rt, tc.envKey, tc.runtime) + } + + // Restore and verify negative case. + if orig == "" { + os.Unsetenv(tc.envKey) + } else { + os.Setenv(tc.envKey, orig) + } + }) + } +} + +func TestIsContainerEnvironment_NoContainerEnv(t *testing.T) { + // Clear all container-related env vars. + for envKey := range containerEnvVars { + if v := os.Getenv(envKey); v != "" { + t.Setenv(envKey, "") + os.Unsetenv(envKey) + } + } + + // In CI or local dev without Docker marker, this should return false. + // We can't guarantee /.dockerenv doesn't exist, but in typical test + // environments it won't. + runtime := ContainerRuntime() + // Only assert no env-var-based detection. Docker file detection is + // environment-dependent and not worth mocking here. + _ = runtime +} + +// --------------------------------------------------------------------------- +// ValidationError +// --------------------------------------------------------------------------- + +func TestValidationError_ErrorMessage(t *testing.T) { + err := &ValidationError{ + Field: "test_field", + Value: "bad-value", + Rule: "format", + Message: "value is invalid", + } + + msg := err.Error() + if !strings.Contains(msg, "test_field") { + t.Errorf("error message should contain field name, got: %s", msg) + } + if !strings.Contains(msg, "bad-value") { + t.Errorf("error message should contain value, got: %s", msg) + } + if !strings.Contains(msg, "format") { + t.Errorf("error message should contain rule, got: %s", msg) + } +} + +func TestTruncateValue(t *testing.T) { + tests := []struct { + input string + maxLen int + want string + }{ + {"short", 10, "short"}, + {"exactly10!", 10, "exactly10!"}, + {"this is way too long", 10, "this is wa..."}, + {"", 5, ""}, + } + for _, tc := range tests { + got := truncateValue(tc.input, tc.maxLen) + if got != tc.want { + t.Errorf("truncateValue(%q, %d) = %q, want %q", tc.input, tc.maxLen, got, tc.want) + } + } +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +// isValidationError is a type assertion helper for testing. +func isValidationError(err error, target **ValidationError) bool { + return errors.As(err, target) +} diff --git a/cli/azd/pkg/azdext/shell.go b/cli/azd/pkg/azdext/shell.go new file mode 100644 index 00000000000..4863ce7a5df --- /dev/null +++ b/cli/azd/pkg/azdext/shell.go @@ -0,0 +1,254 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azdext + +import ( + "context" + "fmt" + "os" + "os/exec" + "runtime" + "strings" + + "golang.org/x/term" +) + +// --------------------------------------------------------------------------- +// Shell detection and execution +// --------------------------------------------------------------------------- + +// ShellType represents a detected shell environment. +type ShellType string + +const ( + // ShellTypeBash is the Bourne Again Shell. + ShellTypeBash ShellType = "bash" + // ShellTypeSh is the POSIX shell. + ShellTypeSh ShellType = "sh" + // ShellTypeZsh is the Z Shell. + ShellTypeZsh ShellType = "zsh" + // ShellTypeFish is the Fish shell. + ShellTypeFish ShellType = "fish" + // ShellTypePowerShell is PowerShell (pwsh/powershell.exe). + ShellTypePowerShell ShellType = "powershell" + // ShellTypeCmd is Windows cmd.exe. + ShellTypeCmd ShellType = "cmd" + // ShellTypeUnknown indicates the shell could not be determined. + ShellTypeUnknown ShellType = "" +) + +// String returns the string representation of the shell type. +func (s ShellType) String() string { + if s == ShellTypeUnknown { + return "unknown" + } + return string(s) +} + +// ShellInfo contains information about the detected shell. +type ShellInfo struct { + // Type is the detected shell type. + Type ShellType + // Path is the filesystem path to the shell executable, if known. + Path string + // Source describes how the shell was detected. + Source string +} + +// DetectShell identifies the current shell environment. +// +// Detection strategy (in order): +// 1. SHELL environment variable (Unix) — most reliable on macOS/Linux. +// 2. PSModulePath environment variable — indicates PowerShell on any platform. +// 3. ComSpec environment variable (Windows) — standard Windows shell path. +// 4. Platform default fallback (sh on Unix, cmd on Windows). +// +// Platform behavior: +// - Windows: Detects cmd.exe (default), PowerShell, or WSL shells. +// - macOS/Linux: Detects from $SHELL (bash, zsh, fish, sh). +// - If $SHELL is unset, falls back to platform default. +// +// DetectShell never returns an error. If detection fails, Type is [ShellTypeUnknown]. +func DetectShell() ShellInfo { + // Strategy 1: $SHELL (Unix convention, also set in some Windows terminals). + if shellEnv := os.Getenv("SHELL"); shellEnv != "" { + st := shellTypeFromPath(shellEnv) + if st != ShellTypeUnknown { + return ShellInfo{Type: st, Path: shellEnv, Source: "SHELL"} + } + } + + // Strategy 2: PSModulePath indicates PowerShell is the active shell. + if psPath := os.Getenv("PSModulePath"); psPath != "" { + // Try to find pwsh/powershell on PATH. + if p, err := exec.LookPath("pwsh"); err == nil { + return ShellInfo{Type: ShellTypePowerShell, Path: p, Source: "PSModulePath"} + } + if p, err := exec.LookPath("powershell"); err == nil { + return ShellInfo{Type: ShellTypePowerShell, Path: p, Source: "PSModulePath"} + } + return ShellInfo{Type: ShellTypePowerShell, Path: "", Source: "PSModulePath"} + } + + // Strategy 3: ComSpec (Windows). + if comspec := os.Getenv("ComSpec"); comspec != "" { + st := shellTypeFromPath(comspec) + if st != ShellTypeUnknown { + return ShellInfo{Type: st, Path: comspec, Source: "ComSpec"} + } + // ComSpec is set but not a recognized shell; assume cmd. + return ShellInfo{Type: ShellTypeCmd, Path: comspec, Source: "ComSpec"} + } + + // Strategy 4: Platform default. + if runtime.GOOS == "windows" { + return ShellInfo{Type: ShellTypeCmd, Path: "", Source: "platform-default"} + } + return ShellInfo{Type: ShellTypeSh, Path: "/bin/sh", Source: "platform-default"} +} + +// ShellCommand creates an [exec.Cmd] that executes script through the +// appropriate shell for the current platform. +// +// Platform behavior: +// - Windows cmd: cmd.exe /C